From 8f6639028f980bbb96c21e2025f37996ea74e411 Mon Sep 17 00:00:00 2001 From: JonnyWong16 Date: Sat, 23 Nov 2019 19:03:04 -0800 Subject: [PATCH] Add cheroot-8.2.1 --- lib/cheroot/__init__.py | 15 + lib/cheroot/__main__.py | 6 + lib/cheroot/_compat.py | 110 ++ lib/cheroot/cli.py | 234 ++++ lib/cheroot/connections.py | 279 ++++ lib/cheroot/errors.py | 58 + lib/cheroot/makefile.py | 447 ++++++ lib/cheroot/server.py | 2114 +++++++++++++++++++++++++++++ lib/cheroot/ssl/__init__.py | 52 + lib/cheroot/ssl/builtin.py | 210 +++ lib/cheroot/ssl/pyopenssl.py | 343 +++++ lib/cheroot/test/__init__.py | 1 + lib/cheroot/test/conftest.py | 69 + lib/cheroot/test/helper.py | 168 +++ lib/cheroot/test/test__compat.py | 62 + lib/cheroot/test/test_conn.py | 980 +++++++++++++ lib/cheroot/test/test_core.py | 415 ++++++ lib/cheroot/test/test_dispatch.py | 55 + lib/cheroot/test/test_errors.py | 30 + lib/cheroot/test/test_makefile.py | 52 + lib/cheroot/test/test_server.py | 235 ++++ lib/cheroot/test/test_ssl.py | 474 +++++++ lib/cheroot/test/webtest.py | 605 +++++++++ lib/cheroot/testing.py | 153 +++ lib/cheroot/workers/__init__.py | 1 + lib/cheroot/workers/threadpool.py | 323 +++++ lib/cheroot/wsgi.py | 434 ++++++ 27 files changed, 7925 insertions(+) create mode 100644 lib/cheroot/__init__.py create mode 100644 lib/cheroot/__main__.py create mode 100644 lib/cheroot/_compat.py create mode 100644 lib/cheroot/cli.py create mode 100644 lib/cheroot/connections.py create mode 100644 lib/cheroot/errors.py create mode 100644 lib/cheroot/makefile.py create mode 100644 lib/cheroot/server.py create mode 100644 lib/cheroot/ssl/__init__.py create mode 100644 lib/cheroot/ssl/builtin.py create mode 100644 lib/cheroot/ssl/pyopenssl.py create mode 100644 lib/cheroot/test/__init__.py create mode 100644 lib/cheroot/test/conftest.py create mode 100644 lib/cheroot/test/helper.py create mode 100644 lib/cheroot/test/test__compat.py create mode 100644 lib/cheroot/test/test_conn.py create mode 100644 lib/cheroot/test/test_core.py create mode 100644 lib/cheroot/test/test_dispatch.py create mode 100644 lib/cheroot/test/test_errors.py create mode 100644 lib/cheroot/test/test_makefile.py create mode 100644 lib/cheroot/test/test_server.py create mode 100644 lib/cheroot/test/test_ssl.py create mode 100644 lib/cheroot/test/webtest.py create mode 100644 lib/cheroot/testing.py create mode 100644 lib/cheroot/workers/__init__.py create mode 100644 lib/cheroot/workers/threadpool.py create mode 100644 lib/cheroot/wsgi.py diff --git a/lib/cheroot/__init__.py b/lib/cheroot/__init__.py new file mode 100644 index 00000000..30d38cab --- /dev/null +++ b/lib/cheroot/__init__.py @@ -0,0 +1,15 @@ +"""High-performance, pure-Python HTTP server used by CherryPy.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +try: + import pkg_resources +except ImportError: + pass + + +try: + __version__ = pkg_resources.get_distribution('cheroot').version +except Exception: + __version__ = 'unknown' diff --git a/lib/cheroot/__main__.py b/lib/cheroot/__main__.py new file mode 100644 index 00000000..d2e27c10 --- /dev/null +++ b/lib/cheroot/__main__.py @@ -0,0 +1,6 @@ +"""Stub for accessing the Cheroot CLI tool.""" + +from .cli import main + +if __name__ == '__main__': + main() diff --git a/lib/cheroot/_compat.py b/lib/cheroot/_compat.py new file mode 100644 index 00000000..79899b9d --- /dev/null +++ b/lib/cheroot/_compat.py @@ -0,0 +1,110 @@ +"""Compatibility code for using Cheroot with various versions of Python.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import platform +import re + +import six + +try: + import ssl + IS_ABOVE_OPENSSL10 = ssl.OPENSSL_VERSION_INFO >= (1, 1) + del ssl +except ImportError: + IS_ABOVE_OPENSSL10 = None + + +IS_PYPY = platform.python_implementation() == 'PyPy' + + +SYS_PLATFORM = platform.system() +IS_WINDOWS = SYS_PLATFORM == 'Windows' +IS_LINUX = SYS_PLATFORM == 'Linux' +IS_MACOS = SYS_PLATFORM == 'Darwin' + +PLATFORM_ARCH = platform.machine() +IS_PPC = PLATFORM_ARCH.startswith('ppc') + + +if not six.PY2: + def ntob(n, encoding='ISO-8859-1'): + """Return the native string as bytes in the given encoding.""" + assert_native(n) + # In Python 3, the native string type is unicode + return n.encode(encoding) + + def ntou(n, encoding='ISO-8859-1'): + """Return the native string as unicode with the given encoding.""" + assert_native(n) + # In Python 3, the native string type is unicode + return n + + def bton(b, encoding='ISO-8859-1'): + """Return the byte string as native string in the given encoding.""" + return b.decode(encoding) +else: + # Python 2 + def ntob(n, encoding='ISO-8859-1'): + """Return the native string as bytes in the given encoding.""" + assert_native(n) + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + + def ntou(n, encoding='ISO-8859-1'): + """Return the native string as unicode with the given encoding.""" + assert_native(n) + # In Python 2, the native string type is bytes. + # First, check for the special encoding 'escape'. The test suite uses + # this to signal that it wants to pass a string with embedded \uXXXX + # escapes, but without having to prefix it with u'' for Python 2, + # but no prefix for Python 3. + if encoding == 'escape': + return re.sub( + r'\\u([0-9a-zA-Z]{4})', + lambda m: six.unichr(int(m.group(1), 16)), + n.decode('ISO-8859-1'), + ) + # Assume it's already in the given encoding, which for ISO-8859-1 + # is almost always what was intended. + return n.decode(encoding) + + def bton(b, encoding='ISO-8859-1'): + """Return the byte string as native string in the given encoding.""" + return b + + +def assert_native(n): + """Check whether the input is of nativ ``str`` type. + + Raises: + TypeError: in case of failed check + + """ + if not isinstance(n, str): + raise TypeError('n must be a native str (got %s)' % type(n).__name__) + + +if not six.PY2: + """Python 3 has memoryview builtin.""" + # Python 2.7 has it backported, but socket.write() does + # str(memoryview(b'0' * 100)) -> + # instead of accessing it correctly. + memoryview = memoryview +else: + """Link memoryview to buffer under Python 2.""" + memoryview = buffer # noqa: F821 + + +def extract_bytes(mv): + """Retrieve bytes out of memoryview/buffer or bytes.""" + if isinstance(mv, memoryview): + return bytes(mv) if six.PY2 else mv.tobytes() + + if isinstance(mv, bytes): + return mv + + raise ValueError diff --git a/lib/cheroot/cli.py b/lib/cheroot/cli.py new file mode 100644 index 00000000..f46e7dea --- /dev/null +++ b/lib/cheroot/cli.py @@ -0,0 +1,234 @@ +"""Command line tool for starting a Cheroot WSGI/HTTP server instance. + +Basic usage:: + + # Start a server on 127.0.0.1:8000 with the default settings + # for the WSGI app myapp/wsgi.py:application() + cheroot myapp.wsgi + + # Start a server on 0.0.0.0:9000 with 8 threads + # for the WSGI app myapp/wsgi.py:main_app() + cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8 + + # Start a server for the cheroot.server.Gateway subclass + # myapp/gateway.py:HTTPGateway + cheroot myapp.gateway:HTTPGateway + + # Start a server on the UNIX socket /var/spool/myapp.sock + cheroot myapp.wsgi --bind /var/spool/myapp.sock + + # Start a server on the abstract UNIX socket CherootServer + cheroot myapp.wsgi --bind @CherootServer +""" + +import argparse +from importlib import import_module +import os +import sys +import contextlib + +import six + +from . import server +from . import wsgi + + +__metaclass__ = type + + +class BindLocation: + """A class for storing the bind location for a Cheroot instance.""" + + +class TCPSocket(BindLocation): + """TCPSocket.""" + + def __init__(self, address, port): + """Initialize. + + Args: + address (str): Host name or IP address + port (int): TCP port number + """ + self.bind_addr = address, port + + +class UnixSocket(BindLocation): + """UnixSocket.""" + + def __init__(self, path): + """Initialize.""" + self.bind_addr = path + + +class AbstractSocket(BindLocation): + """AbstractSocket.""" + + def __init__(self, addr): + """Initialize.""" + self.bind_addr = '\0{}'.format(self.abstract_socket) + + +class Application: + """Application.""" + + @classmethod + def resolve(cls, full_path): + """Read WSGI app/Gateway path string and import application module.""" + mod_path, _, app_path = full_path.partition(':') + app = getattr(import_module(mod_path), app_path or 'application') + + with contextlib.suppress(TypeError): + if issubclass(app, server.Gateway): + return GatewayYo(app) + + return cls(app) + + def __init__(self, wsgi_app): + """Initialize.""" + if not callable(wsgi_app): + raise TypeError( + 'Application must be a callable object or ' + 'cheroot.server.Gateway subclass', + ) + self.wsgi_app = wsgi_app + + def server_args(self, parsed_args): + """Return keyword args for Server class.""" + args = { + arg: value + for arg, value in vars(parsed_args).items() + if not arg.startswith('_') and value is not None + } + args.update(vars(self)) + return args + + def server(self, parsed_args): + """Server.""" + return wsgi.Server(**self.server_args(parsed_args)) + + +class GatewayYo: + """Gateway.""" + + def __init__(self, gateway): + """Init.""" + self.gateway = gateway + + def server(self, parsed_args): + """Server.""" + server_args = vars(self) + server_args['bind_addr'] = parsed_args['bind_addr'] + if parsed_args.max is not None: + server_args['maxthreads'] = parsed_args.max + if parsed_args.numthreads is not None: + server_args['minthreads'] = parsed_args.numthreads + return server.HTTPServer(**server_args) + + +def parse_wsgi_bind_location(bind_addr_string): + """Convert bind address string to a BindLocation.""" + # try and match for an IP/hostname and port + match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string)) + try: + addr = match.hostname + port = match.port + if addr is not None or port is not None: + return TCPSocket(addr, port) + except ValueError: + pass + + # else, assume a UNIX socket path + # if the string begins with an @ symbol, use an abstract socket + if bind_addr_string.startswith('@'): + return AbstractSocket(bind_addr_string[1:]) + return UnixSocket(path=bind_addr_string) + + +def parse_wsgi_bind_addr(bind_addr_string): + """Convert bind address string to bind address parameter.""" + return parse_wsgi_bind_location(bind_addr_string).bind_addr + + +_arg_spec = { + '_wsgi_app': dict( + metavar='APP_MODULE', + type=Application.resolve, + help='WSGI application callable or cheroot.server.Gateway subclass', + ), + '--bind': dict( + metavar='ADDRESS', + dest='bind_addr', + type=parse_wsgi_bind_addr, + default='[::1]:8000', + help='Network interface to listen on (default: [::1]:8000)', + ), + '--chdir': dict( + metavar='PATH', + type=os.chdir, + help='Set the working directory', + ), + '--server-name': dict( + dest='server_name', + type=str, + help='Web server name to be advertised via Server HTTP header', + ), + '--threads': dict( + metavar='INT', + dest='numthreads', + type=int, + help='Minimum number of worker threads', + ), + '--max-threads': dict( + metavar='INT', + dest='max', + type=int, + help='Maximum number of worker threads', + ), + '--timeout': dict( + metavar='INT', + dest='timeout', + type=int, + help='Timeout in seconds for accepted connections', + ), + '--shutdown-timeout': dict( + metavar='INT', + dest='shutdown_timeout', + type=int, + help='Time in seconds to wait for worker threads to cleanly exit', + ), + '--request-queue-size': dict( + metavar='INT', + dest='request_queue_size', + type=int, + help='Maximum number of queued connections', + ), + '--accepted-queue-size': dict( + metavar='INT', + dest='accepted_queue_size', + type=int, + help='Maximum number of active requests in queue', + ), + '--accepted-queue-timeout': dict( + metavar='INT', + dest='accepted_queue_timeout', + type=int, + help='Timeout in seconds for putting requests into queue', + ), +} + + +def main(): + """Create a new Cheroot instance with arguments from the command line.""" + parser = argparse.ArgumentParser( + description='Start an instance of the Cheroot WSGI/HTTP server.', + ) + for arg, spec in _arg_spec.items(): + parser.add_argument(arg, **spec) + raw_args = parser.parse_args() + + # ensure cwd in sys.path + '' in sys.path or sys.path.insert(0, '') + + # create a server based on the arguments provided + raw_args._wsgi_app.server(raw_args).safe_start() diff --git a/lib/cheroot/connections.py b/lib/cheroot/connections.py new file mode 100644 index 00000000..943ac65a --- /dev/null +++ b/lib/cheroot/connections.py @@ -0,0 +1,279 @@ +"""Utilities to manage open connections.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import io +import os +import select +import socket +import time + +from . import errors +from .makefile import MakeFile + +import six + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + import ctypes.wintypes + _SetHandleInformation = windll.kernel32.SetHandleInformation + _SetHandleInformation.argtypes = [ + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ] + _SetHandleInformation.restype = ctypes.wintypes.BOOL + except ImportError: + def prevent_socket_inheritance(sock): + """Stub inheritance prevention. + + Dummy function, since neither fcntl nor ctypes are available. + """ + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not _SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class ConnectionManager: + """Class which manages HTTPConnection objects. + + This is for connections which are being kept-alive for follow-up requests. + """ + + def __init__(self, server): + """Initialize ConnectionManager object. + + Args: + server (cheroot.server.HTTPServer): web server object + that uses this ConnectionManager instance. + """ + self.server = server + self.connections = [] + + def put(self, conn): + """Put idle connection into the ConnectionManager to be managed. + + Args: + conn (cheroot.server.HTTPConnection): HTTP connection + to be managed. + """ + conn.last_used = time.time() + conn.ready_with_data = conn.rfile.has_data() + self.connections.append(conn) + + def expire(self): + """Expire least recently used connections. + + This happens if there are either too many open connections, or if the + connections have been timed out. + + This should be called periodically. + """ + if not self.connections: + return + + # Look at the first connection - if it can be closed, then do + # that, and wait for get_conn to return it. + conn = self.connections[0] + if conn.closeable: + return + + # Too many connections? + ka_limit = self.server.keep_alive_conn_limit + if ka_limit is not None and len(self.connections) > ka_limit: + conn.closeable = True + return + + # Connection too old? + if (conn.last_used + self.server.timeout) < time.time(): + conn.closeable = True + return + + def get_conn(self, server_socket): + """Return a HTTPConnection object which is ready to be handled. + + A connection returned by this method should be ready for a worker + to handle it. If there are no connections ready, None will be + returned. + + Any connection returned by this method will need to be `put` + back if it should be examined again for another request. + + Args: + server_socket (socket.socket): Socket to listen to for new + connections. + Returns: + cheroot.server.HTTPConnection instance, or None. + + """ + # Grab file descriptors from sockets, but stop if we find a + # connection which is already marked as ready. + socket_dict = {} + for conn in self.connections: + if conn.closeable or conn.ready_with_data: + break + socket_dict[conn.socket.fileno()] = conn + else: + # No ready connection. + conn = None + + # We have a connection ready for use. + if conn: + self.connections.remove(conn) + return conn + + # Will require a select call. + ss_fileno = server_socket.fileno() + socket_dict[ss_fileno] = server_socket + try: + rlist, _, _ = select.select(list(socket_dict), [], [], 0.1) + # No available socket. + if not rlist: + return None + except OSError: + # Mark any connection which no longer appears valid. + for fno, conn in list(socket_dict.items()): + # If the server socket is invalid, we'll just ignore it and + # wait to be shutdown. + if fno == ss_fileno: + continue + try: + os.fstat(fno) + except OSError: + # Socket is invalid, close the connection, insert at + # the front. + self.connections.remove(conn) + self.connections.insert(0, conn) + conn.closeable = True + + # Wait for the next tick to occur. + return None + + try: + # See if we have a new connection coming in. + rlist.remove(ss_fileno) + except ValueError: + # No new connection, but reuse existing socket. + conn = socket_dict[rlist.pop()] + else: + conn = server_socket + + # All remaining connections in rlist should be marked as ready. + for fno in rlist: + socket_dict[fno].ready_with_data = True + + # New connection. + if conn is server_socket: + return self._from_server_socket(server_socket) + + self.connections.remove(conn) + return conn + + def _from_server_socket(self, server_socket): + try: + s, addr = server_socket.accept() + if self.server.stats['Enabled']: + self.server.stats['Accepts'] += 1 + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.server.timeout) + + mf = MakeFile + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.server.ssl_adapter is not None: + try: + s, ssl_env = self.server.ssl_adapter.wrap(s) + except errors.NoSSLError: + msg = ( + 'The client sent a plain HTTP request, but ' + 'this server only speaks HTTPS on this port.' + ) + buf = [ + '%s 400 Bad Request\r\n' % self.server.protocol, + 'Content-Length: %s\r\n' % len(msg), + 'Content-Type: text/plain\r\n\r\n', + msg, + ] + + sock_to_make = s if not six.PY2 else s._sock + wfile = mf(sock_to_make, 'wb', io.DEFAULT_BUFFER_SIZE) + try: + wfile.write(''.join(buf).encode('ISO-8859-1')) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + return + if not s: + return + mf = self.server.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.server.timeout) + + conn = self.server.ConnectionClass(self.server, s, mf) + + if not isinstance( + self.server.bind_addr, + (six.text_type, six.binary_type), + ): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + return conn + + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error as ex: + if self.server.stats['Enabled']: + self.server.stats['Socket Errors'] += 1 + if ex.args[0] in errors.socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See + # https://github.com/cherrypy/cherrypy/issues/707. + return + if ex.args[0] in errors.socket_errors_nonblocking: + # Just try again. See + # https://github.com/cherrypy/cherrypy/issues/479. + return + if ex.args[0] in errors.socket_errors_to_ignore: + # Our socket was closed. + # See https://github.com/cherrypy/cherrypy/issues/686. + return + raise + + def close(self): + """Close all monitored connections.""" + for conn in self.connections[:]: + conn.close() + self.connections = [] diff --git a/lib/cheroot/errors.py b/lib/cheroot/errors.py new file mode 100644 index 00000000..80928731 --- /dev/null +++ b/lib/cheroot/errors.py @@ -0,0 +1,58 @@ +"""Collection of exceptions raised and/or processed by Cheroot.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import errno +import sys + + +class MaxSizeExceeded(Exception): + """Exception raised when a client sends more data then acceptable within limit. + + Depends on ``request.body.maxbytes`` config option if used within CherryPy + """ + + +class NoSSLError(Exception): + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + + +class FatalSSLAlert(Exception): + """Exception raised when the SSL implementation signals a fatal alert.""" + + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + missing_attr = set([None, ]) + unique_nums = set(getattr(errno, k, None) for k in errnames) + return list(unique_nums - missing_attr) + + +socket_error_eintr = plat_specific_errors('EINTR', 'WSAEINTR') + +socket_errors_to_ignore = plat_specific_errors( + 'EPIPE', + 'EBADF', 'WSAEBADF', + 'ENOTSOCK', 'WSAENOTSOCK', + 'ETIMEDOUT', 'WSAETIMEDOUT', + 'ECONNREFUSED', 'WSAECONNREFUSED', + 'ECONNRESET', 'WSAECONNRESET', + 'ECONNABORTED', 'WSAECONNABORTED', + 'ENETRESET', 'WSAENETRESET', + 'EHOSTDOWN', 'EHOSTUNREACH', +) +socket_errors_to_ignore.append('timed out') +socket_errors_to_ignore.append('The read operation timed out') +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK', +) + +if sys.platform == 'darwin': + socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE')) + socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE')) diff --git a/lib/cheroot/makefile.py b/lib/cheroot/makefile.py new file mode 100644 index 00000000..8a86b338 --- /dev/null +++ b/lib/cheroot/makefile.py @@ -0,0 +1,447 @@ +"""Socket file object.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket + +try: + # prefer slower Python-based io module + import _pyio as io +except ImportError: + # Python 2.6 + import io + +import six + +from . import errors +from ._compat import extract_bytes, memoryview + + +# Write only 16K at a time to sockets +SOCK_WRITE_BLOCKSIZE = 16384 + + +class BufferedWriter(io.BufferedWriter): + """Faux file object attached to a socket object.""" + + def write(self, b): + """Write bytes to buffer.""" + self._checkClosed() + if isinstance(b, str): + raise TypeError("can't write str to binary stream") + + with self._write_lock: + self._write_buf.extend(b) + self._flush_unlocked() + return len(b) + + def _flush_unlocked(self): + self._checkClosed('flush of closed file') + while self._write_buf: + try: + # ssl sockets only except 'bytes', not bytearrays + # so perhaps we should conditionally wrap this for perf? + n = self.raw.write(bytes(self._write_buf)) + except io.BlockingIOError as e: + n = e.characters_written + del self._write_buf[:n] + + +class MakeFile_PY2(getattr(socket, '_fileobject', object)): + """Faux file object attached to a socket object.""" + + def __init__(self, *args, **kwargs): + """Initialize faux file object.""" + self.bytes_read = 0 + self.bytes_written = 0 + socket._fileobject.__init__(self, *args, **kwargs) + self._refcount = 0 + + def _reuse(self): + self._refcount += 1 + + def _drop(self): + if self._refcount < 0: + self.close() + else: + self._refcount -= 1 + + def write(self, data): + """Sendall for non-blocking sockets.""" + bytes_sent = 0 + data_mv = memoryview(data) + payload_size = len(data_mv) + while bytes_sent < payload_size: + try: + bytes_sent += self.send( + data_mv[bytes_sent:bytes_sent + SOCK_WRITE_BLOCKSIZE], + ) + except socket.error as e: + if e.args[0] not in errors.socket_errors_nonblocking: + raise + + def send(self, data): + """Send some part of message to the socket.""" + bytes_sent = self._sock.send(extract_bytes(data)) + self.bytes_written += bytes_sent + return bytes_sent + + def flush(self): + """Write all data from buffer to socket and reset write buffer.""" + if self._wbuf: + buffer = ''.join(self._wbuf) + self._wbuf = [] + self.write(buffer) + + def recv(self, size): + """Receive message of a size from the socket.""" + while True: + try: + data = self._sock.recv(size) + self.bytes_read += len(data) + return data + except socket.error as e: + what = ( + e.args[0] not in errors.socket_errors_nonblocking + and e.args[0] not in errors.socket_error_eintr + ) + if what: + raise + + class FauxSocket: + """Faux socket with the minimal interface required by pypy.""" + + def _reuse(self): + pass + + _fileobject_uses_str_type = six.PY2 and isinstance( + socket._fileobject(FauxSocket())._rbuf, six.string_types, + ) + + # FauxSocket is no longer needed + del FauxSocket + + if not _fileobject_uses_str_type: + def read(self, size=-1): + """Read data from the socket to buffer.""" + # Use max, disallow tiny reads in a loop as they are very + # inefficient. + # We never leave read() with any leftover data from a new recv() + # call in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned + # by recv() minimizes memory usage and fragmentation that occurs + # when rbufsize is large compared to the typical return value of + # recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and + # return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return rv + + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, 'recv(%d) returned %d bytes' % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + # assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + """Read line from the socket to buffer.""" + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + data = None + recv = self.recv + while data != '\n': + data = recv(1) + if not data: + break + buffers.append(data) + return ''.join(buffers) + + buf.seek(0, 2) # seek end + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return rv + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when + # returning a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + # assert buf_len == buf.tell() + return buf.getvalue() + + def has_data(self): + """Return true if there is buffered data to read.""" + return bool(self._rbuf.getvalue()) + + else: + def read(self, size=-1): + """Read data from the socket to buffer.""" + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = '' + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + return ''.join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return ''.join(buffers) + + def readline(self, size=-1): + """Read line from the socket to buffer.""" + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == '' + buffers = [] + while data != '\n': + data = self.recv(1) + if not data: + break + buffers.append(data) + return ''.join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return ''.join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return ''.join(buffers) + + def has_data(self): + """Return true if there is buffered data to read.""" + return bool(self._rbuf) + + +if not six.PY2: + class StreamReader(io.BufferedReader): + """Socket stream reader.""" + + def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize socket stream reader.""" + super().__init__(socket.SocketIO(sock, mode), bufsize) + self.bytes_read = 0 + + def read(self, *args, **kwargs): + """Capture bytes read.""" + val = super().read(*args, **kwargs) + self.bytes_read += len(val) + return val + + def has_data(self): + """Return true if there is buffered data to read.""" + return len(self._read_buf) > self._read_pos + + class StreamWriter(BufferedWriter): + """Socket stream writer.""" + + def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize socket stream writer.""" + super().__init__(socket.SocketIO(sock, mode), bufsize) + self.bytes_written = 0 + + def write(self, val, *args, **kwargs): + """Capture bytes written.""" + res = super().write(val, *args, **kwargs) + self.bytes_written += len(val) + return res + + def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): + """File object attached to a socket object.""" + cls = StreamReader if 'r' in mode else StreamWriter + return cls(sock, mode, bufsize) +else: + StreamReader = StreamWriter = MakeFile = MakeFile_PY2 diff --git a/lib/cheroot/server.py b/lib/cheroot/server.py new file mode 100644 index 00000000..991160de --- /dev/null +++ b/lib/cheroot/server.py @@ -0,0 +1,2114 @@ +""" +A high-speed, production ready, thread pooled, generic HTTP server. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = HTTPServer(...) + server.start() + -> while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return + +For running a server you can invoke :func:`start() ` (it +will run the server forever) or use invoking :func:`prepare() +` and :func:`serve() ` like this:: + + server = HTTPServer(...) + server.prepare() + try: + threading.Thread(target=server.serve).start() + + # waiting/detecting some appropriate stop condition here + ... + + finally: + server.stop() + +And now for a trivial doctest to exercise the test suite + +>>> 'HTTPServer' in globals() +True + +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import io +import re +import email.utils +import socket +import sys +import time +import traceback as traceback_ +import logging +import platform + +try: + from functools import lru_cache +except ImportError: + from backports.functools_lru_cache import lru_cache + +import six +from six.moves import queue +from six.moves import urllib + +from . import connections, errors, __version__ +from ._compat import bton, ntou +from ._compat import IS_PPC +from .workers import threadpool +from .makefile import MakeFile, StreamWriter + + +__all__ = ( + 'HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'Gateway', 'get_ssl_adapter_class', +) + + +IS_WINDOWS = platform.system() == 'Windows' +"""Flag indicating whether the app is running under Windows.""" + + +IS_GAE = os.getenv('SERVER_SOFTWARE', '').startswith('Google App Engine/') +"""Flag indicating whether the app is running in GAE env. + +Ref: +https://cloud.google.com/appengine/docs/standard/python/tools +/using-local-server#detecting_application_runtime_environment +""" + + +IS_UID_GID_RESOLVABLE = not IS_WINDOWS and not IS_GAE +"""Indicates whether UID/GID resolution's available under current platform.""" + + +if IS_UID_GID_RESOLVABLE: + try: + import grp + import pwd + except ImportError: + """Unavailable in the current env. + + This shouldn't be happening normally. + All of the known cases are excluded via the if clause. + """ + IS_UID_GID_RESOLVABLE = False + grp, pwd = None, None + import struct + + +if IS_WINDOWS and hasattr(socket, 'AF_INET6'): + if not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 + if not hasattr(socket, 'IPV6_V6ONLY'): + socket.IPV6_V6ONLY = 27 + + +if not hasattr(socket, 'SO_PEERCRED'): + """ + NOTE: the value for SO_PEERCRED can be architecture specific, in + which case the getsockopt() will hopefully fail. The arch + specific value could be derived from platform.processor() + """ + socket.SO_PEERCRED = 21 if IS_PPC else 17 + + +LF = b'\n' +CRLF = b'\r\n' +TAB = b'\t' +SPACE = b' ' +COLON = b':' +SEMICOLON = b';' +EMPTY = b'' +ASTERISK = b'*' +FORWARD_SLASH = b'/' +QUOTED_SLASH = b'%2F' +QUOTED_SLASH_REGEX = re.compile(b'(?i)' + QUOTED_SLASH) + + +comma_separated_headers = [ + b'Accept', b'Accept-Charset', b'Accept-Encoding', + b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control', + b'Connection', b'Content-Encoding', b'Content-Language', b'Expect', + b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE', + b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning', + b'WWW-Authenticate', +] + + +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +class HeaderReader: + """Object for reading headers from an HTTP request. + + Interface and default implementation. + """ + + def __call__(self, rfile, hdict=None): + """ + Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP + spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError('Illegal end of headers.') + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError('HTTP requires CRLF terminators') + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError('Illegal header line.') + v = v.strip() + k = self._transform_key(k) + hname = k + + if not self._allow_header(k): + continue + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = b', '.join((existing, v)) + hdict[hname] = v + + return hdict + + def _allow_header(self, key_name): + return True + + def _transform_key(self, key_name): + # TODO: what about TE and WWW-Authenticate? + return key_name.strip().title() + + +class DropUnderscoreHeaderReader(HeaderReader): + """Custom HeaderReader to exclude any headers with underscores in them.""" + + def _allow_header(self, key_name): + orig = super(DropUnderscoreHeaderReader, self)._allow_header(key_name) + return orig and '_' not in key_name + + +class SizeCheckWrapper: + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + """Initialize SizeCheckWrapper instance. + + Args: + rfile (file): file of a limited size + maxlen (int): maximum length of the file being read + """ + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise errors.MaxSizeExceeded() + + def read(self, size=None): + """Read a chunk from rfile buffer and return it. + + Args: + size (int): amount of data to read + + Returns: + bytes: Chunk from rfile, limited by size if specified. + + """ + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + """Read a single line from rfile buffer and return it. + + Args: + size (int): minimum amount of data to read + + Returns: + bytes: One line from rfile. + + """ + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See https://github.com/cherrypy/cherrypy/issues/421 + if len(data) < 256 or data[-1:] == LF: + return EMPTY.join(res) + + def readlines(self, sizehint=0): + """Read all lines from rfile buffer and return them. + + Args: + sizehint (int): hint of minimum amount of data to read + + Returns: + list[bytes]: Lines of bytes read from rfile. + + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + """Release resources allocated for rfile.""" + self.rfile.close() + + def __iter__(self): + """Return file iterator.""" + return self + + def __next__(self): + """Generate next file chunk.""" + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + next = __next__ + + +class KnownLengthRFile: + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + """Initialize KnownLengthRFile instance. + + Args: + rfile (file): file of a known size + content_length (int): length of the file being read + + """ + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + """Read a chunk from rfile buffer and return it. + + Args: + size (int): amount of data to read + + Returns: + bytes: Chunk from rfile, limited by size if specified. + + """ + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + """Read a single line from rfile buffer and return it. + + Args: + size (int): minimum amount of data to read + + Returns: + bytes: One line from rfile. + + """ + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + """Read all lines from rfile buffer and return them. + + Args: + sizehint (int): hint of minimum amount of data to read + + Returns: + list[bytes]: Lines of bytes read from rfile. + + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + """Release resources allocated for rfile.""" + self.rfile.close() + + def __iter__(self): + """Return file iterator.""" + return self + + def __next__(self): + """Generate next file chunk.""" + data = next(self.rfile) + self.remaining -= len(data) + return data + + next = __next__ + + +class ChunkedRFile: + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + """Initialize ChunkedRFile instance. + + Args: + rfile (file): file encoded with the 'chunked' transfer encoding + maxlen (int): maximum length of the file being read + bufsize (int): size of the buffer used to read the file + """ + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise errors.MaxSizeExceeded( + 'Request Entity Too Large', self.maxlen, + ) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError('Bad chunked transfer size: ' + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +# if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError('Request Entity Too Large') + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + 'got ' + repr(crlf) + ')', + ) + + def read(self, size=None): + """Read a chunk from rfile buffer and return it. + + Args: + size (int): amount of data to read + + Returns: + bytes: Chunk from rfile, limited by size if specified. + + """ + data = EMPTY + + if size == 0: + return data + + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + self.buffer = EMPTY + + def readline(self, size=None): + """Read a single line from rfile buffer and return it. + + Args: + size (int): minimum amount of data to read + + Returns: + bytes: One line from rfile. + + """ + data = EMPTY + + if size == 0: + return data + + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + self.buffer = EMPTY + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + """Read all lines from rfile buffer and return them. + + Args: + sizehint (int): hint of minimum amount of data to read + + Returns: + list[bytes]: Lines of bytes read from rfile. + + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + """Read HTTP headers and yield them. + + Returns: + Generator: yields CRLF separated lines. + + """ + if not self.closed: + raise ValueError( + 'Cannot read trailers until the request body has been read.', + ) + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError('Illegal end of headers.') + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError('Request Entity Too Large') + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError('HTTP requires CRLF terminators') + + yield line + + def close(self): + """Release resources allocated for rfile.""" + self.rfile.close() + + +class HTTPRequest: + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + header_reader = HeaderReader() + """ + A HeaderReader instance or compatible reader. + """ + + def __init__(self, server, conn, proxy_mode=False, strict_mode=True): + """Initialize HTTP request container instance. + + Args: + server (HTTPServer): web server object receiving this request + conn (HTTPConnection): HTTP connection object for this request + proxy_mode (bool): whether this HTTPServer should behave as a PROXY + server for certain requests + strict_mode (bool): whether we should return a 400 Bad Request when + we encounter a request that a HTTP compliant client should not be + making + """ + self.server = server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = b'http' + if self.server.ssl_adapter is not None: + self.scheme = b'https' + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = '' + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + self.proxy_mode = proxy_mode + self.strict_mode = strict_mode + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper( + self.conn.rfile, + self.server.max_request_header_size, + ) + try: + success = self.read_request_line() + except errors.MaxSizeExceeded: + self.simple_response( + '414 Request-URI Too Long', + 'The Request-URI sent with the request exceeds the maximum ' + 'allowed bytes.', + ) + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except errors.MaxSizeExceeded: + self.simple_response( + '413 Request Entity Too Large', + 'The headers sent with the request exceed the maximum ' + 'allowed bytes.', + ) + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): + """Read and parse first line of the HTTP request. + + Returns: + bool: True if the request line is valid or False if it's malformed. + + """ + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response( + '400 Bad Request', 'HTTP requires CRLF terminators', + ) + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + if not req_protocol.startswith(b'HTTP/'): + self.simple_response( + '400 Bad Request', 'Malformed Request-Line: bad protocol', + ) + return False + rp = req_protocol[5:].split(b'.', 1) + if len(rp) != 2: + self.simple_response( + '400 Bad Request', 'Malformed Request-Line: bad version', + ) + return False + rp = tuple(map(int, rp)) # Minor.Major must be threat as integers + if rp > (1, 1): + self.simple_response( + '505 HTTP Version Not Supported', 'Cannot fulfill request', + ) + return False + except (ValueError, IndexError): + self.simple_response('400 Bad Request', 'Malformed Request-Line') + return False + + self.uri = uri + self.method = method.upper() + + if self.strict_mode and method != self.method: + resp = ( + 'Malformed method name: According to RFC 2616 ' + '(section 5.1.1) and its successors ' + 'RFC 7230 (section 3.1.1) and RFC 7231 (section 4.1) ' + 'method names are case-sensitive and uppercase.' + ) + self.simple_response('400 Bad Request', resp) + return False + + try: + if six.PY2: # FIXME: Figure out better way to do this + # Ref: https://stackoverflow.com/a/196392/595220 (like this?) + """This is a dummy check for unicode in URI.""" + ntou(bton(uri, 'ascii'), 'ascii') + scheme, authority, path, qs, fragment = urllib.parse.urlsplit(uri) + except UnicodeError: + self.simple_response('400 Bad Request', 'Malformed Request-URI') + return False + + uri_is_absolute_form = (scheme or authority) + + if self.method == b'OPTIONS': + # TODO: cover this branch with tests + path = ( + uri + # https://tools.ietf.org/html/rfc7230#section-5.3.4 + if (self.proxy_mode and uri_is_absolute_form) + else path + ) + elif self.method == b'CONNECT': + # TODO: cover this branch with tests + if not self.proxy_mode: + self.simple_response('405 Method Not Allowed') + return False + + # `urlsplit()` above parses "example.com:3128" as path part of URI. + # this is a workaround, which makes it detect netloc correctly + uri_split = urllib.parse.urlsplit(b'//' + uri) + _scheme, _authority, _path, _qs, _fragment = uri_split + _port = EMPTY + try: + _port = uri_split.port + except ValueError: + pass + + # FIXME: use third-party validation to make checks against RFC + # the validation doesn't take into account, that urllib parses + # invalid URIs without raising errors + # https://tools.ietf.org/html/rfc7230#section-5.3.3 + invalid_path = ( + _authority != uri + or not _port + or any((_scheme, _path, _qs, _fragment)) + ) + if invalid_path: + self.simple_response( + '400 Bad Request', + 'Invalid path in Request-URI: request-' + 'target must match authority-form.', + ) + return False + + authority = path = _authority + scheme = qs = fragment = EMPTY + else: + disallowed_absolute = ( + self.strict_mode + and not self.proxy_mode + and uri_is_absolute_form + ) + if disallowed_absolute: + # https://tools.ietf.org/html/rfc7230#section-5.3.2 + # (absolute form) + """Absolute URI is only allowed within proxies.""" + self.simple_response( + '400 Bad Request', + 'Absolute URI not allowed if server is not a proxy.', + ) + return False + + invalid_path = ( + self.strict_mode + and not uri.startswith(FORWARD_SLASH) + and not uri_is_absolute_form + ) + if invalid_path: + # https://tools.ietf.org/html/rfc7230#section-5.3.1 + # (origin_form) and + """Path should start with a forward slash.""" + resp = ( + 'Invalid path in Request-URI: request-target must contain ' + 'origin-form which starts with absolute-path (URI ' + 'starting with a slash "/").' + ) + self.simple_response('400 Bad Request', resp) + return False + + if fragment: + self.simple_response( + '400 Bad Request', + 'Illegal #fragment in Request-URI.', + ) + return False + + if path is None: + # FIXME: It looks like this case cannot happen + self.simple_response( + '400 Bad Request', + 'Invalid path in Request-URI.', + ) + return False + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." https://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not + # "/this/path". + try: + # TODO: Figure out whether exception can really happen here. + # It looks like it's caught on urlsplit() call above. + atoms = [ + urllib.parse.unquote_to_bytes(x) + for x in QUOTED_SLASH_REGEX.split(path) + ] + except ValueError as ex: + self.simple_response('400 Bad Request', ex.args[0]) + return False + path = QUOTED_SLASH.join(atoms) + + if not path.startswith(FORWARD_SLASH): + path = FORWARD_SLASH + path + + if scheme is not EMPTY: + self.scheme = scheme + self.authority = authority + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + sp = int(self.server.protocol[5]), int(self.server.protocol[7]) + + if sp[0] != rp[0]: + self.simple_response('505 HTTP Version Not Supported') + return False + + self.request_protocol = req_protocol + self.response_protocol = 'HTTP/%s.%s' % min(rp, sp) + + return True + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" + # then all the http headers + try: + self.header_reader(self.rfile, self.inheaders) + except ValueError as ex: + self.simple_response('400 Bad Request', ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + + try: + cl = int(self.inheaders.get(b'Content-Length', 0)) + except ValueError: + self.simple_response( + '400 Bad Request', + 'Malformed Content-Length Header.', + ) + return False + + if mrbs and cl > mrbs: + self.simple_response( + '413 Request Entity Too Large', + 'The entity sent with the request exceeds the maximum ' + 'allowed bytes.', + ) + return False + + # Persistent connection support + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 + if self.inheaders.get(b'Connection', b'') == b'close': + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get(b'Connection', b'') != b'Keep-Alive': + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == 'HTTP/1.1': + te = self.inheaders.get(b'Transfer-Encoding') + if te: + te = [x.strip().lower() for x in te.split(b',') if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == b'chunked': + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response('501 Unimplemented') + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get(b'Expect', b'') == b'100-continue': + # Don't use simple_response here, because it emits headers + # we don't want. See + # https://github.com/cherrypy/cherrypy/issues/951 + msg = self.server.protocol.encode('ascii') + msg += b' 100 Continue\r\n\r\n' + try: + self.conn.wfile.write(msg) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + return True + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get(b'Content-Length', 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response( + '413 Request Entity Too Large', + 'The entity sent with the request exceeds the ' + 'maximum allowed bytes.', + ) + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + self.ready and self.ensure_headers_sent() + + if self.chunked_write: + self.conn.wfile.write(b'0\r\n\r\n') + + def simple_response(self, status, msg=''): + """Write a simple response back to the client.""" + status = str(status) + proto_status = '%s %s\r\n' % (self.server.protocol, status) + content_length = 'Content-Length: %s\r\n' % len(msg) + content_type = 'Content-Type: text/plain\r\n' + buf = [ + proto_status.encode('ISO-8859-1'), + content_length.encode('ISO-8859-1'), + content_type.encode('ISO-8859-1'), + ] + + if status[:3] in ('413', '414'): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append(b'Connection: close\r\n') + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = '400 Bad Request' + + buf.append(CRLF) + if msg: + if isinstance(msg, six.text_type): + msg = msg.encode('ISO-8859-1') + buf.append(msg) + + try: + self.conn.wfile.write(EMPTY.join(buf)) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + + def ensure_headers_sent(self): + """Ensure headers are sent to the client if not already sent.""" + if not self.sent_headers: + self.sent_headers = True + self.send_headers() + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + chunk_size_hex = hex(len(chunk))[2:].encode('ascii') + buf = [chunk_size_hex, CRLF, chunk, CRLF] + self.conn.wfile.write(EMPTY.join(buf)) + else: + self.conn.wfile.write(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif b'content-length' not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + needs_chunked = ( + self.response_protocol == 'HTTP/1.1' + and self.method != b'HEAD' + ) + if needs_chunked: + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append((b'Transfer-Encoding', b'chunked')) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if b'connection' not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append((b'Connection', b'close')) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append((b'Connection', b'Keep-Alive')) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if b'date' not in hkeys: + self.outheaders.append(( + b'Date', + email.utils.formatdate(usegmt=True).encode('ISO-8859-1'), + )) + + if b'server' not in hkeys: + self.outheaders.append(( + b'Server', + self.server.server_name.encode('ISO-8859-1'), + )) + + proto = self.server.protocol.encode('ascii') + buf = [proto + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.write(EMPTY.join(buf)) + + +class HTTPConnection: + """An HTTP connection (active socket).""" + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = io.DEFAULT_BUFFER_SIZE + wbufsize = io.DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + peercreds_enabled = False + peercreds_resolve_enabled = False + + # Fields set by ConnectionManager. + closeable = False + last_used = None + ready_with_data = False + + def __init__(self, server, sock, makefile=MakeFile): + """Initialize HTTPConnection instance. + + Args: + server (HTTPServer): web server object receiving this request + sock (socket._socketobject): the raw socket object (usually + TCP) for this connection + makefile (file): a fileobject class for reading from the socket + """ + self.server = server + self.socket = sock + self.rfile = makefile(sock, 'rb', self.rbufsize) + self.wfile = makefile(sock, 'wb', self.wbufsize) + self.requests_seen = 0 + + self.peercreds_enabled = self.server.peercreds_enabled + self.peercreds_resolve_enabled = self.server.peercreds_resolve_enabled + + # LRU cached methods: + # Ref: https://stackoverflow.com/a/14946506/595220 + self.resolve_peer_creds = ( + lru_cache(maxsize=1)(self.resolve_peer_creds) + ) + self.get_peer_creds = ( + lru_cache(maxsize=1)(self.get_peer_creds) + ) + + def communicate(self): + """Read each request and respond appropriately. + + Returns true if the connection should be kept open. + """ + request_seen = False + try: + req = self.RequestHandlerClass(self.server, self) + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return False + + request_seen = True + req.respond() + if not req.close_connection: + return True + except socket.error as ex: + errnum = ex.args[0] + # sadly SSL sockets return a different (longer) time out string + timeout_errs = 'timed out', 'The read operation timed out' + if errnum in timeout_errs: + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See https://github.com/cherrypy/cherrypy/issues/853 + if (not request_seen) or (req and req.started_request): + self._conditional_error(req, '408 Request Timeout') + elif errnum not in errors.socket_errors_to_ignore: + self.server.error_log( + 'socket.error %s' % repr(errnum), + level=logging.WARNING, traceback=True, + ) + self._conditional_error(req, '500 Internal Server Error') + except (KeyboardInterrupt, SystemExit): + raise + except errors.FatalSSLAlert: + pass + except errors.NoSSLError: + self._handle_no_ssl(req) + except Exception as ex: + self.server.error_log( + repr(ex), level=logging.ERROR, traceback=True, + ) + self._conditional_error(req, '500 Internal Server Error') + return False + + linger = False + + def _handle_no_ssl(self, req): + if not req or req.sent_headers: + return + # Unwrap wfile + try: + resp_sock = self.socket._sock + except AttributeError: + # self.socket is of OpenSSL.SSL.Connection type + resp_sock = self.socket._socket + self.wfile = StreamWriter(resp_sock, 'wb', self.wbufsize) + msg = ( + 'The client sent a plain HTTP request, but ' + 'this server only speaks HTTPS on this port.' + ) + req.simple_response('400 Bad Request', msg) + self.linger = True + + def _conditional_error(self, req, response): + """Respond with an error. + + Don't bother writing if a response + has already started being written. + """ + if not req or req.sent_headers: + return + + try: + req.simple_response(response) + except errors.FatalSSLAlert: + pass + except errors.NoSSLError: + self._handle_no_ssl(req) + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + self._close_kernel_socket() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + def get_peer_creds(self): # LRU cached on per-instance basis, see __init__ + """Return the PID/UID/GID tuple of the peer socket for UNIX sockets. + + This function uses SO_PEERCRED to query the UNIX PID, UID, GID + of the peer, which is only available if the bind address is + a UNIX domain socket. + + Raises: + NotImplementedError: in case of unsupported socket type + RuntimeError: in case of SO_PEERCRED lookup unsupported or disabled + + """ + PEERCRED_STRUCT_DEF = '3i' + + if IS_WINDOWS or self.socket.family != socket.AF_UNIX: + raise NotImplementedError( + 'SO_PEERCRED is only supported in Linux kernel and WSL', + ) + elif not self.peercreds_enabled: + raise RuntimeError( + 'Peer creds lookup is disabled within this server', + ) + + try: + peer_creds = self.socket.getsockopt( + # FIXME: Use LOCAL_CREDS for BSD-like OSs + # Ref: https://gist.github.com/LucaFilipozzi/e4f1e118202aff27af6aadebda1b5d91 # noqa + socket.SOL_SOCKET, socket.SO_PEERCRED, + struct.calcsize(PEERCRED_STRUCT_DEF), + ) + except socket.error as socket_err: + """Non-Linux kernels don't support SO_PEERCRED. + + Refs: + http://welz.org.za/notes/on-peer-cred.html + https://github.com/daveti/tcpSockHack + msdn.microsoft.com/en-us/commandline/wsl/release_notes#build-15025 + """ + six.raise_from( # 3.6+: raise RuntimeError from socket_err + RuntimeError, + socket_err, + ) + else: + pid, uid, gid = struct.unpack(PEERCRED_STRUCT_DEF, peer_creds) + return pid, uid, gid + + @property + def peer_pid(self): + """Return the id of the connected peer process.""" + pid, _, _ = self.get_peer_creds() + return pid + + @property + def peer_uid(self): + """Return the user id of the connected peer process.""" + _, uid, _ = self.get_peer_creds() + return uid + + @property + def peer_gid(self): + """Return the group id of the connected peer process.""" + _, _, gid = self.get_peer_creds() + return gid + + def resolve_peer_creds(self): # LRU cached on per-instance basis + """Return the username and group tuple of the peercreds if available. + + Raises: + NotImplementedError: in case of unsupported OS + RuntimeError: in case of UID/GID lookup unsupported or disabled + + """ + if not IS_UID_GID_RESOLVABLE: + raise NotImplementedError( + 'UID/GID lookup is unavailable under current platform. ' + 'It can only be done under UNIX-like OS ' + 'but not under the Google App Engine', + ) + elif not self.peercreds_resolve_enabled: + raise RuntimeError( + 'UID/GID lookup is disabled within this server', + ) + + user = pwd.getpwuid(self.peer_uid).pw_name # [0] + group = grp.getgrgid(self.peer_gid).gr_name # [0] + + return user, group + + @property + def peer_user(self): + """Return the username of the connected peer process.""" + user, _ = self.resolve_peer_creds() + return user + + @property + def peer_group(self): + """Return the group of the connected peer process.""" + _, group = self.resolve_peer_creds() + return group + + def _close_kernel_socket(self): + """Close kernel socket in outdated Python versions. + + On old Python versions, + Python's socket module does NOT call close on the kernel + socket when you call socket.close(). We do so manually here + because we want this server to send a FIN TCP segment + immediately. Note this must be called *before* calling + socket.close(), because the latter drops its reference to + the kernel socket. + """ + if six.PY2 and hasattr(self.socket, '_sock'): + self.socket._sock.close() + + +class HTTPServer: + """An HTTP server.""" + + _bind_addr = '127.0.0.1' + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create. + + (default -1 = no limit)""" + + server_name = None + """The name of the server; defaults to ``self.version``.""" + + protocol = 'HTTP/1.1' + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections. + + (default 5).""" + + shutdown_timeout = 5 + """The total time to wait for worker threads to cleanly exit. + + Specified in seconds.""" + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = 'Cheroot/' + __version__ + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``. + """ + + ready = False + """Internal flag which indicating the socket is accepting connections.""" + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of ssl.Adapter (or a subclass). + + You must have the corresponding SSL driver library installed. + """ + + peercreds_enabled = False + """If True, peer cred lookup can be performed via UNIX domain socket.""" + + peercreds_resolve_enabled = False + """If True, username/group will be looked up in the OS from peercreds.""" + + keep_alive_conn_limit = 10 + """The maximum number of waiting keep-alive connections that will be kept open. + + Default is 10. Set to None to have unlimited connections.""" + + def __init__( + self, bind_addr, gateway, + minthreads=10, maxthreads=-1, server_name=None, + peercreds_enabled=False, peercreds_resolve_enabled=False, + ): + """Initialize HTTPServer instance. + + Args: + bind_addr (tuple): network interface to listen to + gateway (Gateway): gateway for processing HTTP requests + minthreads (int): minimum number of threads for HTTP thread pool + maxthreads (int): maximum number of threads for HTTP thread pool + server_name (str): web server name to be advertised via Server + HTTP header + """ + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = threadpool.ThreadPool( + self, min=minthreads or 1, max=maxthreads, + ) + self.connections = connections.ConnectionManager(self) + + if not server_name: + server_name = self.version + self.server_name = server_name + self.peercreds_enabled = peercreds_enabled + self.peercreds_resolve_enabled = ( + peercreds_resolve_enabled and peercreds_enabled + ) + self.clear_stats() + + def clear_stats(self): + """Reset server stat counters..""" + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, 'qsize', None), + 'Threads': lambda s: len(getattr(self.requests, '_threads', [])), + 'Threads Idle': lambda s: getattr(self.requests, 'idle', None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum( + [w['Requests'](w) for w in s['Worker Threads'].values()], 0, + ), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0, + ), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) for w in s['Worker Threads'].values()], + 0, + ), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( + [w['Work Time'](w) for w in s['Worker Threads'].values()], 0, + ), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0, + ), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0, + ), + 'Worker Threads': {}, + } + logging.statistics['Cheroot HTTPServer %d' % id(self)] = self.stats + + def runtime(self): + """Return server uptime.""" + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + """Render Server instance representing bind address.""" + return '%s.%s(%r)' % ( + self.__module__, self.__class__.__name__, + self.bind_addr, + ) + + @property + def bind_addr(self): + """Return the interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string. + + Systemd socket activation is automatic and doesn't require tempering + with this variable. + """ + return self._bind_addr + + @bind_addr.setter + def bind_addr(self, value): + """Set the interface on which to listen for connections.""" + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError( + "Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + 'to listen on all active interfaces.', + ) + self._bind_addr = value + + def safe_start(self): + """Run the server forever, and stop it cleanly on exit.""" + try: + self.start() + except (KeyboardInterrupt, IOError): + # The time.sleep call might raise + # "IOError: [Errno 4] Interrupted function call" on KBInt. + self.error_log('Keyboard Interrupt: shutting down') + self.stop() + raise + except SystemExit: + self.error_log('SystemExit raised: shutting down') + self.stop() + raise + + def prepare(self): + """Prepare server to serving requests. + + It binds a socket's port, setups the socket to ``listen()`` and does + other preparing things. + """ + self._interrupt = None + + if self.software is None: + self.software = '%s Server' % self.version + + # Select the appropriate socket + self.socket = None + msg = 'No socket could be created' + if os.getenv('LISTEN_PID', None): + # systemd socket activation + self.socket = socket.fromfd(3, socket.AF_INET, socket.SOCK_STREAM) + elif isinstance(self.bind_addr, (six.text_type, six.binary_type)): + # AF_UNIX socket + try: + self.bind_unix_socket(self.bind_addr) + except socket.error as serr: + msg = '%s -- (%s: %s)' % (msg, self.bind_addr, serr) + six.raise_from(socket.error(msg), serr) + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 + # addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE, + ) + except socket.gaierror: + sock_type = socket.AF_INET + bind_addr = self.bind_addr + + if ':' in host: + sock_type = socket.AF_INET6 + bind_addr = bind_addr + (0, 0) + + info = [(sock_type, socket.SOCK_STREAM, 0, '', bind_addr)] + + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + break + except socket.error as serr: + msg = '%s -- (%s: %s)' % (msg, sa, serr) + if self.socket: + self.socket.close() + self.socket = None + + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + + def serve(self): + """Serve requests, after invoking :func:`prepare()`.""" + while self.ready: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.error_log( + 'Error in HTTPServer.tick', level=logging.ERROR, + traceback=True, + ) + + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def start(self): + """Run the server forever. + + It is shortcut for invoking :func:`prepare()` then :func:`serve()`. + """ + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrypy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self.prepare() + self.serve() + + def error_log(self, msg='', level=20, traceback=False): + """Write error message to log. + + Args: + msg (str): error message + level (int): logging level + traceback (bool): add traceback to output or not + """ + # Override this in subclasses as desired + sys.stderr.write(msg + '\n') + sys.stderr.flush() + if traceback: + tblines = traceback_.format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + sock = self.prepare_socket( + self.bind_addr, + family, type, proto, + self.nodelay, self.ssl_adapter, + ) + sock = self.socket = self.bind_socket(sock, self.bind_addr) + self.bind_addr = self.resolve_real_bind_addr(sock) + return sock + + def bind_unix_socket(self, bind_addr): + """Create (or recreate) a UNIX socket object.""" + if IS_WINDOWS: + """ + Trying to access socket.AF_UNIX under Windows + causes an AttributeError. + """ + raise ValueError( # or RuntimeError? + 'AF_UNIX sockets are not supported under Windows.', + ) + + fs_permissions = 0o777 # TODO: allow changing mode + + try: + # Make possible reusing the socket... + os.unlink(self.bind_addr) + except OSError: + """ + File does not exist, which is the primary goal anyway. + """ + except TypeError as typ_err: + err_msg = str(typ_err) + if ( + 'remove() argument 1 must be encoded ' + 'string without null bytes, not unicode' + not in err_msg + and 'embedded NUL character' not in err_msg # py34 + and 'argument must be a ' + 'string without NUL characters' not in err_msg # pypy2 + ): + raise + except ValueError as val_err: + err_msg = str(val_err) + if ( + 'unlink: embedded null ' + 'character in path' not in err_msg + and 'embedded null byte' not in err_msg + and 'argument must be a ' + 'string without NUL characters' not in err_msg # pypy3 + ): + raise + + sock = self.prepare_socket( + bind_addr=bind_addr, + family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, + nodelay=self.nodelay, ssl_adapter=self.ssl_adapter, + ) + + try: + """Linux way of pre-populating fs mode permissions.""" + # Allow everyone access the socket... + os.fchmod(sock.fileno(), fs_permissions) + FS_PERMS_SET = True + except OSError: + FS_PERMS_SET = False + + try: + sock = self.bind_socket(sock, bind_addr) + except socket.error: + sock.close() + raise + + bind_addr = self.resolve_real_bind_addr(sock) + + try: + """FreeBSD/macOS pre-populating fs mode permissions.""" + if not FS_PERMS_SET: + try: + os.lchmod(bind_addr, fs_permissions) + except AttributeError: + os.chmod(bind_addr, fs_permissions, follow_symlinks=False) + FS_PERMS_SET = True + except OSError: + pass + + if not FS_PERMS_SET: + self.error_log( + 'Failed to set socket fs mode permissions', + level=logging.WARNING, + ) + + self.bind_addr = bind_addr + self.socket = sock + return sock + + @staticmethod + def prepare_socket(bind_addr, family, type, proto, nodelay, ssl_adapter): + """Create and prepare the socket object.""" + sock = socket.socket(family, type, proto) + connections.prevent_socket_inheritance(sock) + + host, port = bind_addr[:2] + IS_EPHEMERAL_PORT = port == 0 + + if not (IS_WINDOWS or IS_EPHEMERAL_PORT): + """Enable SO_REUSEADDR for the current socket. + + Skip for Windows (has different semantics) + or ephemeral ports (can steal ports from others). + + Refs: + * https://msdn.microsoft.com/en-us/library/ms740621(v=vs.85).aspx + * https://github.com/cherrypy/cheroot/issues/114 + * https://gavv.github.io/blog/ephemeral-port-reuse/ + """ + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if nodelay and not isinstance( + bind_addr, + (six.text_type, six.binary_type), + ): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if ssl_adapter is not None: + sock = ssl_adapter.bind(sock) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See + # https://github.com/cherrypy/cherrypy/issues/871. + listening_ipv6 = ( + hasattr(socket, 'AF_INET6') + and family == socket.AF_INET6 + and host in ('::', '::0', '::0.0.0.0') + ) + if listening_ipv6: + try: + sock.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0, + ) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + return sock + + @staticmethod + def bind_socket(socket_, bind_addr): + """Bind the socket to given interface.""" + socket_.bind(bind_addr) + return socket_ + + @staticmethod + def resolve_real_bind_addr(socket_): + """Retrieve actual bind addr from bound socket.""" + # FIXME: keep requested bind_addr separate real bound_addr (port + # is different in case of ephemeral port 0) + bind_addr = socket_.getsockname() + if socket_.family in ( + # Windows doesn't have socket.AF_UNIX, so not using it in check + socket.AF_INET, + socket.AF_INET6, + ): + """UNIX domain sockets are strings or bytes. + + In case of bytes with a leading null-byte it's an abstract socket. + """ + return bind_addr[:2] + + if isinstance(bind_addr, six.binary_type): + bind_addr = bton(bind_addr) + + return bind_addr + + def tick(self): + """Accept a new connection and put it on the Queue.""" + if not self.ready: + return + + conn = self.connections.get_conn(self.socket) + if conn: + try: + self.requests.put(conn) + except queue.Full: + # Just drop the conn. TODO: write 503 back? + conn.close() + + self.connections.expire() + + @property + def interrupt(self): + """Flag interrupt of the server.""" + return self._interrupt + + @interrupt.setter + def interrupt(self, interrupt): + """Perform the shutdown of this server and save the exception.""" + self._interrupt = True + self.stop() + self._interrupt = interrupt + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + sock = getattr(self, 'socket', None) + if sock: + if not isinstance( + self.bind_addr, + (six.text_type, six.binary_type), + ): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + # Changed to use error code and not message + # See + # https://github.com/cherrypy/cherrypy/issues/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, + ): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See + # https://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, 'close'): + sock.close() + self.socket = None + + self.connections.close() + self.requests.stop(self.shutdown_timeout) + + +class Gateway: + """Base class to interface HTTPServer with other systems, such as WSGI.""" + + def __init__(self, req): + """Initialize Gateway instance with request. + + Args: + req (HTTPRequest): current HTTP request + """ + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplementedError # pragma: no cover + + +# These may either be ssl.Adapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cheroot.ssl.builtin.BuiltinSSLAdapter', + 'pyopenssl': 'cheroot.ssl.pyopenssl.pyOpenSSLAdapter', +} + + +def get_ssl_adapter_class(name='builtin'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, six.string_types): + last_dot = adapter.rfind('.') + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter diff --git a/lib/cheroot/ssl/__init__.py b/lib/cheroot/ssl/__init__.py new file mode 100644 index 00000000..d45fd7f1 --- /dev/null +++ b/lib/cheroot/ssl/__init__.py @@ -0,0 +1,52 @@ +"""Implementation of the SSL adapter base interface.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from abc import ABCMeta, abstractmethod + +from six import add_metaclass + + +@add_metaclass(ABCMeta) +class Adapter: + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> + socket file object`` + """ + + @abstractmethod + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None, + ): + """Set up certificates, private key ciphers and reset context.""" + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + self.ciphers = ciphers + self.context = None + + @abstractmethod + def bind(self, sock): + """Wrap and return the given socket.""" + return sock + + @abstractmethod + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + raise NotImplementedError # pragma: no cover + + @abstractmethod + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + raise NotImplementedError # pragma: no cover + + @abstractmethod + def makefile(self, sock, mode='r', bufsize=-1): + """Return socket file object.""" + raise NotImplementedError # pragma: no cover diff --git a/lib/cheroot/ssl/builtin.py b/lib/cheroot/ssl/builtin.py new file mode 100644 index 00000000..d131b2f4 --- /dev/null +++ b/lib/cheroot/ssl/builtin.py @@ -0,0 +1,210 @@ +""" +A library for integrating Python's builtin ``ssl`` library with Cheroot. + +The ssl module must be importable for SSL functionality. + +To use this module, set ``HTTPServer.ssl_adapter`` to an instance of +``BuiltinSSLAdapter``. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +try: + import ssl +except ImportError: + ssl = None + +try: + from _pyio import DEFAULT_BUFFER_SIZE +except ImportError: + try: + from io import DEFAULT_BUFFER_SIZE + except ImportError: + DEFAULT_BUFFER_SIZE = -1 + +import six + +from . import Adapter +from .. import errors +from .._compat import IS_ABOVE_OPENSSL10 +from ..makefile import StreamReader, StreamWriter + +if six.PY2: + import socket + generic_socket_error = socket.error + del socket +else: + generic_socket_error = OSError + + +def _assert_ssl_exc_contains(exc, *msgs): + """Check whether SSL exception contains either of messages provided.""" + if len(msgs) < 1: + raise TypeError( + '_assert_ssl_exc_contains() requires ' + 'at least one message to be passed.', + ) + err_msg_lower = str(exc).lower() + return any(m.lower() in err_msg_lower for m in msgs) + + +class BuiltinSSLAdapter(Adapter): + """A wrapper for integrating Python's builtin ssl module with Cheroot.""" + + certificate = None + """The filename of the server SSL certificate.""" + + private_key = None + """The filename of the server's private key file.""" + + certificate_chain = None + """The filename of the certificate chain file.""" + + context = None + """The ssl.SSLContext that will be used to wrap sockets.""" + + ciphers = None + """The ciphers list of SSL.""" + + CERT_KEY_TO_ENV = { + 'subject': 'SSL_CLIENT_S_DN', + 'issuer': 'SSL_CLIENT_I_DN', + } + + CERT_KEY_TO_LDAP_CODE = { + 'countryName': 'C', + 'stateOrProvinceName': 'ST', + 'localityName': 'L', + 'organizationName': 'O', + 'organizationalUnitName': 'OU', + 'commonName': 'CN', + 'emailAddress': 'Email', + } + + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None, + ): + """Set up context in addition to base class properties if available.""" + if ssl is None: + raise ImportError('You must install the ssl module to use HTTPS.') + + super(BuiltinSSLAdapter, self).__init__( + certificate, private_key, certificate_chain, ciphers, + ) + + self.context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH, + cafile=certificate_chain, + ) + self.context.load_cert_chain(certificate, private_key) + if self.ciphers is not None: + self.context.set_ciphers(ciphers) + + def bind(self, sock): + """Wrap and return the given socket.""" + return super(BuiltinSSLAdapter, self).bind(sock) + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + EMPTY_RESULT = None, {} + try: + s = self.context.wrap_socket( + sock, do_handshake_on_connect=True, server_side=True, + ) + except ssl.SSLError as ex: + if ex.errno == ssl.SSL_ERROR_EOF: + # This is almost certainly due to the cherrypy engine + # 'pinging' the socket to assert it's connectable; + # the 'ping' isn't SSL. + return EMPTY_RESULT + elif ex.errno == ssl.SSL_ERROR_SSL: + if _assert_ssl_exc_contains(ex, 'http request'): + # The client is speaking HTTP to an HTTPS server. + raise errors.NoSSLError + + # Check if it's one of the known errors + # Errors that are caught by PyOpenSSL, but thrown by + # built-in ssl + _block_errors = ( + 'unknown protocol', 'unknown ca', 'unknown_ca', + 'unknown error', + 'https proxy request', 'inappropriate fallback', + 'wrong version number', + 'no shared cipher', 'certificate unknown', + 'ccs received early', + 'certificate verify failed', # client cert w/o trusted CA + ) + if _assert_ssl_exc_contains(ex, *_block_errors): + # Accepted error, let's pass + return EMPTY_RESULT + elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'): + # This error is thrown by builtin SSL after a timeout + # when client is speaking HTTP to an HTTPS server. + # The connection can safely be dropped. + return EMPTY_RESULT + raise + except generic_socket_error as exc: + """It is unclear why exactly this happens. + + It's reproducible only with openssl>1.0 and stdlib ``ssl`` wrapper. + In CherryPy it's triggered by Checker plugin, which connects + to the app listening to the socket port in TLS mode via plain + HTTP during startup (from the same process). + + + Ref: https://github.com/cherrypy/cherrypy/issues/1618 + """ + is_error0 = exc.args == (0, 'Error') + + if is_error0 and IS_ABOVE_OPENSSL10: + return EMPTY_RESULT + raise + return s, self.get_environ(s) + + # TODO: fill this out more with mod ssl env + def get_environ(self, sock): + """Create WSGI environ entries to be merged into each request.""" + cipher = sock.cipher() + ssl_environ = { + 'wsgi.url_scheme': 'https', + 'HTTPS': 'on', + 'SSL_PROTOCOL': cipher[1], + 'SSL_CIPHER': cipher[0], + # SSL_VERSION_INTERFACE string The mod_ssl program version + # SSL_VERSION_LIBRARY string The OpenSSL program version + } + + if self.context and self.context.verify_mode != ssl.CERT_NONE: + client_cert = sock.getpeercert() + if client_cert: + for cert_key, env_var in self.CERT_KEY_TO_ENV.items(): + ssl_environ.update( + self.env_dn_dict(env_var, client_cert.get(cert_key)), + ) + + return ssl_environ + + def env_dn_dict(self, env_prefix, cert_value): + """Return a dict of WSGI environment variables for a client cert DN. + + E.g. SSL_CLIENT_S_DN_CN, SSL_CLIENT_S_DN_C, etc. + See SSL_CLIENT_S_DN_x509 at + https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars. + """ + if not cert_value: + return {} + + env = {} + for rdn in cert_value: + for attr_name, val in rdn: + attr_code = self.CERT_KEY_TO_LDAP_CODE.get(attr_name) + if attr_code: + env['%s_%s' % (env_prefix, attr_code)] = val + return env + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + """Return socket file object.""" + cls = StreamReader if 'r' in mode else StreamWriter + return cls(sock, mode, bufsize) diff --git a/lib/cheroot/ssl/pyopenssl.py b/lib/cheroot/ssl/pyopenssl.py new file mode 100644 index 00000000..f51be0d8 --- /dev/null +++ b/lib/cheroot/ssl/pyopenssl.py @@ -0,0 +1,343 @@ +""" +A library for integrating pyOpenSSL with Cheroot. + +The OpenSSL module must be importable for SSL functionality. +You can obtain it from `here `_. + +To use this module, set HTTPServer.ssl_adapter to an instance of +ssl.Adapter. There are two ways to use SSL: + +Method One +---------- + + * ``ssl_adapter.context``: an instance of SSL.Context. + +If this is not None, it is assumed to be an SSL.Context instance, +and will be passed to SSL.Connection on bind(). The developer is +responsible for forming a valid Context object. This approach is +to be preferred for more flexibility, e.g. if the cert and key are +streams instead of files, or need decryption, or SSL.SSLv3_METHOD +is desired instead of the default SSL.SSLv23_METHOD, etc. Consult +the pyOpenSSL documentation for complete options. + +Method Two (shortcut) +--------------------- + + * ``ssl_adapter.certificate``: the filename of the server SSL certificate. + * ``ssl_adapter.private_key``: the filename of the server's private key file. + +Both are None by default. If ssl_adapter.context is None, but .private_key +and .certificate are both given and valid, they will be read, and the +context will be automatically created from them. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket +import threading +import time + +import six + +try: + from OpenSSL import SSL + from OpenSSL import crypto + + try: + ssl_conn_type = SSL.Connection + except AttributeError: + ssl_conn_type = SSL.ConnectionType +except ImportError: + SSL = None + +from . import Adapter +from .. import errors, server as cheroot_server +from ..makefile import StreamReader, StreamWriter + + +class SSLFileobjectMixin: + """Base mixin for an SSL socket stream.""" + + ssl_timeout = 3 + ssl_retry = .01 + + def _safe_call(self, is_reader, call, *args, **kwargs): + """Wrap the given call with SSL error-trapping. + + is_reader: if False EOF errors will be raised. If True, EOF errors + will return "" (to emulate normal sockets). + """ + start = time.time() + while True: + try: + return call(*args, **kwargs) + except SSL.WantReadError: + # Sleep and try again. This is dangerous, because it means + # the rest of the stack has no way of differentiating + # between a "new handshake" error and "client dropped". + # Note this isn't an endless loop: there's a timeout below. + # Ref: https://stackoverflow.com/a/5133568/595220 + time.sleep(self.ssl_retry) + except SSL.WantWriteError: + time.sleep(self.ssl_retry) + except SSL.SysCallError as e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return b'' + + errnum = e.args[0] + if is_reader and errnum in errors.socket_errors_to_ignore: + return b'' + raise socket.error(errnum) + except SSL.Error as e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return b'' + + thirdarg = None + try: + thirdarg = e.args[0][0][2] + except IndexError: + pass + + if thirdarg == 'http request': + # The client is talking HTTP to an HTTPS server. + raise errors.NoSSLError() + + raise errors.FatalSSLAlert(*e.args) + + if time.time() - start > self.ssl_timeout: + raise socket.timeout('timed out') + + def recv(self, size): + """Receive message of a size from the socket.""" + return self._safe_call( + True, + super(SSLFileobjectMixin, self).recv, + size, + ) + + def readline(self, size=-1): + """Receive message of a size from the socket. + + Matches the following interface: + https://docs.python.org/3/library/io.html#io.IOBase.readline + """ + return self._safe_call( + True, + super(SSLFileobjectMixin, self).readline, + size, + ) + + def sendall(self, *args, **kwargs): + """Send whole message to the socket.""" + return self._safe_call( + False, + super(SSLFileobjectMixin, self).sendall, + *args, **kwargs + ) + + def send(self, *args, **kwargs): + """Send some part of message to the socket.""" + return self._safe_call( + False, + super(SSLFileobjectMixin, self).send, + *args, **kwargs + ) + + +class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): + """SSL file object attached to a socket object.""" + + +class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): + """SSL file object attached to a socket object.""" + + +class SSLConnectionProxyMeta: + """Metaclass for generating a bunch of proxy methods.""" + + def __new__(mcl, name, bases, nmspc): + """Attach a list of proxy methods to a new class.""" + proxy_methods = ( + 'get_context', 'pending', 'send', 'write', 'recv', 'read', + 'renegotiate', 'bind', 'listen', 'connect', 'accept', + 'setblocking', 'fileno', 'close', 'get_cipher_list', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'makefile', 'get_app_data', 'set_app_data', 'state_string', + 'sock_shutdown', 'get_peer_certificate', 'want_read', + 'want_write', 'set_connect_state', 'set_accept_state', + 'connect_ex', 'sendall', 'settimeout', 'gettimeout', + 'shutdown', + ) + proxy_methods_no_args = ( + 'shutdown', + ) + + proxy_props = ( + 'family', + ) + + def lock_decorator(method): + """Create a proxy method for a new class.""" + def proxy_wrapper(self, *args): + self._lock.acquire() + try: + new_args = ( + args[:] if method not in proxy_methods_no_args else [] + ) + return getattr(self._ssl_conn, method)(*new_args) + finally: + self._lock.release() + return proxy_wrapper + for m in proxy_methods: + nmspc[m] = lock_decorator(m) + nmspc[m].__name__ = m + + def make_property(property_): + """Create a proxy method for a new class.""" + def proxy_prop_wrapper(self): + return getattr(self._ssl_conn, property_) + proxy_prop_wrapper.__name__ = property_ + return property(proxy_prop_wrapper) + for p in proxy_props: + nmspc[p] = make_property(p) + + # Doesn't work via super() for some reason. + # Falling back to type() instead: + return type(name, bases, nmspc) + + +@six.add_metaclass(SSLConnectionProxyMeta) +class SSLConnection: + """A thread-safe wrapper for an SSL.Connection. + + ``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``. + """ + + def __init__(self, *args): + """Initialize SSLConnection instance.""" + self._ssl_conn = SSL.Connection(*args) + self._lock = threading.RLock() + + +class pyOpenSSLAdapter(Adapter): + """A wrapper for integrating pyOpenSSL with Cheroot.""" + + certificate = None + """The filename of the server SSL certificate.""" + + private_key = None + """The filename of the server's private key file.""" + + certificate_chain = None + """Optional. The filename of CA's intermediate certificate bundle. + + This is needed for cheaper "chained root" SSL certificates, and should be + left as None if not required.""" + + context = None + """An instance of SSL.Context.""" + + ciphers = None + """The ciphers list of SSL.""" + + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None, + ): + """Initialize OpenSSL Adapter instance.""" + if SSL is None: + raise ImportError('You must install pyOpenSSL to use HTTPS.') + + super(pyOpenSSLAdapter, self).__init__( + certificate, private_key, certificate_chain, ciphers, + ) + + self._environ = None + + def bind(self, sock): + """Wrap and return the given socket.""" + if self.context is None: + self.context = self.get_context() + conn = SSLConnection(self.context, sock) + self._environ = self.get_environ() + return conn + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + return sock, self._environ.copy() + + def get_context(self): + """Return an SSL.Context from self attributes.""" + # See https://code.activestate.com/recipes/442473/ + c = SSL.Context(SSL.SSLv23_METHOD) + c.use_privatekey_file(self.private_key) + if self.certificate_chain: + c.load_verify_locations(self.certificate_chain) + c.use_certificate_file(self.certificate) + return c + + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + ssl_environ = { + 'HTTPS': 'on', + # pyOpenSSL doesn't provide access to any of these AFAICT + # 'SSL_PROTOCOL': 'SSLv2', + # SSL_CIPHER string The cipher specification name + # SSL_VERSION_INTERFACE string The mod_ssl program version + # SSL_VERSION_LIBRARY string The OpenSSL program version + } + + if self.certificate: + # Server certificate attributes + cert = open(self.certificate, 'rb').read() + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + ssl_environ.update({ + 'SSL_SERVER_M_VERSION': cert.get_version(), + 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), + # 'SSL_SERVER_V_START': + # Validity of server's certificate (start time), + # 'SSL_SERVER_V_END': + # Validity of server's certificate (end time), + }) + + for prefix, dn in [ + ('I', cert.get_issuer()), + ('S', cert.get_subject()), + ]: + # X509Name objects don't seem to have a way to get the + # complete DN string. Use str() and slice it instead, + # because str(dn) == "" + dnstr = str(dn)[18:-2] + + wsgikey = 'SSL_SERVER_%s_DN' % prefix + ssl_environ[wsgikey] = dnstr + + # The DN should be of the form: /k1=v1/k2=v2, but we must allow + # for any value to contain slashes itself (in a URL). + while dnstr: + pos = dnstr.rfind('=') + dnstr, value = dnstr[:pos], dnstr[pos + 1:] + pos = dnstr.rfind('/') + dnstr, key = dnstr[:pos], dnstr[pos + 1:] + if key and value: + wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + ssl_environ[wsgikey] = value + + return ssl_environ + + def makefile(self, sock, mode='r', bufsize=-1): + """Return socket file object.""" + cls = ( + SSLFileobjectStreamReader + if 'r' in mode else + SSLFileobjectStreamWriter + ) + if SSL and isinstance(sock, ssl_conn_type): + wrapped_socket = cls(sock, mode, bufsize) + wrapped_socket.ssl_timeout = sock.gettimeout() + return wrapped_socket + # This is from past: + # TODO: figure out what it's meant for + else: + return cheroot_server.CP_fileobject(sock, mode, bufsize) diff --git a/lib/cheroot/test/__init__.py b/lib/cheroot/test/__init__.py new file mode 100644 index 00000000..e2a7b348 --- /dev/null +++ b/lib/cheroot/test/__init__.py @@ -0,0 +1 @@ +"""Cheroot test suite.""" diff --git a/lib/cheroot/test/conftest.py b/lib/cheroot/test/conftest.py new file mode 100644 index 00000000..b9c8bad4 --- /dev/null +++ b/lib/cheroot/test/conftest.py @@ -0,0 +1,69 @@ +"""Pytest configuration module. + +Contains fixtures, which are tightly bound to the Cheroot framework +itself, useless for end-users' app testing. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import threading +import time + +import pytest + +from ..server import Gateway, HTTPServer +from ..testing import ( # noqa: F401 + native_server, wsgi_server, +) +from ..testing import get_server_client + + +@pytest.fixture +def wsgi_server_client(wsgi_server): # noqa: F811 + """Create a test client out of given WSGI server.""" + return get_server_client(wsgi_server) + + +@pytest.fixture +def native_server_client(native_server): # noqa: F811 + """Create a test client out of given HTTP server.""" + return get_server_client(native_server) + + +@pytest.fixture +def http_server(): + """Provision a server creator as a fixture.""" + def start_srv(): + bind_addr = yield + if bind_addr is None: + return + httpserver = make_http_server(bind_addr) + yield httpserver + yield httpserver + + srv_creator = iter(start_srv()) + next(srv_creator) + yield srv_creator + try: + while True: + httpserver = next(srv_creator) + if httpserver is not None: + httpserver.stop() + except StopIteration: + pass + + +def make_http_server(bind_addr): + """Create and start an HTTP server bound to bind_addr.""" + httpserver = HTTPServer( + bind_addr=bind_addr, + gateway=Gateway, + ) + + threading.Thread(target=httpserver.safe_start).start() + + while not httpserver.ready: + time.sleep(0.1) + + return httpserver diff --git a/lib/cheroot/test/helper.py b/lib/cheroot/test/helper.py new file mode 100644 index 00000000..0243ac86 --- /dev/null +++ b/lib/cheroot/test/helper.py @@ -0,0 +1,168 @@ +"""A library of helper functions for the Cheroot test suite.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import datetime +import logging +import os +import sys +import time +import threading +import types + +from six.moves import http_client + +import six + +import cheroot.server +import cheroot.wsgi + +from cheroot.test import webtest + +log = logging.getLogger(__name__) +thisdir = os.path.abspath(os.path.dirname(__file__)) + + +config = { + 'bind_addr': ('127.0.0.1', 54583), + 'server': 'wsgi', + 'wsgi_app': None, +} + + +class CherootWebCase(webtest.WebCase): + """Helper class for a web app test suite.""" + + script_name = '' + scheme = 'http' + + available_servers = { + 'wsgi': cheroot.wsgi.Server, + 'native': cheroot.server.HTTPServer, + } + + @classmethod + def setup_class(cls): + """Create and run one HTTP server per class.""" + conf = config.copy() + conf.update(getattr(cls, 'config', {})) + + s_class = conf.pop('server', 'wsgi') + server_factory = cls.available_servers.get(s_class) + if server_factory is None: + raise RuntimeError('Unknown server in config: %s' % conf['server']) + cls.httpserver = server_factory(**conf) + + cls.HOST, cls.PORT = cls.httpserver.bind_addr + if cls.httpserver.ssl_adapter is None: + ssl = '' + cls.scheme = 'http' + else: + ssl = ' (ssl)' + cls.HTTP_CONN = http_client.HTTPSConnection + cls.scheme = 'https' + + v = sys.version.split()[0] + log.info('Python version used to run this test script: %s' % v) + log.info('Cheroot version: %s' % cheroot.__version__) + log.info('HTTP server version: %s%s' % (cls.httpserver.protocol, ssl)) + log.info('PID: %s' % os.getpid()) + + if hasattr(cls, 'setup_server'): + # Clear the wsgi server so that + # it can be updated with the new root + cls.setup_server() + cls.start() + + @classmethod + def teardown_class(cls): + """Cleanup HTTP server.""" + if hasattr(cls, 'setup_server'): + cls.stop() + + @classmethod + def start(cls): + """Load and start the HTTP server.""" + threading.Thread(target=cls.httpserver.safe_start).start() + while not cls.httpserver.ready: + time.sleep(0.1) + + @classmethod + def stop(cls): + """Terminate HTTP server.""" + cls.httpserver.stop() + td = getattr(cls, 'teardown', None) + if td: + td() + + date_tolerance = 2 + + def assertEqualDates(self, dt1, dt2, seconds=None): + """Assert abs(dt1 - dt2) is within Y seconds.""" + if seconds is None: + seconds = self.date_tolerance + + if dt1 > dt2: + diff = dt1 - dt2 + else: + diff = dt2 - dt1 + if not diff < datetime.timedelta(seconds=seconds): + raise AssertionError('%r and %r are not within %r seconds.' % + (dt1, dt2, seconds)) + + +class Request: + """HTTP request container.""" + + def __init__(self, environ): + """Initialize HTTP request.""" + self.environ = environ + + +class Response: + """HTTP response container.""" + + def __init__(self): + """Initialize HTTP response.""" + self.status = '200 OK' + self.headers = {'Content-Type': 'text/html'} + self.body = None + + def output(self): + """Generate iterable response body object.""" + if self.body is None: + return [] + elif isinstance(self.body, six.text_type): + return [self.body.encode('iso-8859-1')] + elif isinstance(self.body, six.binary_type): + return [self.body] + else: + return [x.encode('iso-8859-1') for x in self.body] + + +class Controller: + """WSGI app for tests.""" + + def __call__(self, environ, start_response): + """WSGI request handler.""" + req, resp = Request(environ), Response() + try: + # Python 3 supports unicode attribute names + # Python 2 encodes them + handler = self.handlers[environ['PATH_INFO']] + except KeyError: + resp.status = '404 Not Found' + else: + output = handler(req, resp) + if (output is not None + and not any(resp.status.startswith(status_code) + for status_code in ('204', '304'))): + resp.body = output + try: + resp.headers.setdefault('Content-Length', str(len(output))) + except TypeError: + if not isinstance(output, types.GeneratorType): + raise + start_response(resp.status, resp.headers.items()) + return resp.output() diff --git a/lib/cheroot/test/test__compat.py b/lib/cheroot/test/test__compat.py new file mode 100644 index 00000000..c03e5463 --- /dev/null +++ b/lib/cheroot/test/test__compat.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +"""Test suite for cross-python compatibility helpers.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest +import six + +from cheroot._compat import extract_bytes, memoryview, ntob, ntou, bton + + +@pytest.mark.parametrize( + 'func,inp,out', + [ + (ntob, 'bar', b'bar'), + (ntou, 'bar', u'bar'), + (bton, b'bar', 'bar'), + ], +) +def test_compat_functions_positive(func, inp, out): + """Check that compat functions work with correct input.""" + assert func(inp, encoding='utf-8') == out + + +@pytest.mark.parametrize( + 'func', + [ + ntob, + ntou, + ], +) +def test_compat_functions_negative_nonnative(func): + """Check that compat functions fail loudly for incorrect input.""" + non_native_test_str = u'bar' if six.PY2 else b'bar' + with pytest.raises(TypeError): + func(non_native_test_str, encoding='utf-8') + + +def test_ntou_escape(): + """Check that ntou supports escape-encoding under Python 2.""" + expected = u'hišřії' + actual = ntou('hi\u0161\u0159\u0456\u0457', encoding='escape') + assert actual == expected + + +@pytest.mark.parametrize( + 'input_argument,expected_result', + [ + (b'qwerty', b'qwerty'), + (memoryview(b'asdfgh'), b'asdfgh'), + ], +) +def test_extract_bytes(input_argument, expected_result): + """Check that legitimate inputs produce bytes.""" + assert extract_bytes(input_argument) == expected_result + + +def test_extract_bytes_invalid(): + """Ensure that invalid input causes exception to be raised.""" + with pytest.raises(ValueError): + extract_bytes(u'some юнікод їїї') diff --git a/lib/cheroot/test/test_conn.py b/lib/cheroot/test/test_conn.py new file mode 100644 index 00000000..b26ffc96 --- /dev/null +++ b/lib/cheroot/test/test_conn.py @@ -0,0 +1,980 @@ +"""Tests for TCP connection handling, including proper and timely close.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket +import time + +from six.moves import range, http_client, urllib + +import six +import pytest + +from cheroot.test import helper, webtest + + +timeout = 1 +pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN' + + +class Controller(helper.Controller): + """Controller for serving WSGI apps.""" + + def hello(req, resp): + """Render Hello world.""" + return 'Hello, world!' + + def pov(req, resp): + """Render pov value.""" + return pov + + def stream(req, resp): + """Render streaming response.""" + if 'set_cl' in req.environ['QUERY_STRING']: + resp.headers['Content-Length'] = str(10) + + def content(): + for x in range(10): + yield str(x) + + return content() + + def upload(req, resp): + """Process file upload and render thank.""" + if not req.environ['REQUEST_METHOD'] == 'POST': + raise AssertionError("'POST' != request.method %r" % + req.environ['REQUEST_METHOD']) + return "thanks for '%s'" % req.environ['wsgi.input'].read() + + def custom_204(req, resp): + """Render response with status 204.""" + resp.status = '204' + return 'Code = 204' + + def custom_304(req, resp): + """Render response with status 304.""" + resp.status = '304' + return 'Code = 304' + + def err_before_read(req, resp): + """Render response with status 500.""" + resp.status = '500 Internal Server Error' + return 'ok' + + def one_megabyte_of_a(req, resp): + """Render 1MB response.""" + return ['a' * 1024] * 1024 + + def wrong_cl_buffered(req, resp): + """Render buffered response with invalid length value.""" + resp.headers['Content-Length'] = '5' + return 'I have too many bytes' + + def wrong_cl_unbuffered(req, resp): + """Render unbuffered response with invalid length value.""" + resp.headers['Content-Length'] = '5' + return ['I too', ' have too many bytes'] + + def _munge(string): + """Encode PATH_INFO correctly depending on Python version. + + WSGI 1.0 is a mess around unicode. Create endpoints + that match the PATH_INFO that it produces. + """ + if six.PY2: + return string + return string.encode('utf-8').decode('latin-1') + + handlers = { + '/hello': hello, + '/pov': pov, + '/page1': pov, + '/page2': pov, + '/page3': pov, + '/stream': stream, + '/upload': upload, + '/custom/204': custom_204, + '/custom/304': custom_304, + '/err_before_read': err_before_read, + '/one_megabyte_of_a': one_megabyte_of_a, + '/wrong_cl_buffered': wrong_cl_buffered, + '/wrong_cl_unbuffered': wrong_cl_unbuffered, + } + + +@pytest.fixture +def testing_server(wsgi_server_client): + """Attach a WSGI app to the given server and pre-configure it.""" + app = Controller() + + def _timeout(req, resp): + return str(wsgi_server.timeout) + app.handlers['/timeout'] = _timeout + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = app + wsgi_server.max_request_body_size = 1001 + wsgi_server.timeout = timeout + wsgi_server.server_client = wsgi_server_client + wsgi_server.keep_alive_conn_limit = 2 + return wsgi_server + + +@pytest.fixture +def test_client(testing_server): + """Get and return a test client out of the given server.""" + return testing_server.server_client + + +def header_exists(header_name, headers): + """Check that a header is present.""" + return header_name.lower() in (k.lower() for (k, _) in headers) + + +def header_has_value(header_name, header_value, headers): + """Check that a header with a given value is present.""" + return header_name.lower() in ( + k.lower() for (k, v) in headers + if v == header_value + ) + + +def test_HTTP11_persistent_connections(test_client): + """Test persistent HTTP/1.1 connections.""" + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert there's no "Connection: close". + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Make another request on the same connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page1', http_conn=http_connection, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Test client-side close. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page2', http_conn=http_connection, + headers=[('Connection', 'close')], + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'close', actual_headers) + + # Make another request on the same connection, which should error. + with pytest.raises(http_client.NotConnected): + test_client.get('/pov', http_conn=http_connection) + + +@pytest.mark.parametrize( + 'set_cl', + ( + False, # Without Content-Length + True, # With Content-Length + ), +) +def test_streaming_11(test_client, set_cl): + """Test serving of streaming responses with HTTP/1.1 protocol.""" + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert there's no "Connection: close". + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Make another, streamed request on the same connection. + if set_cl: + # When a Content-Length is provided, the content should stream + # without closing the connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream?set_cl=Yes', http_conn=http_connection, + ) + assert header_exists('Content-Length', actual_headers) + assert not header_has_value('Connection', 'close', actual_headers) + assert not header_exists('Transfer-Encoding', actual_headers) + + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + else: + # When no Content-Length response header is provided, + # streamed output will either close the connection, or use + # chunked encoding, to determine transfer-length. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream', http_conn=http_connection, + ) + assert not header_exists('Content-Length', actual_headers) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + + chunked_response = False + for k, v in actual_headers: + if k.lower() == 'transfer-encoding': + if str(v) == 'chunked': + chunked_response = True + + if chunked_response: + assert not header_has_value('Connection', 'close', actual_headers) + else: + assert header_has_value('Connection', 'close', actual_headers) + + # Make another request on the same connection, which should + # error. + with pytest.raises(http_client.NotConnected): + test_client.get('/pov', http_conn=http_connection) + + # Try HEAD. + # See https://www.bitbucket.org/cherrypy/cherrypy/issue/864. + # TODO: figure out how can this be possible on an closed connection + # (chunked_response case) + status_line, actual_headers, actual_resp_body = test_client.head( + '/stream', http_conn=http_connection, + ) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'' + assert not header_exists('Transfer-Encoding', actual_headers) + + +@pytest.mark.parametrize( + 'set_cl', + ( + False, # Without Content-Length + True, # With Content-Length + ), +) +def test_streaming_10(test_client, set_cl): + """Test serving of streaming responses with HTTP/1.0 protocol.""" + original_server_protocol = test_client.server_instance.protocol + test_client.server_instance.protocol = 'HTTP/1.0' + + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert Keep-Alive. + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection, + headers=[('Connection', 'Keep-Alive')], + protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + + # Make another, streamed request on the same connection. + if set_cl: + # When a Content-Length is provided, the content should + # stream without closing the connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream?set_cl=Yes', http_conn=http_connection, + headers=[('Connection', 'Keep-Alive')], + protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + + assert header_exists('Content-Length', actual_headers) + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + assert not header_exists('Transfer-Encoding', actual_headers) + else: + # When a Content-Length is not provided, + # the server should close the connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream', http_conn=http_connection, + headers=[('Connection', 'Keep-Alive')], + protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + + assert not header_exists('Content-Length', actual_headers) + assert not header_has_value('Connection', 'Keep-Alive', actual_headers) + assert not header_exists('Transfer-Encoding', actual_headers) + + # Make another request on the same connection, which should error. + with pytest.raises(http_client.NotConnected): + test_client.get( + '/pov', http_conn=http_connection, + protocol='HTTP/1.0', + ) + + test_client.server_instance.protocol = original_server_protocol + + +@pytest.mark.parametrize( + 'http_server_protocol', + ( + 'HTTP/1.0', + 'HTTP/1.1', + ), +) +def test_keepalive(test_client, http_server_protocol): + """Test Keep-Alive enabled connections.""" + original_server_protocol = test_client.server_instance.protocol + test_client.server_instance.protocol = http_server_protocol + + http_client_protocol = 'HTTP/1.0' + + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Test a normal HTTP/1.0 request. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page2', + protocol=http_client_protocol, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Test a keep-alive HTTP/1.0 request. + + status_line, actual_headers, actual_resp_body = test_client.get( + '/page3', headers=[('Connection', 'Keep-Alive')], + http_conn=http_connection, protocol=http_client_protocol, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + + # Remove the keep-alive header again. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page3', http_conn=http_connection, + protocol=http_client_protocol, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + test_client.server_instance.protocol = original_server_protocol + + +def test_keepalive_conn_management(test_client): + """Test management of Keep-Alive connections.""" + test_client.server_instance.timeout = 2 + + def connection(): + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + return http_connection + + def request(conn): + status_line, actual_headers, actual_resp_body = test_client.get( + '/page3', headers=[('Connection', 'Keep-Alive')], + http_conn=conn, protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + + disconnect_errors = ( + http_client.BadStatusLine, + http_client.CannotSendRequest, + http_client.NotConnected, + ) + + # Make a new connection. + c1 = connection() + request(c1) + + # Make a second one. + c2 = connection() + request(c2) + + # Reusing the first connection should still work. + request(c1) + + # Creating a new connection should still work. + c3 = connection() + request(c3) + + # Allow a tick. + time.sleep(0.2) + + # That's three connections, we should expect the one used less recently + # to be expired. + with pytest.raises(disconnect_errors): + request(c2) + + # But the oldest created one should still be valid. + # (As well as the newest one). + request(c1) + request(c3) + + # Wait for some of our timeout. + time.sleep(1.0) + + # Refresh the third connection. + request(c3) + + # Wait for the remainder of our timeout, plus one tick. + time.sleep(1.2) + + # First connection should now be expired. + with pytest.raises(disconnect_errors): + request(c1) + + # But the third one should still be valid. + request(c3) + + test_client.server_instance.timeout = timeout + + +@pytest.mark.parametrize( + 'timeout_before_headers', + ( + True, + False, + ), +) +def test_HTTP11_Timeout(test_client, timeout_before_headers): + """Check timeout without sending any data. + + The server will close the conn with a 408. + """ + conn = test_client.get_connection() + conn.auto_open = False + conn.connect() + + if not timeout_before_headers: + # Connect but send half the headers only. + conn.send(b'GET /hello HTTP/1.1') + conn.send(('Host: %s' % conn.host).encode('ascii')) + # else: Connect but send nothing. + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # The request should have returned 408 already. + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 408 + conn.close() + + +def test_HTTP11_Timeout_after_request(test_client): + """Check timeout after at least one request has succeeded. + + The server should close the connection without 408. + """ + fail_msg = "Writing to timed out socket didn't fail as it should have: %s" + + # Make an initial request + conn = test_client.get_connection() + conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = str(timeout).encode() + assert actual_body == expected_body + + # Make a second request on the same socket + conn._output(b'GET /hello HTTP/1.1') + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._send_output() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = b'Hello, world!' + assert actual_body == expected_body + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # Make another request on the same socket, which should error + conn._output(b'GET /hello HTTP/1.1') + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._send_output() + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + except (socket.error, http_client.BadStatusLine): + pass + except Exception as ex: + pytest.fail(fail_msg % ex) + else: + if response.status != 408: + pytest.fail(fail_msg % response.read()) + + conn.close() + + # Make another request on a new socket, which should work + conn = test_client.get_connection() + conn.putrequest('GET', '/pov', skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = pov.encode() + assert actual_body == expected_body + + # Make another request on the same socket, + # but timeout on the headers + conn.send(b'GET /hello HTTP/1.1') + # Wait for our socket timeout + time.sleep(timeout * 2) + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + except (socket.error, http_client.BadStatusLine): + pass + except Exception as ex: + pytest.fail(fail_msg % ex) + else: + if response.status != 408: + pytest.fail(fail_msg % response.read()) + + conn.close() + + # Retry the request on a new connection, which should work + conn = test_client.get_connection() + conn.putrequest('GET', '/pov', skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = pov.encode() + assert actual_body == expected_body + conn.close() + + +def test_HTTP11_pipelining(test_client): + """Test HTTP/1.1 pipelining. + + httplib doesn't support this directly. + """ + conn = test_client.get_connection() + + # Put request 1 + conn.putrequest('GET', '/hello', skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + + for trial in range(5): + # Put next request + conn._output( + ('GET /hello?%s HTTP/1.1' % trial).encode('iso-8859-1'), + ) + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._send_output() + + # Retrieve previous response + response = conn.response_class(conn.sock, method='GET') + # there is a bug in python3 regarding the buffering of + # ``conn.sock``. Until that bug get's fixed we will + # monkey patch the ``response`` instance. + # https://bugs.python.org/issue23377 + if not six.PY2: + response.fp = conn.sock.makefile('rb', 0) + response.begin() + body = response.read(13) + assert response.status == 200 + assert body == b'Hello, world!' + + # Retrieve final response + response = conn.response_class(conn.sock, method='GET') + response.begin() + body = response.read() + assert response.status == 200 + assert body == b'Hello, world!' + + conn.close() + + +def test_100_Continue(test_client): + """Test 100-continue header processing.""" + conn = test_client.get_connection() + + # Try a page without an Expect request header first. + # Note that httplib's response.begin automatically ignores + # 100 Continue responses, so we must manually check for it. + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '4') + conn.endheaders() + conn.send(b"d'oh") + response = conn.response_class(conn.sock, method='POST') + version, status, reason = response._read_status() + assert status != 100 + conn.close() + + # Now try a page with an Expect header... + conn.connect() + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '17') + conn.putheader('Expect', '100-continue') + conn.endheaders() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + assert status == 100 + while True: + line = response.fp.readline().strip() + if line: + pytest.fail( + '100 Continue should not output any headers. Got %r' % + line, + ) + else: + break + + # ...send the body + body = b'I am a small file' + conn.send(body) + + # ...get the final response + response.begin() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 200 + expected_resp_body = ("thanks for '%s'" % body).encode() + assert actual_resp_body == expected_resp_body + conn.close() + + +@pytest.mark.parametrize( + 'max_request_body_size', + ( + 0, + 1001, + ), +) +def test_readall_or_close(test_client, max_request_body_size): + """Test a max_request_body_size of 0 (the default) and 1001.""" + old_max = test_client.server_instance.max_request_body_size + + test_client.server_instance.max_request_body_size = max_request_body_size + + conn = test_client.get_connection() + + # Get a POST page with an error + conn.putrequest('POST', '/err_before_read', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '1000') + conn.putheader('Expect', '100-continue') + conn.endheaders() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + assert status == 100 + skip = True + while skip: + skip = response.fp.readline().strip() + + # ...send the body + conn.send(b'x' * 1000) + + # ...get the final response + response.begin() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 500 + + # Now try a working page with an Expect header... + conn._output(b'POST /upload HTTP/1.1') + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._output(b'Content-Type: text/plain') + conn._output(b'Content-Length: 17') + conn._output(b'Expect: 100-continue') + conn._send_output() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + assert status == 100 + skip = True + while skip: + skip = response.fp.readline().strip() + + # ...send the body + body = b'I am a small file' + conn.send(body) + + # ...get the final response + response.begin() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 200 + expected_resp_body = ("thanks for '%s'" % body).encode() + assert actual_resp_body == expected_resp_body + conn.close() + + test_client.server_instance.max_request_body_size = old_max + + +def test_No_Message_Body(test_client): + """Test HTTP queries with an empty response body.""" + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert there's no "Connection: close". + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Make a 204 request on the same connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/custom/204', http_conn=http_connection, + ) + actual_status = int(status_line[:3]) + assert actual_status == 204 + assert not header_exists('Content-Length', actual_headers) + assert actual_resp_body == b'' + assert not header_exists('Connection', actual_headers) + + # Make a 304 request on the same connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/custom/304', http_conn=http_connection, + ) + actual_status = int(status_line[:3]) + assert actual_status == 304 + assert not header_exists('Content-Length', actual_headers) + assert actual_resp_body == b'' + assert not header_exists('Connection', actual_headers) + + +@pytest.mark.xfail( + reason='Server does not correctly read trailers/ending of the previous ' + 'HTTP request, thus the second request fails as the server tries ' + r"to parse b'Content-Type: application/json\r\n' as a " + 'Request-Line. This results in HTTP status code 400, instead of 413' + 'Ref: https://github.com/cherrypy/cheroot/issues/69', +) +def test_Chunked_Encoding(test_client): + """Test HTTP uploads with chunked transfer-encoding.""" + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + + # Try a normal chunked request (with extensions) + body = ( + b'8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n' + b'Content-Type: application/json\r\n' + b'\r\n' + ) + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader('Trailer', 'Content-Type') + # Note that this is somewhat malformed: + # we shouldn't be sending Content-Length. + # RFC 2616 says the server should ignore it. + conn.putheader('Content-Length', '3') + conn.endheaders() + conn.send(body) + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + expected_resp_body = ("thanks for '%s'" % b'xx\r\nxxxxyyyyy').encode() + assert actual_resp_body == expected_resp_body + + # Try a chunked request that exceeds server.max_request_body_size. + # Note that the delimiters and trailer are included. + body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n' + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader('Content-Type', 'text/plain') + # Chunked requests don't need a content-length + # conn.putheader("Content-Length", len(body)) + conn.endheaders() + conn.send(body) + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 413 + conn.close() + + +def test_Content_Length_in(test_client): + """Try a non-chunked request where Content-Length exceeds limit. + + (server.max_request_body_size). + Assert error before body send. + """ + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '9999') + conn.endheaders() + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 413 + expected_resp_body = ( + b'The entity sent with the request exceeds ' + b'the maximum allowed bytes.' + ) + assert actual_resp_body == expected_resp_body + conn.close() + + +def test_Content_Length_not_int(test_client): + """Test that malicious Content-Length header returns 400.""" + status_line, actual_headers, actual_resp_body = test_client.post( + '/upload', + headers=[ + ('Content-Type', 'text/plain'), + ('Content-Length', 'not-an-integer'), + ], + ) + actual_status = int(status_line[:3]) + + assert actual_status == 400 + assert actual_resp_body == b'Malformed Content-Length Header.' + + +@pytest.mark.parametrize( + 'uri,expected_resp_status,expected_resp_body', + ( + ( + '/wrong_cl_buffered', 500, + ( + b'The requested resource returned more bytes than the ' + b'declared Content-Length.' + ), + ), + ('/wrong_cl_unbuffered', 200, b'I too'), + ), +) +def test_Content_Length_out( + test_client, + uri, expected_resp_status, expected_resp_body, +): + """Test response with Content-Length less than the response body. + + (non-chunked response) + """ + conn = test_client.get_connection() + conn.putrequest('GET', uri, skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + + assert actual_status == expected_resp_status + assert actual_resp_body == expected_resp_body + + conn.close() + + +@pytest.mark.xfail( + reason='Sometimes this test fails due to low timeout. ' + 'Ref: https://github.com/cherrypy/cherrypy/issues/598', +) +def test_598(test_client): + """Test serving large file with a read timeout in place.""" + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + remote_data_conn = urllib.request.urlopen( + '%s://%s:%s/one_megabyte_of_a' + % ('http', conn.host, conn.port), + ) + buf = remote_data_conn.read(512) + time.sleep(timeout * 0.6) + remaining = (1024 * 1024) - 512 + while remaining: + data = remote_data_conn.read(remaining) + if not data: + break + buf += data + remaining -= len(data) + + assert len(buf) == 1024 * 1024 + assert buf == b'a' * 1024 * 1024 + assert remaining == 0 + remote_data_conn.close() + + +@pytest.mark.parametrize( + 'invalid_terminator', + ( + b'\n\n', + b'\r\n\n', + ), +) +def test_No_CRLF(test_client, invalid_terminator): + """Test HTTP queries with no valid CRLF terminators.""" + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + + # (b'%s' % b'') is not supported in Python 3.4, so just use + + conn.send(b'GET /hello HTTP/1.1' + invalid_terminator) + response = conn.response_class(conn.sock, method='GET') + response.begin() + actual_resp_body = response.read() + expected_resp_body = b'HTTP requires CRLF terminators' + assert actual_resp_body == expected_resp_body + conn.close() diff --git a/lib/cheroot/test/test_core.py b/lib/cheroot/test/test_core.py new file mode 100644 index 00000000..aad2bb7f --- /dev/null +++ b/lib/cheroot/test/test_core.py @@ -0,0 +1,415 @@ +"""Tests for managing HTTP issues (malformed requests, etc).""" +# -*- coding: utf-8 -*- +# vim: set fileencoding=utf-8 : + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import errno +import socket + +import pytest +import six +from six.moves import urllib + +from cheroot.test import helper + + +HTTP_BAD_REQUEST = 400 +HTTP_LENGTH_REQUIRED = 411 +HTTP_NOT_FOUND = 404 +HTTP_OK = 200 +HTTP_VERSION_NOT_SUPPORTED = 505 + + +class HelloController(helper.Controller): + """Controller for serving WSGI apps.""" + + def hello(req, resp): + """Render Hello world.""" + return 'Hello world!' + + def body_required(req, resp): + """Render Hello world or set 411.""" + if req.environ.get('Content-Length', None) is None: + resp.status = '411 Length Required' + return + return 'Hello world!' + + def query_string(req, resp): + """Render QUERY_STRING value.""" + return req.environ.get('QUERY_STRING', '') + + def asterisk(req, resp): + """Render request method value.""" + method = req.environ.get('REQUEST_METHOD', 'NO METHOD FOUND') + tmpl = 'Got asterisk URI path with {method} method' + return tmpl.format(**locals()) + + def _munge(string): + """Encode PATH_INFO correctly depending on Python version. + + WSGI 1.0 is a mess around unicode. Create endpoints + that match the PATH_INFO that it produces. + """ + if six.PY2: + return string + return string.encode('utf-8').decode('latin-1') + + handlers = { + '/hello': hello, + '/no_body': hello, + '/body_required': body_required, + '/query_string': query_string, + _munge('/привіт'): hello, + _munge('/Юххууу'): hello, + '/\xa0Ðblah key 0 900 4 data': hello, + '/*': asterisk, + } + + +def _get_http_response(connection, method='GET'): + c = connection + kwargs = {'strict': c.strict} if hasattr(c, 'strict') else {} + # Python 3.2 removed the 'strict' feature, saying: + # "http.client now always assumes HTTP/1.x compliant servers." + return c.response_class(c.sock, method=method, **kwargs) + + +@pytest.fixture +def testing_server(wsgi_server_client): + """Attach a WSGI app to the given server and pre-configure it.""" + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = HelloController() + wsgi_server.max_request_body_size = 30000000 + wsgi_server.server_client = wsgi_server_client + return wsgi_server + + +@pytest.fixture +def test_client(testing_server): + """Get and return a test client out of the given server.""" + return testing_server.server_client + + +def test_http_connect_request(test_client): + """Check that CONNECT query results in Method Not Allowed status.""" + status_line = test_client.connect('/anything')[0] + actual_status = int(status_line[:3]) + assert actual_status == 405 + + +def test_normal_request(test_client): + """Check that normal GET query succeeds.""" + status_line, _, actual_resp_body = test_client.get('/hello') + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + assert actual_resp_body == b'Hello world!' + + +def test_query_string_request(test_client): + """Check that GET param is parsed well.""" + status_line, _, actual_resp_body = test_client.get( + '/query_string?test=True', + ) + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + assert actual_resp_body == b'test=True' + + +@pytest.mark.parametrize( + 'uri', + ( + '/hello', # plain + '/query_string?test=True', # query + '/{0}?{1}={2}'.format( # quoted unicode + *map(urllib.parse.quote, ('Юххууу', 'ї', 'йо')) + ), + ), +) +def test_parse_acceptable_uri(test_client, uri): + """Check that server responds with OK to valid GET queries.""" + status_line = test_client.get(uri)[0] + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + + +@pytest.mark.xfail(six.PY2, reason='Fails on Python 2') +def test_parse_uri_unsafe_uri(test_client): + """Test that malicious URI does not allow HTTP injection. + + This effectively checks that sending GET request with URL + + /%A0%D0blah%20key%200%20900%204%20data + + is not converted into + + GET / + blah key 0 900 4 data + HTTP/1.1 + + which would be a security issue otherwise. + """ + c = test_client.get_connection() + resource = '/\xa0Ðblah key 0 900 4 data'.encode('latin-1') + quoted = urllib.parse.quote(resource) + assert quoted == '/%A0%D0blah%20key%200%20900%204%20data' + request = 'GET {quoted} HTTP/1.1'.format(**locals()) + c._output(request.encode('utf-8')) + c._send_output() + response = _get_http_response(c, method='GET') + response.begin() + assert response.status == HTTP_OK + assert response.read(12) == b'Hello world!' + c.close() + + +def test_parse_uri_invalid_uri(test_client): + """Check that server responds with Bad Request to invalid GET queries. + + Invalid request line test case: it should only contain US-ASCII. + """ + c = test_client.get_connection() + c._output(u'GET /йопта! HTTP/1.1'.encode('utf-8')) + c._send_output() + response = _get_http_response(c, method='GET') + response.begin() + assert response.status == HTTP_BAD_REQUEST + assert response.read(21) == b'Malformed Request-URI' + c.close() + + +@pytest.mark.parametrize( + 'uri', + ( + 'hello', # ascii + 'привіт', # non-ascii + ), +) +def test_parse_no_leading_slash_invalid(test_client, uri): + """Check that server responds with Bad Request to invalid GET queries. + + Invalid request line test case: it should have leading slash (be absolute). + """ + status_line, _, actual_resp_body = test_client.get( + urllib.parse.quote(uri), + ) + actual_status = int(status_line[:3]) + assert actual_status == HTTP_BAD_REQUEST + assert b'starting with a slash' in actual_resp_body + + +def test_parse_uri_absolute_uri(test_client): + """Check that server responds with Bad Request to Absolute URI. + + Only proxy servers should allow this. + """ + status_line, _, actual_resp_body = test_client.get('http://google.com/') + actual_status = int(status_line[:3]) + assert actual_status == HTTP_BAD_REQUEST + expected_body = b'Absolute URI not allowed if server is not a proxy.' + assert actual_resp_body == expected_body + + +def test_parse_uri_asterisk_uri(test_client): + """Check that server responds with OK to OPTIONS with "*" Absolute URI.""" + status_line, _, actual_resp_body = test_client.options('*') + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + expected_body = b'Got asterisk URI path with OPTIONS method' + assert actual_resp_body == expected_body + + +def test_parse_uri_fragment_uri(test_client): + """Check that server responds with Bad Request to URI with fragment.""" + status_line, _, actual_resp_body = test_client.get( + '/hello?test=something#fake', + ) + actual_status = int(status_line[:3]) + assert actual_status == HTTP_BAD_REQUEST + expected_body = b'Illegal #fragment in Request-URI.' + assert actual_resp_body == expected_body + + +def test_no_content_length(test_client): + """Test POST query with an empty body being successful.""" + # "The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers." + # + # Send a message with neither header and no body. + c = test_client.get_connection() + c.request('POST', '/no_body') + response = c.getresponse() + actual_resp_body = response.read() + actual_status = response.status + assert actual_status == HTTP_OK + assert actual_resp_body == b'Hello world!' + + +def test_content_length_required(test_client): + """Test POST query with body failing because of missing Content-Length.""" + # Now send a message that has no Content-Length, but does send a body. + # Verify that CP times out the socket and responds + # with 411 Length Required. + + c = test_client.get_connection() + c.request('POST', '/body_required') + response = c.getresponse() + response.read() + + actual_status = response.status + assert actual_status == HTTP_LENGTH_REQUIRED + + +@pytest.mark.parametrize( + 'request_line,status_code,expected_body', + ( + ( + b'GET /', # missing proto + HTTP_BAD_REQUEST, b'Malformed Request-Line', + ), + ( + b'GET / HTTPS/1.1', # invalid proto + HTTP_BAD_REQUEST, b'Malformed Request-Line: bad protocol', + ), + ( + b'GET / HTTP/1', # invalid version + HTTP_BAD_REQUEST, b'Malformed Request-Line: bad version', + ), + ( + b'GET / HTTP/2.15', # invalid ver + HTTP_VERSION_NOT_SUPPORTED, b'Cannot fulfill request', + ), + ), +) +def test_malformed_request_line( + test_client, request_line, + status_code, expected_body, +): + """Test missing or invalid HTTP version in Request-Line.""" + c = test_client.get_connection() + c._output(request_line) + c._send_output() + response = _get_http_response(c, method='GET') + response.begin() + assert response.status == status_code + assert response.read(len(expected_body)) == expected_body + c.close() + + +def test_malformed_http_method(test_client): + """Test non-uppercase HTTP method.""" + c = test_client.get_connection() + c.putrequest('GeT', '/malformed_method_case') + c.putheader('Content-Type', 'text/plain') + c.endheaders() + + response = c.getresponse() + actual_status = response.status + assert actual_status == HTTP_BAD_REQUEST + actual_resp_body = response.read(21) + assert actual_resp_body == b'Malformed method name' + + +def test_malformed_header(test_client): + """Check that broken HTTP header results in Bad Request.""" + c = test_client.get_connection() + c.putrequest('GET', '/') + c.putheader('Content-Type', 'text/plain') + # See https://www.bitbucket.org/cherrypy/cherrypy/issue/941 + c._output(b'Re, 1.2.3.4#015#012') + c.endheaders() + + response = c.getresponse() + actual_status = response.status + assert actual_status == HTTP_BAD_REQUEST + actual_resp_body = response.read(20) + assert actual_resp_body == b'Illegal header line.' + + +def test_request_line_split_issue_1220(test_client): + """Check that HTTP request line of exactly 256 chars length is OK.""" + Request_URI = ( + '/hello?' + 'intervenant-entreprise-evenement_classaction=' + 'evenement-mailremerciements' + '&_path=intervenant-entreprise-evenement' + '&intervenant-entreprise-evenement_action-id=19404' + '&intervenant-entreprise-evenement_id=19404' + '&intervenant-entreprise_id=28092' + ) + assert len('GET %s HTTP/1.1\r\n' % Request_URI) == 256 + + actual_resp_body = test_client.get(Request_URI)[2] + assert actual_resp_body == b'Hello world!' + + +def test_garbage_in(test_client): + """Test that server sends an error for garbage received over TCP.""" + # Connect without SSL regardless of server.scheme + + c = test_client.get_connection() + c._output(b'gjkgjklsgjklsgjkljklsg') + c._send_output() + response = c.response_class(c.sock, method='GET') + try: + response.begin() + actual_status = response.status + assert actual_status == HTTP_BAD_REQUEST + actual_resp_body = response.read(22) + assert actual_resp_body == b'Malformed Request-Line' + c.close() + except socket.error as ex: + # "Connection reset by peer" is also acceptable. + if ex.errno != errno.ECONNRESET: + raise + + +class CloseController: + """Controller for testing the close callback.""" + + def __call__(self, environ, start_response): + """Get the req to know header sent status.""" + self.req = start_response.__self__.req + resp = CloseResponse(self.close) + start_response(resp.status, resp.headers.items()) + return resp + + def close(self): + """Close, writing hello.""" + self.req.write(b'hello') + + +class CloseResponse: + """Dummy empty response to trigger the no body status.""" + + def __init__(self, close): + """Use some defaults to ensure we have a header.""" + self.status = '200 OK' + self.headers = {'Content-Type': 'text/html'} + self.close = close + + def __getitem__(self, index): + """Ensure we don't have a body.""" + raise IndexError() + + def output(self): + """Return self to hook the close method.""" + return self + + +@pytest.fixture +def testing_server_close(wsgi_server_client): + """Attach a WSGI app to the given server and pre-configure it.""" + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = CloseController() + wsgi_server.max_request_body_size = 30000000 + wsgi_server.server_client = wsgi_server_client + return wsgi_server + + +def test_send_header_before_closing(testing_server_close): + """Test we are actually sending the headers before calling 'close'.""" + _, _, resp_body = testing_server_close.server_client.get('/') + assert resp_body == b'hello' diff --git a/lib/cheroot/test/test_dispatch.py b/lib/cheroot/test/test_dispatch.py new file mode 100644 index 00000000..bc588749 --- /dev/null +++ b/lib/cheroot/test/test_dispatch.py @@ -0,0 +1,55 @@ +"""Tests for the HTTP server.""" +# -*- coding: utf-8 -*- +# vim: set fileencoding=utf-8 : + +from __future__ import absolute_import, division, print_function + +from cheroot.wsgi import PathInfoDispatcher + + +def wsgi_invoke(app, environ): + """Serve 1 requeset from a WSGI application.""" + response = {} + + def start_response(status, headers): + response.update({ + 'status': status, + 'headers': headers, + }) + + response['body'] = b''.join( + app(environ, start_response), + ) + + return response + + +def test_dispatch_no_script_name(): + """Despatch despite lack of SCRIPT_NAME in environ.""" + # Bare bones WSGI hello world app (from PEP 333). + def app(environ, start_response): + start_response( + '200 OK', [ + ('Content-Type', 'text/plain; charset=utf-8'), + ], + ) + return [u'Hello, world!'.encode('utf-8')] + + # Build a dispatch table. + d = PathInfoDispatcher([ + ('/', app), + ]) + + # Dispatch a request without `SCRIPT_NAME`. + response = wsgi_invoke( + d, { + 'PATH_INFO': '/foo', + }, + ) + assert response == { + 'status': '200 OK', + 'headers': [ + ('Content-Type', 'text/plain; charset=utf-8'), + ], + 'body': b'Hello, world!', + } diff --git a/lib/cheroot/test/test_errors.py b/lib/cheroot/test/test_errors.py new file mode 100644 index 00000000..34b42d90 --- /dev/null +++ b/lib/cheroot/test/test_errors.py @@ -0,0 +1,30 @@ +"""Test suite for ``cheroot.errors``.""" + +import pytest + +from cheroot import errors + +from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS + + +@pytest.mark.parametrize( + 'err_names,err_nums', + ( + (('', 'some-nonsense-name'), []), + ( + ( + 'EPROTOTYPE', 'EAGAIN', 'EWOULDBLOCK', + 'WSAEWOULDBLOCK', 'EPIPE', + ), + (91, 11, 32) if IS_LINUX else + (32, 35, 41) if IS_MACOS else + (32, 10041, 11, 10035) if IS_WINDOWS else + (), + ), + ), +) +def test_plat_specific_errors(err_names, err_nums): + """Test that plat_specific_errors retrieves correct err num list.""" + actual_err_nums = errors.plat_specific_errors(*err_names) + assert len(actual_err_nums) == len(err_nums) + assert sorted(actual_err_nums) == sorted(err_nums) diff --git a/lib/cheroot/test/test_makefile.py b/lib/cheroot/test/test_makefile.py new file mode 100644 index 00000000..55db5038 --- /dev/null +++ b/lib/cheroot/test/test_makefile.py @@ -0,0 +1,52 @@ +"""self-explanatory.""" + +from cheroot import makefile + + +__metaclass__ = type + + +class MockSocket: + """Mocks a socket.""" + + def __init__(self): + """Initialize.""" + self.messages = [] + + def recv_into(self, buf): + """Simulate recv_into for Python 3.""" + if not self.messages: + return 0 + msg = self.messages.pop(0) + for index, byte in enumerate(msg): + buf[index] = byte + return len(msg) + + def recv(self, size): + """Simulate recv for Python 2.""" + try: + return self.messages.pop(0) + except IndexError: + return '' + + def send(self, val): + """Simulate a send.""" + return len(val) + + +def test_bytes_read(): + """Reader should capture bytes read.""" + sock = MockSocket() + sock.messages.append(b'foo') + rfile = makefile.MakeFile(sock, 'r') + rfile.read() + assert rfile.bytes_read == 3 + + +def test_bytes_written(): + """Writer should capture bytes writtten.""" + sock = MockSocket() + sock.messages.append(b'foo') + wfile = makefile.MakeFile(sock, 'w') + wfile.write(b'bar') + assert wfile.bytes_written == 3 diff --git a/lib/cheroot/test/test_server.py b/lib/cheroot/test/test_server.py new file mode 100644 index 00000000..30112354 --- /dev/null +++ b/lib/cheroot/test/test_server.py @@ -0,0 +1,235 @@ +"""Tests for the HTTP server.""" +# -*- coding: utf-8 -*- +# vim: set fileencoding=utf-8 : + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import socket +import tempfile +import threading +import uuid + +import pytest +import requests +import requests_unixsocket +import six + +from .._compat import bton, ntob +from .._compat import IS_LINUX, IS_MACOS, SYS_PLATFORM +from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer +from ..testing import ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + EPHEMERAL_PORT, + get_server_client, +) + + +unix_only_sock_test = pytest.mark.skipif( + not hasattr(socket, 'AF_UNIX'), + reason='UNIX domain sockets are only available under UNIX-based OS', +) + + +non_macos_sock_test = pytest.mark.skipif( + IS_MACOS, + reason='Peercreds lookup does not work under macOS/BSD currently.', +) + + +@pytest.fixture(params=('abstract', 'file')) +def unix_sock_file(request): + """Check that bound UNIX socket address is stored in server.""" + if request.param == 'abstract': + yield request.getfixturevalue('unix_abstract_sock') + return + tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp() + + yield tmp_sock_fname + + os.close(tmp_sock_fh) + os.unlink(tmp_sock_fname) + + +@pytest.fixture +def unix_abstract_sock(): + """Return an abstract UNIX socket address.""" + if not IS_LINUX: + pytest.skip( + '{os} does not support an abstract ' + 'socket namespace'.format(os=SYS_PLATFORM), + ) + return b''.join(( + b'\x00cheroot-test-socket', + ntob(str(uuid.uuid4())), + )).decode() + + +def test_prepare_makes_server_ready(): + """Check that prepare() makes the server ready, and stop() clears it.""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + assert not httpserver.ready + assert not httpserver.requests._threads + + httpserver.prepare() + + assert httpserver.ready + assert httpserver.requests._threads + for thr in httpserver.requests._threads: + assert thr.ready + + httpserver.stop() + + assert not httpserver.requests._threads + assert not httpserver.ready + + +def test_stop_interrupts_serve(): + """Check that stop() interrupts running of serve().""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + httpserver.prepare() + serve_thread = threading.Thread(target=httpserver.serve) + serve_thread.start() + + serve_thread.join(0.5) + assert serve_thread.is_alive() + + httpserver.stop() + + serve_thread.join(0.5) + assert not serve_thread.is_alive() + + +@pytest.mark.parametrize( + 'ip_addr', + ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + ), +) +def test_bind_addr_inet(http_server, ip_addr): + """Check that bound IP address is stored in server.""" + httpserver = http_server.send((ip_addr, EPHEMERAL_PORT)) + + assert httpserver.bind_addr[0] == ip_addr + assert httpserver.bind_addr[1] != EPHEMERAL_PORT + + +@unix_only_sock_test +def test_bind_addr_unix(http_server, unix_sock_file): + """Check that bound UNIX socket address is stored in server.""" + httpserver = http_server.send(unix_sock_file) + + assert httpserver.bind_addr == unix_sock_file + + +@unix_only_sock_test +def test_bind_addr_unix_abstract(http_server, unix_abstract_sock): + """Check that bound UNIX abstract sockaddr is stored in server.""" + httpserver = http_server.send(unix_abstract_sock) + + assert httpserver.bind_addr == unix_abstract_sock + + +PEERCRED_IDS_URI = '/peer_creds/ids' +PEERCRED_TEXTS_URI = '/peer_creds/texts' + + +class _TestGateway(Gateway): + def respond(self): + req = self.req + conn = req.conn + req_uri = bton(req.uri) + if req_uri == PEERCRED_IDS_URI: + peer_creds = conn.peer_pid, conn.peer_uid, conn.peer_gid + self.send_payload('|'.join(map(str, peer_creds))) + return + elif req_uri == PEERCRED_TEXTS_URI: + self.send_payload('!'.join((conn.peer_user, conn.peer_group))) + return + return super(_TestGateway, self).respond() + + def send_payload(self, payload): + req = self.req + req.status = b'200 OK' + req.ensure_headers_sent() + req.write(ntob(payload)) + + +@pytest.fixture +def peercreds_enabled_server_and_client(http_server, unix_sock_file): + """Construct a test server with `peercreds_enabled`.""" + httpserver = http_server.send(unix_sock_file) + httpserver.gateway = _TestGateway + httpserver.peercreds_enabled = True + return httpserver, get_server_client(httpserver) + + +@unix_only_sock_test +@non_macos_sock_test +def test_peercreds_unix_sock(peercreds_enabled_server_and_client): + """Check that peercred lookup works when enabled.""" + httpserver, testclient = peercreds_enabled_server_and_client + bind_addr = httpserver.bind_addr + + if isinstance(bind_addr, six.binary_type): + bind_addr = bind_addr.decode() + + unix_base_uri = 'http+unix://{}'.format( + bind_addr.replace('\0', '%00').replace('/', '%2F'), + ) + + expected_peercreds = os.getpid(), os.getuid(), os.getgid() + expected_peercreds = '|'.join(map(str, expected_peercreds)) + + with requests_unixsocket.monkeypatch(): + peercreds_resp = requests.get(unix_base_uri + PEERCRED_IDS_URI) + peercreds_resp.raise_for_status() + assert peercreds_resp.text == expected_peercreds + + peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI) + assert peercreds_text_resp.status_code == 500 + + +@pytest.mark.skipif( + not IS_UID_GID_RESOLVABLE, + reason='Modules `grp` and `pwd` are not available ' + 'under the current platform', +) +@unix_only_sock_test +@non_macos_sock_test +def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client): + """Check that peercred resolution works when enabled.""" + httpserver, testclient = peercreds_enabled_server_and_client + httpserver.peercreds_resolve_enabled = True + + bind_addr = httpserver.bind_addr + + if isinstance(bind_addr, six.binary_type): + bind_addr = bind_addr.decode() + + unix_base_uri = 'http+unix://{}'.format( + bind_addr.replace('\0', '%00').replace('/', '%2F'), + ) + + import grp + import pwd + expected_textcreds = ( + pwd.getpwuid(os.getuid()).pw_name, + grp.getgrgid(os.getgid()).gr_name, + ) + expected_textcreds = '!'.join(map(str, expected_textcreds)) + with requests_unixsocket.monkeypatch(): + peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI) + peercreds_text_resp.raise_for_status() + assert peercreds_text_resp.text == expected_textcreds diff --git a/lib/cheroot/test/test_ssl.py b/lib/cheroot/test/test_ssl.py new file mode 100644 index 00000000..caa1ae0a --- /dev/null +++ b/lib/cheroot/test/test_ssl.py @@ -0,0 +1,474 @@ +"""Tests for TLS/SSL support.""" +# -*- coding: utf-8 -*- +# vim: set fileencoding=utf-8 : + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import functools +import os +import ssl +import sys +import threading +import time + +import OpenSSL.SSL +import pytest +import requests +import six +import trustme + +from .._compat import bton, ntob, ntou +from .._compat import IS_ABOVE_OPENSSL10, IS_PYPY +from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS +from ..server import Gateway, HTTPServer, get_ssl_adapter_class +from ..testing import ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + EPHEMERAL_PORT, + # get_server_client, + _get_conn_data, + _probe_ipv6_sock, +) + + +IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) +IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL') +IS_PYOPENSSL_SSL_VERSION_1_0 = ( + OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION). + startswith(b'OpenSSL 1.0.') +) +PY27 = sys.version_info[:2] == (2, 7) +PY34 = sys.version_info[:2] == (3, 4) + + +_stdlib_to_openssl_verify = { + ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, + ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, + ssl.CERT_REQUIRED: + OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, +} + + +fails_under_py3 = pytest.mark.xfail( + not six.PY2, + reason='Fails under Python 3+', +) + + +fails_under_py3_in_pypy = pytest.mark.xfail( + not six.PY2 and IS_PYPY, + reason='Fails under PyPy3', +) + + +missing_ipv6 = pytest.mark.skipif( + not _probe_ipv6_sock('::1'), + reason='' + 'IPv6 is disabled ' + '(for example, under Travis CI ' + 'which runs under GCE supporting only IPv4)', +) + + +class HelloWorldGateway(Gateway): + """Gateway responding with Hello World to root URI.""" + + def respond(self): + """Respond with dummy content via HTTP.""" + req = self.req + req_uri = bton(req.uri) + if req_uri == '/': + req.status = b'200 OK' + req.ensure_headers_sent() + req.write(b'Hello world!') + return + return super(HelloWorldGateway, self).respond() + + +def make_tls_http_server(bind_addr, ssl_adapter, request): + """Create and start an HTTP server bound to bind_addr.""" + httpserver = HTTPServer( + bind_addr=bind_addr, + gateway=HelloWorldGateway, + ) + # httpserver.gateway = HelloWorldGateway + httpserver.ssl_adapter = ssl_adapter + + threading.Thread(target=httpserver.safe_start).start() + + while not httpserver.ready: + time.sleep(0.1) + + request.addfinalizer(httpserver.stop) + + return httpserver + + +@pytest.fixture +def tls_http_server(request): + """Provision a server creator as a fixture.""" + return functools.partial(make_tls_http_server, request=request) + + +@pytest.fixture +def ca(): + """Provide a certificate authority via fixture.""" + return trustme.CA() + + +@pytest.fixture +def tls_ca_certificate_pem_path(ca): + """Provide a certificate authority certificate file via fixture.""" + with ca.cert_pem.tempfile() as ca_cert_pem: + yield ca_cert_pem + + +@pytest.fixture +def tls_certificate(ca): + """Provide a leaf certificate via fixture.""" + interface, host, port = _get_conn_data(ANY_INTERFACE_IPV4) + return ca.issue_server_cert(ntou(interface), ) + + +@pytest.fixture +def tls_certificate_chain_pem_path(tls_certificate): + """Provide a certificate chain PEM file path via fixture.""" + with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: + yield cert_pem + + +@pytest.fixture +def tls_certificate_private_key_pem_path(tls_certificate): + """Provide a certificate private key PEM file path via fixture.""" + with tls_certificate.private_key_pem.tempfile() as cert_key_pem: + yield cert_key_pem + + +@pytest.mark.parametrize( + 'adapter_type', + ( + 'builtin', + 'pyopenssl', + ), +) +def test_ssl_adapters( + tls_http_server, adapter_type, + tls_certificate, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + tls_ca_certificate_pem_path, +): + """Test ability to connect to server via HTTPS using adapters.""" + interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4) + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, + ) + if adapter_type == 'pyopenssl': + tls_adapter.context = tls_adapter.get_context() + + tls_certificate.configure_cert(tls_adapter.context) + + tlshttpserver = tls_http_server((interface, port), tls_adapter) + + # testclient = get_server_client(tlshttpserver) + # testclient.get('/') + + interface, _host, port = _get_conn_data( + tlshttpserver.bind_addr, + ) + + resp = requests.get( + 'https://' + interface + ':' + str(port) + '/', + verify=tls_ca_certificate_pem_path, + ) + + assert resp.status_code == 200 + assert resp.text == 'Hello world!' + + +@pytest.mark.parametrize( + 'adapter_type', + ( + 'builtin', + 'pyopenssl', + ), +) +@pytest.mark.parametrize( + 'is_trusted_cert,tls_client_identity', + ( + (True, 'localhost'), (True, '127.0.0.1'), + (True, '*.localhost'), (True, 'not_localhost'), + (False, 'localhost'), + ), +) +@pytest.mark.parametrize( + 'tls_verify_mode', + ( + ssl.CERT_NONE, # server shouldn't validate client cert + ssl.CERT_OPTIONAL, # same as CERT_REQUIRED in client mode, don't use + ssl.CERT_REQUIRED, # server should validate if client cert CA is OK + ), +) +def test_tls_client_auth( + # FIXME: remove twisted logic, separate tests + mocker, + tls_http_server, adapter_type, + ca, + tls_certificate, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + tls_ca_certificate_pem_path, + is_trusted_cert, tls_client_identity, + tls_verify_mode, +): + """Verify that client TLS certificate auth works correctly.""" + test_cert_rejection = ( + tls_verify_mode != ssl.CERT_NONE + and not is_trusted_cert + ) + interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4) + + client_cert_root_ca = ca if is_trusted_cert else trustme.CA() + with mocker.mock_module.patch( + 'idna.core.ulabel', + return_value=ntob(tls_client_identity), + ): + client_cert = client_cert_root_ca.issue_server_cert( + # FIXME: change to issue_cert once new trustme is out + ntou(tls_client_identity), + ) + del client_cert_root_ca + + with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem: + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + ) + if adapter_type == 'pyopenssl': + tls_adapter.context = tls_adapter.get_context() + tls_adapter.context.set_verify( + _stdlib_to_openssl_verify[tls_verify_mode], + lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + ) + else: + tls_adapter.context.verify_mode = tls_verify_mode + + ca.configure_trust(tls_adapter.context) + tls_certificate.configure_cert(tls_adapter.context) + + tlshttpserver = tls_http_server((interface, port), tls_adapter) + + interface, _host, port = _get_conn_data(tlshttpserver.bind_addr) + + make_https_request = functools.partial( + requests.get, + 'https://' + interface + ':' + str(port) + '/', + + # Server TLS certificate verification: + verify=tls_ca_certificate_pem_path, + + # Client TLS certificate verification: + cert=cl_pem, + ) + + if not test_cert_rejection: + resp = make_https_request() + is_req_successful = resp.status_code == 200 + if ( + not is_req_successful + and IS_PYOPENSSL_SSL_VERSION_1_0 + and adapter_type == 'builtin' + and tls_verify_mode == ssl.CERT_REQUIRED + and tls_client_identity == 'localhost' + and is_trusted_cert + ) or PY34: + pytest.xfail( + 'OpenSSL 1.0 has problems with verifying client certs', + ) + assert is_req_successful + assert resp.text == 'Hello world!' + return + + # xfail some flaky tests + # https://github.com/cherrypy/cheroot/issues/237 + issue_237 = ( + IS_MACOS + and adapter_type == 'builtin' + and tls_verify_mode != ssl.CERT_NONE + ) + if issue_237: + pytest.xfail('Test sometimes fails') + + expected_ssl_errors = ( + requests.exceptions.SSLError, + OpenSSL.SSL.Error, + ) if PY34 else ( + requests.exceptions.SSLError, + ) + if IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW: + expected_ssl_errors += requests.exceptions.ConnectionError, + with pytest.raises(expected_ssl_errors) as ssl_err: + make_https_request() + + if PY34 and isinstance(ssl_err, OpenSSL.SSL.Error): + pytest.xfail( + 'OpenSSL behaves wierdly under Python 3.4 ' + 'because of an outdated urllib3', + ) + + try: + err_text = ssl_err.value.args[0].reason.args[0].args[0] + except AttributeError: + if PY34: + pytest.xfail('OpenSSL behaves wierdly under Python 3.4') + elif not six.PY2 and IS_WINDOWS: + err_text = str(ssl_err.value) + else: + raise + + expected_substrings = ( + 'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND + else 'tlsv1 alert unknown ca', + ) + if not six.PY2: + if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl': + expected_substrings = ('tlsv1 alert unknown ca', ) + if ( + IS_WINDOWS + and tls_verify_mode in ( + ssl.CERT_REQUIRED, + ssl.CERT_OPTIONAL, + ) + and not is_trusted_cert + and tls_client_identity == 'localhost' + ): + expected_substrings += ( + 'bad handshake: ' + "SysCallError(10054, 'WSAECONNRESET')", + "('Connection aborted.', " + 'OSError("(10054, \'WSAECONNRESET\')"))', + ) + assert any(e in err_text for e in expected_substrings) + + +@pytest.mark.parametrize( + 'ip_addr', + ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + ), +) +def test_https_over_http_error(http_server, ip_addr): + """Ensure that connecting over HTTPS to HTTP port is handled.""" + httpserver = http_server.send((ip_addr, EPHEMERAL_PORT)) + interface, _host, port = _get_conn_data(httpserver.bind_addr) + with pytest.raises(ssl.SSLError) as ssl_err: + six.moves.http_client.HTTPSConnection( + '{interface}:{port}'.format( + interface=interface, + port=port, + ), + ).request('GET', '/') + expected_substring = ( + 'wrong version number' if IS_ABOVE_OPENSSL10 + else 'unknown protocol' + ) + assert expected_substring in ssl_err.value.args[-1] + + +@pytest.mark.parametrize( + 'adapter_type', + ( + 'builtin', + 'pyopenssl', + ), +) +@pytest.mark.parametrize( + 'ip_addr', + ( + ANY_INTERFACE_IPV4, + pytest.param(ANY_INTERFACE_IPV6, marks=missing_ipv6), + ), +) +def test_http_over_https_error( + tls_http_server, adapter_type, + ca, ip_addr, + tls_certificate, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, +): + """Ensure that connecting over HTTP to HTTPS port is handled.""" + # disable some flaky tests + # https://github.com/cherrypy/cheroot/issues/225 + issue_225 = ( + IS_MACOS + and adapter_type == 'builtin' + ) + if issue_225: + pytest.xfail('Test fails in Travis-CI') + + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, + ) + if adapter_type == 'pyopenssl': + tls_adapter.context = tls_adapter.get_context() + + tls_certificate.configure_cert(tls_adapter.context) + + interface, _host, port = _get_conn_data(ip_addr) + tlshttpserver = tls_http_server((interface, port), tls_adapter) + + interface, host, port = _get_conn_data( + tlshttpserver.bind_addr, + ) + + fqdn = interface + if ip_addr is ANY_INTERFACE_IPV6: + fqdn = '[{}]'.format(fqdn) + + expect_fallback_response_over_plain_http = ( + (adapter_type == 'pyopenssl' + and (IS_ABOVE_OPENSSL10 or not six.PY2)) + or PY27 + ) + if expect_fallback_response_over_plain_http: + resp = requests.get( + 'http://' + fqdn + ':' + str(port) + '/', + ) + assert resp.status_code == 400 + assert resp.text == ( + 'The client sent a plain HTTP request, ' + 'but this server only speaks HTTPS on this port.' + ) + return + + with pytest.raises(requests.exceptions.ConnectionError) as ssl_err: + requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL + 'http://' + fqdn + ':' + str(port) + '/', + ) + + if IS_LINUX: + expected_error_code, expected_error_text = ( + 104, 'Connection reset by peer', + ) + if IS_MACOS: + expected_error_code, expected_error_text = ( + 54, 'Connection reset by peer', + ) + if IS_WINDOWS: + expected_error_code, expected_error_text = ( + 10054, + 'An existing connection was forcibly closed by the remote host', + ) + + underlying_error = ssl_err.value.args[0].args[-1] + err_text = str(underlying_error) + assert underlying_error.errno == expected_error_code, ( + 'The underlying error is {!r}'. + format(underlying_error) + ) + assert expected_error_text in err_text diff --git a/lib/cheroot/test/webtest.py b/lib/cheroot/test/webtest.py new file mode 100644 index 00000000..934b2004 --- /dev/null +++ b/lib/cheroot/test/webtest.py @@ -0,0 +1,605 @@ +"""Extensions to unittest for web frameworks. + +Use the WebCase.getPage method to request a page from your HTTP server. +Framework Integration +===================== +If you have control over your server process, you can handle errors +in the server-side of the HTTP conversation a bit better. You must run +both the client (your WebCase tests) and the server in the same process +(but in separate threads, obviously). +When an error occurs in the framework, call server_error. It will print +the traceback to stdout, and keep any assertions you have from running +(the assumption is that, if the server errors, the page output will not +be of further significance to your tests). +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pprint +import re +import socket +import sys +import time +import traceback +import os +import json +import unittest +import warnings +import functools + +from six.moves import http_client, map, urllib_parse +import six + +from more_itertools.more import always_iterable +import jaraco.functools + + +def interface(host): + """Return an IP address for a client connection given the server host. + + If the server is listening on '0.0.0.0' (INADDR_ANY) + or '::' (IN6ADDR_ANY), this will return the proper localhost. + """ + if host == '0.0.0.0': + # INADDR_ANY, which should respond on localhost. + return '127.0.0.1' + if host == '::': + # IN6ADDR_ANY, which should respond on localhost. + return '::1' + return host + + +try: + # Jython support + if sys.platform[:4] == 'java': + def getchar(): + """Get a key press.""" + # Hopefully this is enough + return sys.stdin.read(1) + else: + # On Windows, msvcrt.getch reads a single char without output. + import msvcrt + + def getchar(): + """Get a key press.""" + return msvcrt.getch() +except ImportError: + # Unix getchr + import tty + import termios + + def getchar(): + """Get a key press.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + +# from jaraco.properties +class NonDataProperty: + """Non-data property decorator.""" + + def __init__(self, fget): + """Initialize a non-data property.""" + assert fget is not None, 'fget cannot be none' + assert callable(fget), 'fget must be callable' + self.fget = fget + + def __get__(self, obj, objtype=None): + """Return a class property.""" + if obj is None: + return self + return self.fget(obj) + + +class WebCase(unittest.TestCase): + """Helper web test suite base.""" + + HOST = '127.0.0.1' + PORT = 8000 + HTTP_CONN = http_client.HTTPConnection + PROTOCOL = 'HTTP/1.1' + + scheme = 'http' + url = None + ssl_context = None + + status = None + headers = None + body = None + + encoding = 'utf-8' + + time = None + + @property + def _Conn(self): + """Return HTTPConnection or HTTPSConnection based on self.scheme. + + * from http.client. + """ + cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper()) + return getattr(http_client, cls_name) + + def get_conn(self, auto_open=False): + """Return a connection to our HTTP server.""" + conn = self._Conn(self.interface(), self.PORT) + # Automatically re-connect? + conn.auto_open = auto_open + conn.connect() + return conn + + def set_persistent(self, on=True, auto_open=False): + """Make our HTTP_CONN persistent (or not). + + If the 'on' argument is True (the default), then self.HTTP_CONN + will be set to an instance of HTTP(S)?Connection + to persist across requests. + As this class only allows for a single open connection, if + self already has an open connection, it will be closed. + """ + try: + self.HTTP_CONN.close() + except (TypeError, AttributeError): + pass + + self.HTTP_CONN = ( + self.get_conn(auto_open=auto_open) + if on + else self._Conn + ) + + @property + def persistent(self): + """Presense of the persistent HTTP connection.""" + return hasattr(self.HTTP_CONN, '__class__') + + @persistent.setter + def persistent(self, on): + self.set_persistent(on) + + def interface(self): + """Return an IP address for a client connection. + + If the server is listening on '0.0.0.0' (INADDR_ANY) + or '::' (IN6ADDR_ANY), this will return the proper localhost. + """ + return interface(self.HOST) + + def getPage( + self, url, headers=None, method='GET', body=None, + protocol=None, raise_subcls=(), + ): + """Open the url with debugging support. Return status, headers, body. + + url should be the identifier passed to the server, typically a + server-absolute path and query string (sent between method and + protocol), and should only be an absolute URI if proxy support is + enabled in the server. + + If the application under test generates absolute URIs, be sure + to wrap them first with strip_netloc:: + + class MyAppWebCase(WebCase): + def getPage(url, *args, **kwargs): + super(MyAppWebCase, self).getPage( + cheroot.test.webtest.strip_netloc(url), + *args, **kwargs + ) + + `raise_subcls` is passed through to openURL. + """ + ServerError.on = False + + if isinstance(url, six.text_type): + url = url.encode('utf-8') + if isinstance(body, six.text_type): + body = body.encode('utf-8') + + # for compatibility, support raise_subcls is None + raise_subcls = raise_subcls or () + + self.url = url + self.time = None + start = time.time() + result = openURL( + url, headers, method, body, self.HOST, self.PORT, + self.HTTP_CONN, protocol or self.PROTOCOL, + raise_subcls=raise_subcls, + ssl_context=self.ssl_context, + ) + self.time = time.time() - start + self.status, self.headers, self.body = result + + # Build a list of request cookies from the previous response cookies. + self.cookies = [ + ('Cookie', v) for k, v in self.headers + if k.lower() == 'set-cookie' + ] + + if ServerError.on: + raise ServerError() + return result + + @NonDataProperty + def interactive(self): + """Determine whether tests are run in interactive mode. + + Load interactivity setting from environment, where + the value can be numeric or a string like true or + False or 1 or 0. + """ + env_str = os.environ.get('WEBTEST_INTERACTIVE', 'True') + is_interactive = bool(json.loads(env_str.lower())) + if is_interactive: + warnings.warn( + 'Interactive test failure interceptor support via ' + 'WEBTEST_INTERACTIVE environment variable is deprecated.', + DeprecationWarning, + ) + return is_interactive + + console_height = 30 + + def _handlewebError(self, msg): + print('') + print(' ERROR: %s' % msg) + + if not self.interactive: + raise self.failureException(msg) + + p = ( + ' Show: ' + '[B]ody [H]eaders [S]tatus [U]RL; ' + '[I]gnore, [R]aise, or sys.e[X]it >> ' + ) + sys.stdout.write(p) + sys.stdout.flush() + while True: + i = getchar().upper() + if not isinstance(i, type('')): + i = i.decode('ascii') + if i not in 'BHSUIRX': + continue + print(i.upper()) # Also prints new line + if i == 'B': + for x, line in enumerate(self.body.splitlines()): + if (x + 1) % self.console_height == 0: + # The \r and comma should make the next line overwrite + sys.stdout.write('<-- More -->\r') + m = getchar().lower() + # Erase our "More" prompt + sys.stdout.write(' \r') + if m == 'q': + break + print(line) + elif i == 'H': + pprint.pprint(self.headers) + elif i == 'S': + print(self.status) + elif i == 'U': + print(self.url) + elif i == 'I': + # return without raising the normal exception + return + elif i == 'R': + raise self.failureException(msg) + elif i == 'X': + sys.exit() + sys.stdout.write(p) + sys.stdout.flush() + + @property + def status_code(self): # noqa: D401; irrelevant for properties + """Integer HTTP status code.""" + return int(self.status[:3]) + + def status_matches(self, expected): + """Check whether actual status matches expected.""" + actual = ( + self.status_code + if isinstance(expected, int) else + self.status + ) + return expected == actual + + def assertStatus(self, status, msg=None): + """Fail if self.status != status. + + status may be integer code, exact string status, or + iterable of allowed possibilities. + """ + if any(map(self.status_matches, always_iterable(status))): + return + + tmpl = 'Status {self.status} does not match {status}' + msg = msg or tmpl.format(**locals()) + self._handlewebError(msg) + + def assertHeader(self, key, value=None, msg=None): + """Fail if (key, [value]) not in self.headers.""" + lowkey = key.lower() + for k, v in self.headers: + if k.lower() == lowkey: + if value is None or str(value) == v: + return v + + if msg is None: + if value is None: + msg = '%r not in headers' % key + else: + msg = '%r:%r not in headers' % (key, value) + self._handlewebError(msg) + + def assertHeaderIn(self, key, values, msg=None): + """Fail if header indicated by key doesn't have one of the values.""" + lowkey = key.lower() + for k, v in self.headers: + if k.lower() == lowkey: + matches = [value for value in values if str(value) == v] + if matches: + return matches + + if msg is None: + msg = '%(key)r not in %(values)r' % vars() + self._handlewebError(msg) + + def assertHeaderItemValue(self, key, value, msg=None): + """Fail if the header does not contain the specified value.""" + actual_value = self.assertHeader(key, msg=msg) + header_values = map(str.strip, actual_value.split(',')) + if value in header_values: + return value + + if msg is None: + msg = '%r not in %r' % (value, header_values) + self._handlewebError(msg) + + def assertNoHeader(self, key, msg=None): + """Fail if key in self.headers.""" + lowkey = key.lower() + matches = [k for k, v in self.headers if k.lower() == lowkey] + if matches: + if msg is None: + msg = '%r in headers' % key + self._handlewebError(msg) + + def assertNoHeaderItemValue(self, key, value, msg=None): + """Fail if the header contains the specified value.""" + lowkey = key.lower() + hdrs = self.headers + matches = [k for k, v in hdrs if k.lower() == lowkey and v == value] + if matches: + if msg is None: + msg = '%r:%r in %r' % (key, value, hdrs) + self._handlewebError(msg) + + def assertBody(self, value, msg=None): + """Fail if value != self.body.""" + if isinstance(value, six.text_type): + value = value.encode(self.encoding) + if value != self.body: + if msg is None: + msg = 'expected body:\n%r\n\nactual body:\n%r' % ( + value, self.body, + ) + self._handlewebError(msg) + + def assertInBody(self, value, msg=None): + """Fail if value not in self.body.""" + if isinstance(value, six.text_type): + value = value.encode(self.encoding) + if value not in self.body: + if msg is None: + msg = '%r not in body: %s' % (value, self.body) + self._handlewebError(msg) + + def assertNotInBody(self, value, msg=None): + """Fail if value in self.body.""" + if isinstance(value, six.text_type): + value = value.encode(self.encoding) + if value in self.body: + if msg is None: + msg = '%r found in body' % value + self._handlewebError(msg) + + def assertMatchesBody(self, pattern, msg=None, flags=0): + """Fail if value (a regex pattern) is not in self.body.""" + if isinstance(pattern, six.text_type): + pattern = pattern.encode(self.encoding) + if re.search(pattern, self.body, flags) is None: + if msg is None: + msg = 'No match for %r in body' % pattern + self._handlewebError(msg) + + +methods_with_bodies = ('POST', 'PUT', 'PATCH') + + +def cleanHeaders(headers, method, body, host, port): + """Return request headers, with required headers added (if missing).""" + if headers is None: + headers = [] + + # Add the required Host request header if not present. + # [This specifies the host:port of the server, not the client.] + found = False + for k, v in headers: + if k.lower() == 'host': + found = True + break + if not found: + if port == 80: + headers.append(('Host', host)) + else: + headers.append(('Host', '%s:%s' % (host, port))) + + if method in methods_with_bodies: + # Stick in default type and length headers if not present + found = False + for k, v in headers: + if k.lower() == 'content-type': + found = True + break + if not found: + headers.append( + ('Content-Type', 'application/x-www-form-urlencoded'), + ) + headers.append(('Content-Length', str(len(body or '')))) + + return headers + + +def shb(response): + """Return status, headers, body the way we like from a response.""" + resp_status_line = '%s %s' % (response.status, response.reason) + + if not six.PY2: + return resp_status_line, response.getheaders(), response.read() + + h = [] + key, value = None, None + for line in response.msg.headers: + if line: + if line[0] in ' \t': + value += line.strip() + else: + if key and value: + h.append((key, value)) + key, value = line.split(':', 1) + key = key.strip() + value = value.strip() + if key and value: + h.append((key, value)) + + return resp_status_line, h, response.read() + + +# def openURL(*args, raise_subcls=(), **kwargs): +# py27 compatible signature: +def openURL(*args, **kwargs): + """ + Open a URL, retrying when it fails. + + Specify `raise_subcls` (class or tuple of classes) to exclude + those socket.error subclasses from being suppressed and retried. + """ + raise_subcls = kwargs.pop('raise_subcls', ()) + opener = functools.partial(_open_url_once, *args, **kwargs) + + def on_exception(): + type_, exc = sys.exc_info()[:2] + if isinstance(exc, raise_subcls): + raise + time.sleep(0.5) + + # Try up to 10 times + return jaraco.functools.retry_call( + opener, + retries=9, + cleanup=on_exception, + trap=socket.error, + ) + + +def _open_url_once( + url, headers=None, method='GET', body=None, + host='127.0.0.1', port=8000, http_conn=http_client.HTTPConnection, + protocol='HTTP/1.1', ssl_context=None, +): + """Open the given HTTP resource and return status, headers, and body.""" + headers = cleanHeaders(headers, method, body, host, port) + + # Allow http_conn to be a class or an instance + if hasattr(http_conn, 'host'): + conn = http_conn + else: + kw = {} + if ssl_context: + kw['context'] = ssl_context + conn = http_conn(interface(host), port, **kw) + conn._http_vsn_str = protocol + conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()])) + if not six.PY2 and isinstance(url, bytes): + url = url.decode() + conn.putrequest( + method.upper(), url, skip_host=True, + skip_accept_encoding=True, + ) + for key, value in headers: + conn.putheader(key, value.encode('Latin-1')) + conn.endheaders() + if body is not None: + conn.send(body) + # Handle response + response = conn.getresponse() + s, h, b = shb(response) + if not hasattr(http_conn, 'host'): + # We made our own conn instance. Close it. + conn.close() + return s, h, b + + +def strip_netloc(url): + """Return absolute-URI path from URL. + + Strip the scheme and host from the URL, returning the + server-absolute portion. + + Useful for wrapping an absolute-URI for which only the + path is expected (such as in calls to getPage). + + >>> strip_netloc('https://google.com/foo/bar?bing#baz') + '/foo/bar?bing' + + >>> strip_netloc('//google.com/foo/bar?bing#baz') + '/foo/bar?bing' + + >>> strip_netloc('/foo/bar?bing#baz') + '/foo/bar?bing' + """ + parsed = urllib_parse.urlparse(url) + scheme, netloc, path, params, query, fragment = parsed + stripped = '', '', path, params, query, '' + return urllib_parse.urlunparse(stripped) + + +# Add any exceptions which your web framework handles +# normally (that you don't want server_error to trap). +ignored_exceptions = [] + +# You'll want set this to True when you can't guarantee +# that each response will immediately follow each request; +# for example, when handling requests via multiple threads. +ignore_all = False + + +class ServerError(Exception): + """Exception for signalling server error.""" + + on = False + + +def server_error(exc=None): + """Server debug hook. + + Return True if exception handled, False if ignored. + You probably want to wrap this, so you can still handle an error using + your framework when it's ignored. + """ + if exc is None: + exc = sys.exc_info() + + if ignore_all or exc[0] in ignored_exceptions: + return False + else: + ServerError.on = True + print('') + print(''.join(traceback.format_exception(*exc))) + return True diff --git a/lib/cheroot/testing.py b/lib/cheroot/testing.py new file mode 100644 index 00000000..94bb7734 --- /dev/null +++ b/lib/cheroot/testing.py @@ -0,0 +1,153 @@ +"""Pytest fixtures and other helpers for doing testing by end-users.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from contextlib import closing +import errno +import socket +import threading +import time + +import pytest +from six.moves import http_client + +import cheroot.server +from cheroot.test import webtest +import cheroot.wsgi + +EPHEMERAL_PORT = 0 +NO_INTERFACE = None # Using this or '' will cause an exception +ANY_INTERFACE_IPV4 = '0.0.0.0' +ANY_INTERFACE_IPV6 = '::' + +config = { + cheroot.wsgi.Server: { + 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), + 'wsgi_app': None, + }, + cheroot.server.HTTPServer: { + 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), + 'gateway': cheroot.server.Gateway, + }, +} + + +def cheroot_server(server_factory): + """Set up and tear down a Cheroot server instance.""" + conf = config[server_factory].copy() + bind_port = conf.pop('bind_addr')[-1] + + for interface in ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV4: + try: + actual_bind_addr = (interface, bind_port) + httpserver = server_factory( # create it + bind_addr=actual_bind_addr, + **conf + ) + except OSError: + pass + else: + break + + httpserver.shutdown_timeout = 0 # Speed-up tests teardown + + threading.Thread(target=httpserver.safe_start).start() # spawn it + while not httpserver.ready: # wait until fully initialized and bound + time.sleep(0.1) + + yield httpserver + + httpserver.stop() # destroy it + + +@pytest.fixture(scope='module') +def wsgi_server(): + """Set up and tear down a Cheroot WSGI server instance.""" + for srv in cheroot_server(cheroot.wsgi.Server): + yield srv + + +@pytest.fixture(scope='module') +def native_server(): + """Set up and tear down a Cheroot HTTP server instance.""" + for srv in cheroot_server(cheroot.server.HTTPServer): + yield srv + + +class _TestClient: + def __init__(self, server): + self._interface, self._host, self._port = _get_conn_data( + server.bind_addr, + ) + self.server_instance = server + self._http_connection = self.get_connection() + + def get_connection(self): + name = '{interface}:{port}'.format( + interface=self._interface, + port=self._port, + ) + conn_cls = ( + http_client.HTTPConnection + if self.server_instance.ssl_adapter is None else + http_client.HTTPSConnection + ) + return conn_cls(name) + + def request( + self, uri, method='GET', headers=None, http_conn=None, + protocol='HTTP/1.1', + ): + return webtest.openURL( + uri, method=method, + headers=headers, + host=self._host, port=self._port, + http_conn=http_conn or self._http_connection, + protocol=protocol, + ) + + def __getattr__(self, attr_name): + def _wrapper(uri, **kwargs): + http_method = attr_name.upper() + return self.request(uri, method=http_method, **kwargs) + + return _wrapper + + +def _probe_ipv6_sock(interface): + # Alternate way is to check IPs on interfaces using glibc, like: + # github.com/Gautier/minifail/blob/master/minifail/getifaddrs.py + try: + with closing(socket.socket(family=socket.AF_INET6)) as sock: + sock.bind((interface, 0)) + except (OSError, socket.error) as sock_err: + # In Python 3 socket.error is an alias for OSError + # In Python 2 socket.error is a subclass of IOError + if sock_err.errno != errno.EADDRNOTAVAIL: + raise + else: + return True + + return False + + +def _get_conn_data(bind_addr): + if isinstance(bind_addr, tuple): + host, port = bind_addr + else: + host, port = bind_addr, 0 + + interface = webtest.interface(host) + + if ':' in interface and not _probe_ipv6_sock(interface): + interface = '127.0.0.1' + if ':' in host: + host = interface + + return interface, host, port + + +def get_server_client(server): + """Create and return a test client for the given server.""" + return _TestClient(server) diff --git a/lib/cheroot/workers/__init__.py b/lib/cheroot/workers/__init__.py new file mode 100644 index 00000000..098b8f25 --- /dev/null +++ b/lib/cheroot/workers/__init__.py @@ -0,0 +1 @@ +"""HTTP workers pool.""" diff --git a/lib/cheroot/workers/threadpool.py b/lib/cheroot/workers/threadpool.py new file mode 100644 index 00000000..8c1d29f7 --- /dev/null +++ b/lib/cheroot/workers/threadpool.py @@ -0,0 +1,323 @@ +"""A thread-based worker pool.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +import collections +import threading +import time +import socket +import warnings + +from six.moves import queue + +from jaraco.functools import pass_none + + +__all__ = ('WorkerThread', 'ThreadPool') + + +class TrueyZero: + """Object which equals and does math like the integer 0 but evals True.""" + + def __add__(self, other): + return other + + def __radd__(self, other): + return other + + +trueyzero = TrueyZero() + +_SHUTDOWNREQUEST = None + + +class WorkerThread(threading.Thread): + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + def __init__(self, server): + """Initialize WorkerThread instance. + + Args: + server (cheroot.server.HTTPServer): web server object + receiving this request + """ + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ( + self.start_time is None + and trueyzero + or self.conn.requests_seen + ), + 'Bytes Read': lambda s: self.bytes_read + ( + self.start_time is None + and trueyzero + or self.conn.rfile.bytes_read + ), + 'Bytes Written': lambda s: self.bytes_written + ( + self.start_time is None + and trueyzero + or self.conn.wfile.bytes_written + ), + 'Work Time': lambda s: self.work_time + ( + self.start_time is None + and trueyzero + or time.time() - self.start_time + ), + 'Read Throughput': lambda s: s['Bytes Read'](s) / ( + s['Work Time'](s) or 1e-6 + ), + 'Write Throughput': lambda s: s['Bytes Written'](s) / ( + s['Work Time'](s) or 1e-6 + ), + } + threading.Thread.__init__(self) + + def run(self): + """Process incoming HTTP connections. + + Retrieves incoming connections from thread pool. + """ + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + # Just close the connection and move on. + if conn.closeable: + conn.close() + continue + + self.conn = conn + is_stats_enabled = self.server.stats['Enabled'] + if is_stats_enabled: + self.start_time = time.time() + keep_conn_open = False + try: + keep_conn_open = conn.communicate() + finally: + if keep_conn_open: + self.server.connections.put(conn) + else: + conn.close() + if is_stats_enabled: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit) as ex: + self.server.interrupt = ex + + +class ThreadPool: + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__( + self, server, min=10, max=-1, accepted_queue_size=-1, + accepted_queue_timeout=10, + ): + """Initialize HTTP requests queue instance. + + Args: + server (cheroot.server.HTTPServer): web server object + receiving this request + min (int): minimum number of worker threads + max (int): maximum number of worker threads + accepted_queue_size (int): maximum number of active + requests in queue + accepted_queue_timeout (int): timeout for putting request + into queue + """ + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue(maxsize=accepted_queue_size) + self._queue_put_timeout = accepted_queue_timeout + self.get = self._queue.get + self._pending_shutdowns = collections.deque() + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName('CP Server ' + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + @property + def idle(self): # noqa: D401; irrelevant for properties + """Number of worker threads which are idle. Read-only.""" + idles = len([t for t in self._threads if t.conn is None]) + return max(idles - len(self._pending_shutdowns), 0) + + def put(self, obj): + """Put request into queue. + + Args: + obj (cheroot.server.HTTPConnection): HTTP connection + waiting to be processed + """ + self._queue.put(obj, block=True, timeout=self._queue_put_timeout) + + def _clear_dead_threads(self): + # Remove any dead threads from our list + for t in [t for t in self._threads if not t.is_alive()]: + self._threads.remove(t) + try: + self._pending_shutdowns.popleft() + except IndexError: + pass + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + if self.max > 0: + budget = max(self.max - len(self._threads), 0) + else: + # self.max <= 0 indicates no maximum + budget = float('inf') + + n_new = min(amount, budget) + + workers = [self._spawn_worker() for i in range(n_new)] + while not all(worker.ready for worker in workers): + time.sleep(.1) + self._threads.extend(workers) + + def _spawn_worker(self): + worker = WorkerThread(self.server) + worker.setName('CP Server ' + worker.getName()) + worker.start() + return worker + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + amount -= len(self._pending_shutdowns) + self._clear_dead_threads() + if amount <= 0: + return + + # calculate the number of threads above the minimum + n_extra = max(len(self._threads) - self.min, 0) + + # don't remove more than amount + n_to_remove = min(amount, n_extra) + + # put shutdown requests on the queue equal to the number of threads + # to remove. As each request is processed by a worker, that worker + # will terminate and be culled from the list. + for n in range(n_to_remove): + self._pending_shutdowns.append(None) + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + """Terminate all worker threads. + + Args: + timeout (int): time to wait for threads to stop gracefully + """ + # for compatability, negative timeouts are treated like None + # TODO: treat negative timeouts like already expired timeouts + if timeout is not None and timeout < 0: + timeout = None + warnings.warning( + 'In the future, negative timeouts to Server.stop() ' + 'will be equivalent to a timeout of zero.', + stacklevel=2, + ) + + if timeout is not None: + endtime = time.time() + timeout + + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + ignored_errors = ( + # TODO: explain this exception. + AssertionError, + # Ignore repeated Ctrl-C. See cherrypy#691. + KeyboardInterrupt, + ) + + for worker in self._clear_threads(): + remaining_time = timeout and endtime - time.time() + try: + worker.join(remaining_time) + if worker.is_alive(): + # Timeout exhausted; forcibly shut down the socket. + self._force_close(worker.conn) + worker.join() + except ignored_errors: + pass + + @staticmethod + @pass_none + def _force_close(conn): + if conn.rfile.closed: + return + try: + try: + conn.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + conn.socket.shutdown() + except OSError: + # shutdown sometimes fails (race with 'closed' check?) + # ref #238 + pass + + def _clear_threads(self): + """Clear self._threads and yield all joinable threads.""" + # threads = pop_all(self._threads) + threads, self._threads[:] = self._threads[:], [] + return ( + thread + for thread in threads + if thread is not threading.currentThread() + ) + + @property + def qsize(self): + """Return the queue size.""" + return self._queue.qsize() diff --git a/lib/cheroot/wsgi.py b/lib/cheroot/wsgi.py new file mode 100644 index 00000000..30599b35 --- /dev/null +++ b/lib/cheroot/wsgi.py @@ -0,0 +1,434 @@ +"""This class holds Cheroot WSGI server implementation. + +Simplest example on how to use this server:: + + from cheroot import wsgi + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return [b'Hello world!'] + + addr = '0.0.0.0', 8070 + server = wsgi.Server(addr, my_crazy_app) + server.start() + +The Cheroot WSGI server can serve as many WSGI applications +as you want in one instance by using a PathInfoDispatcher:: + + path_map = { + '/': my_crazy_app, + '/blog': my_blog_app, + } + d = wsgi.PathInfoDispatcher(path_map) + server = wsgi.Server(addr, d) +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import sys + +import six +from six.moves import filter + +from . import server +from .workers import threadpool +from ._compat import ntob, bton + + +class Server(server.HTTPServer): + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__( + self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, + accepted_queue_size=-1, accepted_queue_timeout=10, + peercreds_enabled=False, peercreds_resolve_enabled=False, + ): + """Initialize WSGI Server instance. + + Args: + bind_addr (tuple): network interface to listen to + wsgi_app (callable): WSGI application callable + numthreads (int): number of threads for WSGI thread pool + server_name (str): web server name to be advertised via + Server HTTP header + max (int): maximum number of worker threads + request_queue_size (int): the 'backlog' arg to + socket.listen(); max queued connections + timeout (int): the timeout in seconds for accepted connections + shutdown_timeout (int): the total time, in seconds, to + wait for worker threads to cleanly exit + accepted_queue_size (int): maximum number of active + requests in queue + accepted_queue_timeout (int): timeout for putting request + into queue + """ + super(Server, self).__init__( + bind_addr, + gateway=wsgi_gateways[self.wsgi_version], + server_name=server_name, + peercreds_enabled=peercreds_enabled, + peercreds_resolve_enabled=peercreds_resolve_enabled, + ) + self.wsgi_app = wsgi_app + self.request_queue_size = request_queue_size + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.requests = threadpool.ThreadPool( + self, min=numthreads or 1, max=max, + accepted_queue_size=accepted_queue_size, + accepted_queue_timeout=accepted_queue_timeout, + ) + + @property + def numthreads(self): + """Set minimum number of threads.""" + return self.requests.min + + @numthreads.setter + def numthreads(self, value): + self.requests.min = value + + +class Gateway(server.Gateway): + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + """Initialize WSGI Gateway instance with request. + + Args: + req (HTTPRequest): current HTTP request + """ + super(Gateway, self).__init__(req) + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + @classmethod + def gateway_map(cls): + """Create a mapping of gateways and their versions. + + Returns: + dict[tuple[int,int],class]: map of gateway version and + corresponding class + + """ + return dict( + (gw.version, gw) + for gw in cls.__subclasses__() + ) + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + raise NotImplementedError # pragma: no cover + + def respond(self): + """Process the current request. + + From :pep:`333`: + + The start_response callable must not actually transmit + the response headers. Instead, it must store them for the + server or gateway to transmit only after the first + iteration of the application return value that yields + a NON-EMPTY string, or upon the application's first + invocation of the write() callable. + """ + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in filter(None, response): + if not isinstance(chunk, six.binary_type): + raise ValueError('WSGI Applications must yield bytes') + self.write(chunk) + finally: + # Send headers if not already sent + self.req.ensure_headers_sent() + if hasattr(response, 'close'): + response.close() + + def start_response(self, status, headers, exc_info=None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError( + 'WSGI start_response called a second ' + 'time with no exc_info.', + ) + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + six.reraise(*exc_info) + finally: + exc_info = None + + self.req.status = self._encode_status(status) + + for k, v in headers: + if not isinstance(k, str): + raise TypeError( + 'WSGI response header key %r is not of type str.' % k, + ) + if not isinstance(v, str): + raise TypeError( + 'WSGI response header value %r is not of type str.' % v, + ) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + out_header = ntob(k), ntob(v) + self.req.outheaders.append(out_header) + + return self.write + + @staticmethod + def _encode_status(status): + """Cast status to bytes representation of current Python version. + + According to :pep:`3333`, when using Python 3, the response status + and headers must be bytes masquerading as unicode; that is, they + must be of type "str" but are restricted to code points in the + "latin-1" set. + """ + if six.PY2: + return status + if not isinstance(status, str): + raise TypeError('WSGI response status is not of type str.') + return status.encode('ISO-8859-1') + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError('WSGI write called before start_response.') + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response( + '500 Internal Server Error', + 'The requested resource returned more bytes than the ' + 'declared Content-Length.', + ) + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + self.req.ensure_headers_sent() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + 'Response body exceeds the declared Content-Length.', + ) + + +class Gateway_10(Gateway): + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + version = 1, 0 + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + req = self.req + req_conn = req.conn + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': bton(req.path), + 'QUERY_STRING': bton(req.qs), + 'REMOTE_ADDR': req_conn.remote_addr or '', + 'REMOTE_PORT': str(req_conn.remote_port or ''), + 'REQUEST_METHOD': bton(req.method), + 'REQUEST_URI': bton(req.uri), + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': bton(req.request_protocol), + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.input_terminated': bool(req.chunked_read), + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': bton(req.scheme), + 'wsgi.version': self.version, + } + + if isinstance(req.server.bind_addr, six.string_types): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env['SERVER_PORT'] = '' + try: + env['X_REMOTE_PID'] = str(req_conn.peer_pid) + env['X_REMOTE_UID'] = str(req_conn.peer_uid) + env['X_REMOTE_GID'] = str(req_conn.peer_gid) + + env['X_REMOTE_USER'] = str(req_conn.peer_user) + env['X_REMOTE_GROUP'] = str(req_conn.peer_group) + + env['REMOTE_USER'] = env['X_REMOTE_USER'] + except RuntimeError: + """Unable to retrieve peer creds data. + + Unsupported by current kernel or socket error happened, or + unsupported socket type, or disabled. + """ + else: + env['SERVER_PORT'] = str(req.server.bind_addr[1]) + + # Request headers + env.update( + ('HTTP_' + bton(k).upper().replace('-', '_'), bton(v)) + for k, v in req.inheaders.items() + ) + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop('HTTP_CONTENT_TYPE', None) + if ct is not None: + env['CONTENT_TYPE'] = ct + cl = env.pop('HTTP_CONTENT_LENGTH', None) + if cl is not None: + env['CONTENT_LENGTH'] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class Gateway_u0(Gateway_10): + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys + and values in both Python 2 and Python 3. + """ + + version = 'u', 0 + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + req = self.req + env_10 = super(Gateway_u0, self).get_environ() + env = dict(map(self._decode_key, env_10.items())) + + # Request-URI + enc = env.setdefault(six.u('wsgi.url_encoding'), six.u('utf-8')) + try: + env['PATH_INFO'] = req.path.decode(enc) + env['QUERY_STRING'] = req.qs.decode(enc) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env['wsgi.url_encoding'] = 'ISO-8859-1' + env['PATH_INFO'] = env_10['PATH_INFO'] + env['QUERY_STRING'] = env_10['QUERY_STRING'] + + env.update(map(self._decode_value, env.items())) + + return env + + @staticmethod + def _decode_key(item): + k, v = item + if six.PY2: + k = k.decode('ISO-8859-1') + return k, v + + @staticmethod + def _decode_value(item): + k, v = item + skip_keys = 'REQUEST_URI', 'wsgi.input' + if not six.PY2 or not isinstance(v, bytes) or k in skip_keys: + return k, v + return k, v.decode('ISO-8859-1') + + +wsgi_gateways = Gateway.gateway_map() + + +class PathInfoDispatcher: + """A WSGI dispatcher for dispatch based on the PATH_INFO.""" + + def __init__(self, apps): + """Initialize path info WSGI app dispatcher. + + Args: + apps (dict[str,object]|list[tuple[str,object]]): URI prefix + and WSGI app pairs + """ + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + def by_path_len(app): + return len(app[0]) + apps.sort(key=by_path_len, reverse=True) + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip('/'), a) for p, a in apps] + + def __call__(self, environ, start_response): + """Process incoming WSGI request. + + Ref: :pep:`3333` + + Args: + environ (Mapping): a dict containing WSGI environment variables + start_response (callable): function, which sets response + status and headers + + Returns: + list[bytes]: iterable containing bytes to be returned in + HTTP response body + + """ + path = environ['PATH_INFO'] or '/' + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + '/') or path == p: + environ = environ.copy() + environ['SCRIPT_NAME'] = environ.get('SCRIPT_NAME', '') + p + environ['PATH_INFO'] = path[len(p):] + return app(environ, start_response) + + start_response( + '404 Not Found', [ + ('Content-Type', 'text/plain'), + ('Content-Length', '0'), + ], + ) + return [''] + + +# compatibility aliases +globals().update( + WSGIServer=Server, + WSGIGateway=Gateway, + WSGIGateway_u0=Gateway_u0, + WSGIGateway_10=Gateway_10, + WSGIPathInfoDispatcher=PathInfoDispatcher, +)