Update cheroot-8.5.2

This commit is contained in:
JonnyWong16 2021-10-14 21:14:02 -07:00
parent 4ac151d7de
commit 182e5f553e
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
25 changed files with 2171 additions and 602 deletions

View file

@ -1,13 +1,20 @@
# pylint: disable=unused-import
"""Compatibility code for using Cheroot with various versions of Python."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os
import platform
import re
import six
try:
import selectors # lgtm [py/unused-import]
except ImportError:
import selectors2 as selectors # noqa: F401 # lgtm [py/unused-import]
try:
import ssl
IS_ABOVE_OPENSSL10 = ssl.OPENSSL_VERSION_INFO >= (1, 1)
@ -15,6 +22,24 @@ try:
except ImportError:
IS_ABOVE_OPENSSL10 = None
# contextlib.suppress was added in Python 3.4
try:
from contextlib import suppress
except ImportError:
from contextlib import contextmanager
@contextmanager
def suppress(*exceptions):
"""Return a context manager that suppresses the `exceptions`."""
try:
yield
except exceptions:
pass
IS_CI = bool(os.getenv('CI'))
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
IS_PYPY = platform.python_implementation() == 'PyPy'
@ -36,7 +61,7 @@ if not six.PY2:
return n.encode(encoding)
def ntou(n, encoding='ISO-8859-1'):
"""Return the native string as unicode with the given encoding."""
"""Return the native string as Unicode with the given encoding."""
assert_native(n)
# In Python 3, the native string type is unicode
return n
@ -55,7 +80,7 @@ else:
return n
def ntou(n, encoding='ISO-8859-1'):
"""Return the native string as unicode with the given encoding."""
"""Return the native string as Unicode with the given encoding."""
assert_native(n)
# In Python 2, the native string type is bytes.
# First, check for the special encoding 'escape'. The test suite uses
@ -78,7 +103,7 @@ else:
def assert_native(n):
"""Check whether the input is of nativ ``str`` type.
"""Check whether the input is of native :py:class:`str` type.
Raises:
TypeError: in case of failed check
@ -89,22 +114,35 @@ def assert_native(n):
if not six.PY2:
"""Python 3 has memoryview builtin."""
"""Python 3 has :py:class:`memoryview` builtin."""
# Python 2.7 has it backported, but socket.write() does
# str(memoryview(b'0' * 100)) -> <memory at 0x7fb6913a5588>
# instead of accessing it correctly.
memoryview = memoryview
else:
"""Link memoryview to buffer under Python 2."""
"""Link :py:class:`memoryview` to buffer under Python 2."""
memoryview = buffer # noqa: F821
def extract_bytes(mv):
"""Retrieve bytes out of memoryview/buffer or bytes."""
r"""Retrieve bytes out of the given input buffer.
:param mv: input :py:func:`buffer`
:type mv: memoryview or bytes
:return: unwrapped bytes
:rtype: bytes
:raises ValueError: if the input is not one of \
:py:class:`memoryview`/:py:func:`buffer` \
or :py:class:`bytes`
"""
if isinstance(mv, memoryview):
return bytes(mv) if six.PY2 else mv.tobytes()
if isinstance(mv, bytes):
return mv
raise ValueError
raise ValueError(
'extract_bytes() only accepts bytes and memoryview/buffer',
)

View file

@ -1,36 +1,42 @@
"""Command line tool for starting a Cheroot WSGI/HTTP server instance.
Basic usage::
Basic usage:
# Start a server on 127.0.0.1:8000 with the default settings
# for the WSGI app myapp/wsgi.py:application()
cheroot myapp.wsgi
.. code-block:: shell-session
# Start a server on 0.0.0.0:9000 with 8 threads
# for the WSGI app myapp/wsgi.py:main_app()
cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8
$ # Start a server on 127.0.0.1:8000 with the default settings
$ # for the WSGI app myapp/wsgi.py:application()
$ cheroot myapp.wsgi
# Start a server for the cheroot.server.Gateway subclass
# myapp/gateway.py:HTTPGateway
cheroot myapp.gateway:HTTPGateway
$ # Start a server on 0.0.0.0:9000 with 8 threads
$ # for the WSGI app myapp/wsgi.py:main_app()
$ cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8
# Start a server on the UNIX socket /var/spool/myapp.sock
cheroot myapp.wsgi --bind /var/spool/myapp.sock
$ # Start a server for the cheroot.server.Gateway subclass
$ # myapp/gateway.py:HTTPGateway
$ cheroot myapp.gateway:HTTPGateway
# Start a server on the abstract UNIX socket CherootServer
cheroot myapp.wsgi --bind @CherootServer
$ # Start a server on the UNIX socket /var/spool/myapp.sock
$ cheroot myapp.wsgi --bind /var/spool/myapp.sock
$ # Start a server on the abstract UNIX socket CherootServer
$ cheroot myapp.wsgi --bind @CherootServer
.. spelling::
cli
"""
import argparse
from importlib import import_module
import os
import sys
import contextlib
import six
from . import server
from . import wsgi
from ._compat import suppress
__metaclass__ = type
@ -49,6 +55,7 @@ class TCPSocket(BindLocation):
Args:
address (str): Host name or IP address
port (int): TCP port number
"""
self.bind_addr = address, port
@ -64,9 +71,9 @@ class UnixSocket(BindLocation):
class AbstractSocket(BindLocation):
"""AbstractSocket."""
def __init__(self, addr):
def __init__(self, abstract_socket):
"""Initialize."""
self.bind_addr = '\0{}'.format(self.abstract_socket)
self.bind_addr = '\x00{sock_path}'.format(sock_path=abstract_socket)
class Application:
@ -77,8 +84,8 @@ class Application:
"""Read WSGI app/Gateway path string and import application module."""
mod_path, _, app_path = full_path.partition(':')
app = getattr(import_module(mod_path), app_path or 'application')
with contextlib.suppress(TypeError):
# suppress the `TypeError` exception, just in case `app` is not a class
with suppress(TypeError):
if issubclass(app, server.Gateway):
return GatewayYo(app)
@ -128,8 +135,17 @@ class GatewayYo:
def parse_wsgi_bind_location(bind_addr_string):
"""Convert bind address string to a BindLocation."""
# if the string begins with an @ symbol, use an abstract socket,
# this is the first condition to verify, otherwise the urlparse
# validation would detect //@<value> as a valid url with a hostname
# with value: "<value>" and port: None
if bind_addr_string.startswith('@'):
return AbstractSocket(bind_addr_string[1:])
# try and match for an IP/hostname and port
match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string))
match = six.moves.urllib.parse.urlparse(
'//{addr}'.format(addr=bind_addr_string),
)
try:
addr = match.hostname
port = match.port
@ -139,9 +155,6 @@ def parse_wsgi_bind_location(bind_addr_string):
pass
# else, assume a UNIX socket path
# if the string begins with an @ symbol, use an abstract socket
if bind_addr_string.startswith('@'):
return AbstractSocket(bind_addr_string[1:])
return UnixSocket(path=bind_addr_string)
@ -151,70 +164,70 @@ def parse_wsgi_bind_addr(bind_addr_string):
_arg_spec = {
'_wsgi_app': dict(
metavar='APP_MODULE',
type=Application.resolve,
help='WSGI application callable or cheroot.server.Gateway subclass',
),
'--bind': dict(
metavar='ADDRESS',
dest='bind_addr',
type=parse_wsgi_bind_addr,
default='[::1]:8000',
help='Network interface to listen on (default: [::1]:8000)',
),
'--chdir': dict(
metavar='PATH',
type=os.chdir,
help='Set the working directory',
),
'--server-name': dict(
dest='server_name',
type=str,
help='Web server name to be advertised via Server HTTP header',
),
'--threads': dict(
metavar='INT',
dest='numthreads',
type=int,
help='Minimum number of worker threads',
),
'--max-threads': dict(
metavar='INT',
dest='max',
type=int,
help='Maximum number of worker threads',
),
'--timeout': dict(
metavar='INT',
dest='timeout',
type=int,
help='Timeout in seconds for accepted connections',
),
'--shutdown-timeout': dict(
metavar='INT',
dest='shutdown_timeout',
type=int,
help='Time in seconds to wait for worker threads to cleanly exit',
),
'--request-queue-size': dict(
metavar='INT',
dest='request_queue_size',
type=int,
help='Maximum number of queued connections',
),
'--accepted-queue-size': dict(
metavar='INT',
dest='accepted_queue_size',
type=int,
help='Maximum number of active requests in queue',
),
'--accepted-queue-timeout': dict(
metavar='INT',
dest='accepted_queue_timeout',
type=int,
help='Timeout in seconds for putting requests into queue',
),
'_wsgi_app': {
'metavar': 'APP_MODULE',
'type': Application.resolve,
'help': 'WSGI application callable or cheroot.server.Gateway subclass',
},
'--bind': {
'metavar': 'ADDRESS',
'dest': 'bind_addr',
'type': parse_wsgi_bind_addr,
'default': '[::1]:8000',
'help': 'Network interface to listen on (default: [::1]:8000)',
},
'--chdir': {
'metavar': 'PATH',
'type': os.chdir,
'help': 'Set the working directory',
},
'--server-name': {
'dest': 'server_name',
'type': str,
'help': 'Web server name to be advertised via Server HTTP header',
},
'--threads': {
'metavar': 'INT',
'dest': 'numthreads',
'type': int,
'help': 'Minimum number of worker threads',
},
'--max-threads': {
'metavar': 'INT',
'dest': 'max',
'type': int,
'help': 'Maximum number of worker threads',
},
'--timeout': {
'metavar': 'INT',
'dest': 'timeout',
'type': int,
'help': 'Timeout in seconds for accepted connections',
},
'--shutdown-timeout': {
'metavar': 'INT',
'dest': 'shutdown_timeout',
'type': int,
'help': 'Time in seconds to wait for worker threads to cleanly exit',
},
'--request-queue-size': {
'metavar': 'INT',
'dest': 'request_queue_size',
'type': int,
'help': 'Maximum number of queued connections',
},
'--accepted-queue-size': {
'metavar': 'INT',
'dest': 'accepted_queue_size',
'type': int,
'help': 'Maximum number of active requests in queue',
},
'--accepted-queue-timeout': {
'metavar': 'INT',
'dest': 'accepted_queue_timeout',
'type': int,
'help': 'Timeout in seconds for putting requests into queue',
},
}

View file

@ -5,11 +5,13 @@ __metaclass__ = type
import io
import os
import select
import socket
import threading
import time
from . import errors
from ._compat import selectors
from ._compat import suppress
from .makefile import MakeFile
import six
@ -47,6 +49,69 @@ else:
fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC)
class _ThreadsafeSelector:
"""Thread-safe wrapper around a DefaultSelector.
There are 2 thread contexts in which it may be accessed:
* the selector thread
* one of the worker threads in workers/threadpool.py
The expected read/write patterns are:
* :py:func:`~iter`: selector thread
* :py:meth:`register`: selector thread and threadpool,
via :py:meth:`~cheroot.workers.threadpool.ThreadPool.put`
* :py:meth:`unregister`: selector thread only
Notably, this means :py:class:`_ThreadsafeSelector` never needs to worry
that connections will be removed behind its back.
The lock is held when iterating or modifying the selector but is not
required when :py:meth:`select()ing <selectors.BaseSelector.select>` on it.
"""
def __init__(self):
self._selector = selectors.DefaultSelector()
self._lock = threading.Lock()
def __len__(self):
with self._lock:
return len(self._selector.get_map() or {})
@property
def connections(self):
"""Retrieve connections registered with the selector."""
with self._lock:
mapping = self._selector.get_map() or {}
for _, (_, sock_fd, _, conn) in mapping.items():
yield (sock_fd, conn)
def register(self, fileobj, events, data=None):
"""Register ``fileobj`` with the selector."""
with self._lock:
return self._selector.register(fileobj, events, data)
def unregister(self, fileobj):
"""Unregister ``fileobj`` from the selector."""
with self._lock:
return self._selector.unregister(fileobj)
def select(self, timeout=None):
"""Return socket fd and data pairs from selectors.select call.
Returns entries ready to read in the form:
(socket_file_descriptor, connection)
"""
return (
(key.fd, key.data)
for key, _ in self._selector.select(timeout=timeout)
)
def close(self):
"""Close the selector."""
with self._lock:
self._selector.close()
class ConnectionManager:
"""Class which manages HTTPConnection objects.
@ -60,21 +125,34 @@ class ConnectionManager:
server (cheroot.server.HTTPServer): web server object
that uses this ConnectionManager instance.
"""
self._serving = False
self._stop_requested = False
self.server = server
self.connections = []
self._selector = _ThreadsafeSelector()
self._selector.register(
server.socket.fileno(),
selectors.EVENT_READ, data=server,
)
def put(self, conn):
"""Put idle connection into the ConnectionManager to be managed.
Args:
conn (cheroot.server.HTTPConnection): HTTP connection
to be managed.
:param conn: HTTP connection to be managed
:type conn: cheroot.server.HTTPConnection
"""
conn.last_used = time.time()
conn.ready_with_data = conn.rfile.has_data()
self.connections.append(conn)
# if this conn doesn't have any more data waiting to be read,
# register it with the selector.
if conn.rfile.has_data():
self.server.process_conn(conn)
else:
self._selector.register(
conn.socket.fileno(), selectors.EVENT_READ, data=conn,
)
def expire(self):
def _expire(self):
"""Expire least recently used connections.
This happens if there are either too many open connections, or if the
@ -82,107 +160,102 @@ class ConnectionManager:
This should be called periodically.
"""
if not self.connections:
return
# find any connections still registered with the selector
# that have not been active recently enough.
threshold = time.time() - self.server.timeout
timed_out_connections = [
(sock_fd, conn)
for (sock_fd, conn) in self._selector.connections
if conn != self.server and conn.last_used < threshold
]
for sock_fd, conn in timed_out_connections:
self._selector.unregister(sock_fd)
conn.close()
# Look at the first connection - if it can be closed, then do
# that, and wait for get_conn to return it.
conn = self.connections[0]
if conn.closeable:
return
def stop(self):
"""Stop the selector loop in run() synchronously.
# Too many connections?
ka_limit = self.server.keep_alive_conn_limit
if ka_limit is not None and len(self.connections) > ka_limit:
conn.closeable = True
return
May take up to half a second.
"""
self._stop_requested = True
while self._serving:
time.sleep(0.01)
# Connection too old?
if (conn.last_used + self.server.timeout) < time.time():
conn.closeable = True
return
def get_conn(self, server_socket):
"""Return a HTTPConnection object which is ready to be handled.
A connection returned by this method should be ready for a worker
to handle it. If there are no connections ready, None will be
returned.
Any connection returned by this method will need to be `put`
back if it should be examined again for another request.
def run(self, expiration_interval):
"""Run the connections selector indefinitely.
Args:
server_socket (socket.socket): Socket to listen to for new
connections.
Returns:
cheroot.server.HTTPConnection instance, or None.
expiration_interval (float): Interval, in seconds, at which
connections will be checked for expiration.
Connections that are ready to process are submitted via
self.server.process_conn()
Connections submitted for processing must be `put()`
back if they should be examined again for another request.
Can be shut down by calling `stop()`.
"""
# Grab file descriptors from sockets, but stop if we find a
# connection which is already marked as ready.
socket_dict = {}
for conn in self.connections:
if conn.closeable or conn.ready_with_data:
break
socket_dict[conn.socket.fileno()] = conn
else:
# No ready connection.
conn = None
# We have a connection ready for use.
if conn:
self.connections.remove(conn)
return conn
# Will require a select call.
ss_fileno = server_socket.fileno()
socket_dict[ss_fileno] = server_socket
self._serving = True
try:
rlist, _, _ = select.select(list(socket_dict), [], [], 0.1)
# No available socket.
if not rlist:
return None
self._run(expiration_interval)
finally:
self._serving = False
def _run(self, expiration_interval):
last_expiration_check = time.time()
while not self._stop_requested:
try:
active_list = self._selector.select(timeout=0.01)
except OSError:
# Mark any connection which no longer appears valid.
for fno, conn in list(socket_dict.items()):
# If the server socket is invalid, we'll just ignore it and
# wait to be shutdown.
if fno == ss_fileno:
self._remove_invalid_sockets()
continue
try:
os.fstat(fno)
except OSError:
# Socket is invalid, close the connection, insert at
# the front.
self.connections.remove(conn)
self.connections.insert(0, conn)
conn.closeable = True
# Wait for the next tick to occur.
return None
try:
# See if we have a new connection coming in.
rlist.remove(ss_fileno)
except ValueError:
# No new connection, but reuse existing socket.
conn = socket_dict[rlist.pop()]
for (sock_fd, conn) in active_list:
if conn is self.server:
# New connection
new_conn = self._from_server_socket(self.server.socket)
if new_conn is not None:
self.server.process_conn(new_conn)
else:
conn = server_socket
# unregister connection from the selector until the server
# has read from it and returned it via put()
self._selector.unregister(sock_fd)
self.server.process_conn(conn)
# All remaining connections in rlist should be marked as ready.
for fno in rlist:
socket_dict[fno].ready_with_data = True
now = time.time()
if (now - last_expiration_check) > expiration_interval:
self._expire()
last_expiration_check = now
# New connection.
if conn is server_socket:
return self._from_server_socket(server_socket)
def _remove_invalid_sockets(self):
"""Clean up the resources of any broken connections.
self.connections.remove(conn)
return conn
This method attempts to detect any connections in an invalid state,
unregisters them from the selector and closes the file descriptors of
the corresponding network sockets where possible.
"""
invalid_conns = []
for sock_fd, conn in self._selector.connections:
if conn is self.server:
continue
def _from_server_socket(self, server_socket):
try:
os.fstat(sock_fd)
except OSError:
invalid_conns.append((sock_fd, conn))
for sock_fd, conn in invalid_conns:
self._selector.unregister(sock_fd)
# One of the reason on why a socket could cause an error
# is that the socket is already closed, ignore the
# socket error if we try to close it at this point.
# This is equivalent to OSError in Py3
with suppress(socket.error):
conn.close()
def _from_server_socket(self, server_socket): # noqa: C901 # FIXME
try:
s, addr = server_socket.accept()
if self.server.stats['Enabled']:
@ -274,6 +347,23 @@ class ConnectionManager:
def close(self):
"""Close all monitored connections."""
for conn in self.connections[:]:
for (_, conn) in self._selector.connections:
if conn is not self.server: # server closes its own socket
conn.close()
self.connections = []
self._selector.close()
@property
def _num_connections(self):
"""Return the current number of connections.
Includes all connections registered with the selector,
minus one for the server socket, which is always registered
with the selector.
"""
return len(self._selector) - 1
@property
def can_add_keepalive_connection(self):
"""Flag whether it is allowed to add a new keep-alive connection."""
ka_limit = self.server.keep_alive_conn_limit
return ka_limit is None or self._num_connections < ka_limit

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""Collection of exceptions raised and/or processed by Cheroot."""
from __future__ import absolute_import, division, print_function
@ -23,14 +24,14 @@ class FatalSSLAlert(Exception):
def plat_specific_errors(*errnames):
"""Return error numbers for all errors in errnames on this platform.
"""Return error numbers for all errors in ``errnames`` on this platform.
The 'errno' module contains different global constants depending on
the specific platform (OS). This function will return the list of
numeric values for a given list of potential names.
The :py:mod:`errno` module contains different global constants
depending on the specific platform (OS). This function will return
the list of numeric values for a given list of potential names.
"""
missing_attr = set([None, ])
unique_nums = set(getattr(errno, k, None) for k in errnames)
missing_attr = {None}
unique_nums = {getattr(errno, k, None) for k in errnames}
return list(unique_nums - missing_attr)
@ -56,3 +57,32 @@ socket_errors_nonblocking = plat_specific_errors(
if sys.platform == 'darwin':
socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE'))
socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE'))
acceptable_sock_shutdown_error_codes = {
errno.ENOTCONN,
errno.EPIPE, errno.ESHUTDOWN, # corresponds to BrokenPipeError in Python 3
errno.ECONNRESET, # corresponds to ConnectionResetError in Python 3
}
"""Errors that may happen during the connection close sequence.
* ENOTCONN client is no longer connected
* EPIPE write on a pipe while the other end has been closed
* ESHUTDOWN write on a socket which has been shutdown for writing
* ECONNRESET connection is reset by the peer, we received a TCP RST packet
Refs:
* https://github.com/cherrypy/cheroot/issues/341#issuecomment-735884889
* https://bugs.python.org/issue30319
* https://bugs.python.org/issue30329
* https://github.com/python/cpython/commit/83a2c28
* https://github.com/python/cpython/blob/c39b52f/Lib/poplib.py#L297-L302
* https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown
"""
try: # py3
acceptable_sock_shutdown_exceptions = (
BrokenPipeError, ConnectionResetError,
)
except NameError: # py2
acceptable_sock_shutdown_exceptions = ()

View file

@ -68,7 +68,7 @@ class MakeFile_PY2(getattr(socket, '_fileobject', object)):
self._refcount -= 1
def write(self, data):
"""Sendall for non-blocking sockets."""
"""Send entire data contents for non-blocking sockets."""
bytes_sent = 0
data_mv = memoryview(data)
payload_size = len(data_mv)
@ -122,7 +122,7 @@ class MakeFile_PY2(getattr(socket, '_fileobject', object)):
# FauxSocket is no longer needed
del FauxSocket
if not _fileobject_uses_str_type:
if not _fileobject_uses_str_type: # noqa: C901 # FIXME
def read(self, size=-1):
"""Read data from the socket to buffer."""
# Use max, disallow tiny reads in a loop as they are very

View file

@ -7,12 +7,13 @@ sticking incoming connections onto a Queue::
server = HTTPServer(...)
server.start()
-> while True:
tick()
# This blocks until a request comes in:
child = socket.accept()
-> serve()
while ready:
_connections.run()
while not stop_requested:
child = socket.accept() # blocks until a request comes in
conn = HTTPConnection(child, ...)
server.requests.put(conn)
server.process_conn(conn) # adds conn to threadpool
Worker threads are kept in a pool and poll the Queue, popping off and then
handling each connection in turn. Each connection can consist of an arbitrary
@ -58,7 +59,6 @@ And now for a trivial doctest to exercise the test suite
>>> 'HTTPServer' in globals()
True
"""
from __future__ import absolute_import, division, print_function
@ -74,6 +74,8 @@ import time
import traceback as traceback_
import logging
import platform
import contextlib
import threading
try:
from functools import lru_cache
@ -93,6 +95,7 @@ from .makefile import MakeFile, StreamWriter
__all__ = (
'HTTPRequest', 'HTTPConnection', 'HTTPServer',
'HeaderReader', 'DropUnderscoreHeaderReader',
'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile',
'Gateway', 'get_ssl_adapter_class',
)
@ -156,7 +159,10 @@ EMPTY = b''
ASTERISK = b'*'
FORWARD_SLASH = b'/'
QUOTED_SLASH = b'%2F'
QUOTED_SLASH_REGEX = re.compile(b'(?i)' + QUOTED_SLASH)
QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH)))
_STOPPING_FOR_INTERRUPT = object() # sentinel used during shutdown
comma_separated_headers = [
@ -179,7 +185,7 @@ class HeaderReader:
Interface and default implementation.
"""
def __call__(self, rfile, hdict=None):
def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME
"""
Read headers from the given stream into the given header dict.
@ -248,15 +254,14 @@ class DropUnderscoreHeaderReader(HeaderReader):
class SizeCheckWrapper:
"""Wraps a file-like object, raising MaxSizeExceeded if too large."""
"""Wraps a file-like object, raising MaxSizeExceeded if too large.
:param rfile: ``file`` of a limited size
:param int maxlen: maximum length of the file being read
"""
def __init__(self, rfile, maxlen):
"""Initialize SizeCheckWrapper instance.
Args:
rfile (file): file of a limited size
maxlen (int): maximum length of the file being read
"""
"""Initialize SizeCheckWrapper instance."""
self.rfile = rfile
self.maxlen = maxlen
self.bytes_read = 0
@ -266,14 +271,12 @@ class SizeCheckWrapper:
raise errors.MaxSizeExceeded()
def read(self, size=None):
"""Read a chunk from rfile buffer and return it.
"""Read a chunk from ``rfile`` buffer and return it.
Args:
size (int): amount of data to read
Returns:
bytes: Chunk from rfile, limited by size if specified.
:param int size: amount of data to read
:returns: chunk from ``rfile``, limited by size if specified
:rtype: bytes
"""
data = self.rfile.read(size)
self.bytes_read += len(data)
@ -281,14 +284,12 @@ class SizeCheckWrapper:
return data
def readline(self, size=None):
"""Read a single line from rfile buffer and return it.
"""Read a single line from ``rfile`` buffer and return it.
Args:
size (int): minimum amount of data to read
Returns:
bytes: One line from rfile.
:param int size: minimum amount of data to read
:returns: one line from ``rfile``
:rtype: bytes
"""
if size is not None:
data = self.rfile.readline(size)
@ -309,14 +310,12 @@ class SizeCheckWrapper:
return EMPTY.join(res)
def readlines(self, sizehint=0):
"""Read all lines from rfile buffer and return them.
"""Read all lines from ``rfile`` buffer and return them.
Args:
sizehint (int): hint of minimum amount of data to read
Returns:
list[bytes]: Lines of bytes read from rfile.
:param int sizehint: hint of minimum amount of data to read
:returns: lines of bytes read from ``rfile``
:rtype: list[bytes]
"""
# Shamelessly stolen from StringIO
total = 0
@ -331,7 +330,7 @@ class SizeCheckWrapper:
return lines
def close(self):
"""Release resources allocated for rfile."""
"""Release resources allocated for ``rfile``."""
self.rfile.close()
def __iter__(self):
@ -349,28 +348,24 @@ class SizeCheckWrapper:
class KnownLengthRFile:
"""Wraps a file-like object, returning an empty string when exhausted."""
"""Wraps a file-like object, returning an empty string when exhausted.
:param rfile: ``file`` of a known size
:param int content_length: length of the file being read
"""
def __init__(self, rfile, content_length):
"""Initialize KnownLengthRFile instance.
Args:
rfile (file): file of a known size
content_length (int): length of the file being read
"""
"""Initialize KnownLengthRFile instance."""
self.rfile = rfile
self.remaining = content_length
def read(self, size=None):
"""Read a chunk from rfile buffer and return it.
"""Read a chunk from ``rfile`` buffer and return it.
Args:
size (int): amount of data to read
Returns:
bytes: Chunk from rfile, limited by size if specified.
:param int size: amount of data to read
:rtype: bytes
:returns: chunk from ``rfile``, limited by size if specified
"""
if self.remaining == 0:
return b''
@ -384,14 +379,12 @@ class KnownLengthRFile:
return data
def readline(self, size=None):
"""Read a single line from rfile buffer and return it.
"""Read a single line from ``rfile`` buffer and return it.
Args:
size (int): minimum amount of data to read
Returns:
bytes: One line from rfile.
:param int size: minimum amount of data to read
:returns: one line from ``rfile``
:rtype: bytes
"""
if self.remaining == 0:
return b''
@ -405,14 +398,12 @@ class KnownLengthRFile:
return data
def readlines(self, sizehint=0):
"""Read all lines from rfile buffer and return them.
"""Read all lines from ``rfile`` buffer and return them.
Args:
sizehint (int): hint of minimum amount of data to read
Returns:
list[bytes]: Lines of bytes read from rfile.
:param int sizehint: hint of minimum amount of data to read
:returns: lines of bytes read from ``rfile``
:rtype: list[bytes]
"""
# Shamelessly stolen from StringIO
total = 0
@ -427,7 +418,7 @@ class KnownLengthRFile:
return lines
def close(self):
"""Release resources allocated for rfile."""
"""Release resources allocated for ``rfile``."""
self.rfile.close()
def __iter__(self):
@ -449,16 +440,14 @@ class ChunkedRFile:
This class is intended to provide a conforming wsgi.input value for
request entities that have been encoded with the 'chunked' transfer
encoding.
:param rfile: file encoded with the 'chunked' transfer encoding
:param int maxlen: maximum length of the file being read
:param int bufsize: size of the buffer used to read the file
"""
def __init__(self, rfile, maxlen, bufsize=8192):
"""Initialize ChunkedRFile instance.
Args:
rfile (file): file encoded with the 'chunked' transfer encoding
maxlen (int): maximum length of the file being read
bufsize (int): size of the buffer used to read the file
"""
"""Initialize ChunkedRFile instance."""
self.rfile = rfile
self.maxlen = maxlen
self.bytes_read = 0
@ -484,7 +473,10 @@ class ChunkedRFile:
chunk_size = line.pop(0)
chunk_size = int(chunk_size, 16)
except ValueError:
raise ValueError('Bad chunked transfer size: ' + repr(chunk_size))
raise ValueError(
'Bad chunked transfer size: {chunk_size!r}'.
format(chunk_size=chunk_size),
)
if chunk_size <= 0:
self.closed = True
@ -507,14 +499,12 @@ class ChunkedRFile:
)
def read(self, size=None):
"""Read a chunk from rfile buffer and return it.
"""Read a chunk from ``rfile`` buffer and return it.
Args:
size (int): amount of data to read
Returns:
bytes: Chunk from rfile, limited by size if specified.
:param int size: amount of data to read
:returns: chunk from ``rfile``, limited by size if specified
:rtype: bytes
"""
data = EMPTY
@ -540,14 +530,12 @@ class ChunkedRFile:
self.buffer = EMPTY
def readline(self, size=None):
"""Read a single line from rfile buffer and return it.
"""Read a single line from ``rfile`` buffer and return it.
Args:
size (int): minimum amount of data to read
Returns:
bytes: One line from rfile.
:param int size: minimum amount of data to read
:returns: one line from ``rfile``
:rtype: bytes
"""
data = EMPTY
@ -583,14 +571,12 @@ class ChunkedRFile:
self.buffer = self.buffer[newline_pos:]
def readlines(self, sizehint=0):
"""Read all lines from rfile buffer and return them.
"""Read all lines from ``rfile`` buffer and return them.
Args:
sizehint (int): hint of minimum amount of data to read
Returns:
list[bytes]: Lines of bytes read from rfile.
:param int sizehint: hint of minimum amount of data to read
:returns: lines of bytes read from ``rfile``
:rtype: list[bytes]
"""
# Shamelessly stolen from StringIO
total = 0
@ -635,7 +621,7 @@ class ChunkedRFile:
yield line
def close(self):
"""Release resources allocated for rfile."""
"""Release resources allocated for ``rfile``."""
self.rfile.close()
@ -744,7 +730,7 @@ class HTTPRequest:
self.ready = True
def read_request_line(self):
def read_request_line(self): # noqa: C901 # FIXME
"""Read and parse first line of the HTTP request.
Returns:
@ -845,7 +831,7 @@ class HTTPRequest:
# `urlsplit()` above parses "example.com:3128" as path part of URI.
# this is a workaround, which makes it detect netloc correctly
uri_split = urllib.parse.urlsplit(b'//' + uri)
uri_split = urllib.parse.urlsplit(b''.join((b'//', uri)))
_scheme, _authority, _path, _qs, _fragment = uri_split
_port = EMPTY
try:
@ -975,8 +961,14 @@ class HTTPRequest:
return True
def read_request_headers(self):
"""Read self.rfile into self.inheaders. Return success."""
def read_request_headers(self): # noqa: C901 # FIXME
"""Read ``self.rfile`` into ``self.inheaders``.
Ref: :py:attr:`self.inheaders <HTTPRequest.outheaders>`.
:returns: success status
:rtype: bool
"""
# then all the http headers
try:
self.header_reader(self.rfile, self.inheaders)
@ -1054,8 +1046,10 @@ class HTTPRequest:
# Don't use simple_response here, because it emits headers
# we don't want. See
# https://github.com/cherrypy/cherrypy/issues/951
msg = self.server.protocol.encode('ascii')
msg += b' 100 Continue\r\n\r\n'
msg = b''.join((
self.server.protocol.encode('ascii'), SPACE, b'100 Continue',
CRLF, CRLF,
))
try:
self.conn.wfile.write(msg)
except socket.error as ex:
@ -1138,10 +1132,11 @@ class HTTPRequest:
else:
self.conn.wfile.write(chunk)
def send_headers(self):
def send_headers(self): # noqa: C901 # FIXME
"""Assert, process, and send the HTTP response message-headers.
You must set self.status, and self.outheaders before calling this.
You must set ``self.status``, and :py:attr:`self.outheaders
<HTTPRequest.outheaders>` before calling this.
"""
hkeys = [key.lower() for key, value in self.outheaders]
status = int(self.status[:3])
@ -1168,6 +1163,12 @@ class HTTPRequest:
# Closing the conn is the only way to determine len.
self.close_connection = True
# Override the decision to not close the connection if the connection
# manager doesn't have space for it.
if not self.close_connection:
can_keep = self.server.can_add_keepalive_connection
self.close_connection = not can_keep
if b'connection' not in hkeys:
if self.response_protocol == 'HTTP/1.1':
# Both server and client are HTTP/1.1 or better
@ -1178,6 +1179,14 @@ class HTTPRequest:
if not self.close_connection:
self.outheaders.append((b'Connection', b'Keep-Alive'))
if (b'Connection', b'Keep-Alive') in self.outheaders:
self.outheaders.append((
b'Keep-Alive',
u'timeout={connection_timeout}'.
format(connection_timeout=self.server.timeout).
encode('ISO-8859-1'),
))
if (not self.close_connection) and (not self.chunked_read):
# Read any remaining request body data on the socket.
# "If an origin server receives a request that does not include an
@ -1228,9 +1237,7 @@ class HTTPConnection:
peercreds_resolve_enabled = False
# Fields set by ConnectionManager.
closeable = False
last_used = None
ready_with_data = False
def __init__(self, server, sock, makefile=MakeFile):
"""Initialize HTTPConnection instance.
@ -1259,7 +1266,7 @@ class HTTPConnection:
lru_cache(maxsize=1)(self.get_peer_creds)
)
def communicate(self):
def communicate(self): # noqa: C901 # FIXME
"""Read each request and respond appropriately.
Returns true if the connection should be kept open.
@ -1351,6 +1358,9 @@ class HTTPConnection:
if not self.linger:
self._close_kernel_socket()
# close the socket file descriptor
# (will be closed in the OS if there is no
# other reference to the underlying socket)
self.socket.close()
else:
# On the other hand, sometimes we want to hang around for a bit
@ -1426,12 +1436,12 @@ class HTTPConnection:
return gid
def resolve_peer_creds(self): # LRU cached on per-instance basis
"""Return the username and group tuple of the peercreds if available.
"""Look up the username and group tuple of the ``PEERCREDS``.
Raises:
NotImplementedError: in case of unsupported OS
RuntimeError: in case of UID/GID lookup unsupported or disabled
:returns: the username and group tuple of the ``PEERCREDS``
:raises NotImplementedError: if the OS is unsupported
:raises RuntimeError: if UID/GID lookup is unsupported or disabled
"""
if not IS_UID_GID_RESOLVABLE:
raise NotImplementedError(
@ -1462,18 +1472,20 @@ class HTTPConnection:
return group
def _close_kernel_socket(self):
"""Close kernel socket in outdated Python versions.
"""Terminate the connection at the transport level."""
# Honor ``sock_shutdown`` for PyOpenSSL connections.
shutdown = getattr(
self.socket, 'sock_shutdown',
self.socket.shutdown,
)
On old Python versions,
Python's socket module does NOT call close on the kernel
socket when you call socket.close(). We do so manually here
because we want this server to send a FIN TCP segment
immediately. Note this must be called *before* calling
socket.close(), because the latter drops its reference to
the kernel socket.
"""
if six.PY2 and hasattr(self.socket, '_sock'):
self.socket._sock.close()
try:
shutdown(socket.SHUT_RDWR) # actually send a TCP FIN
except errors.acceptable_sock_shutdown_exceptions:
pass
except socket.error as e:
if e.errno not in errors.acceptable_sock_shutdown_error_codes:
raise
class HTTPServer:
@ -1515,7 +1527,12 @@ class HTTPServer:
timeout = 10
"""The timeout in seconds for accepted connections (default 10)."""
version = 'Cheroot/' + __version__
expiration_interval = 0.5
"""The interval, in seconds, at which the server checks for
expired connections (default 0.5).
"""
version = 'Cheroot/{version!s}'.format(version=__version__)
"""A version string for the HTTPServer."""
software = None
@ -1540,16 +1557,23 @@ class HTTPServer:
"""The class to use for handling HTTP connections."""
ssl_adapter = None
"""An instance of ssl.Adapter (or a subclass).
"""An instance of ``ssl.Adapter`` (or a subclass).
You must have the corresponding SSL driver library installed.
Ref: :py:class:`ssl.Adapter <cheroot.ssl.Adapter>`.
You must have the corresponding TLS driver library installed.
"""
peercreds_enabled = False
"""If True, peer cred lookup can be performed via UNIX domain socket."""
"""
If :py:data:`True`, peer creds will be looked up via UNIX domain socket.
"""
peercreds_resolve_enabled = False
"""If True, username/group will be looked up in the OS from peercreds."""
"""
If :py:data:`True`, username/group will be looked up in the OS from
``PEERCREDS``-provided IDs.
"""
keep_alive_conn_limit = 10
"""The maximum number of waiting keep-alive connections that will be kept open.
@ -1577,7 +1601,6 @@ class HTTPServer:
self.requests = threadpool.ThreadPool(
self, min=minthreads or 1, max=maxthreads,
)
self.connections = connections.ConnectionManager(self)
if not server_name:
server_name = self.version
@ -1603,25 +1626,29 @@ class HTTPServer:
'Threads Idle': lambda s: getattr(self.requests, 'idle', None),
'Socket Errors': 0,
'Requests': lambda s: (not s['Enabled']) and -1 or sum(
[w['Requests'](w) for w in s['Worker Threads'].values()], 0,
(w['Requests'](w) for w in s['Worker Threads'].values()), 0,
),
'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum(
[w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0,
(w['Bytes Read'](w) for w in s['Worker Threads'].values()), 0,
),
'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum(
[w['Bytes Written'](w) for w in s['Worker Threads'].values()],
(w['Bytes Written'](w) for w in s['Worker Threads'].values()),
0,
),
'Work Time': lambda s: (not s['Enabled']) and -1 or sum(
[w['Work Time'](w) for w in s['Worker Threads'].values()], 0,
(w['Work Time'](w) for w in s['Worker Threads'].values()), 0,
),
'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum(
[w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
for w in s['Worker Threads'].values()], 0,
(
w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
for w in s['Worker Threads'].values()
), 0,
),
'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum(
[w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
for w in s['Worker Threads'].values()], 0,
(
w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
for w in s['Worker Threads'].values()
), 0,
),
'Worker Threads': {},
}
@ -1645,17 +1672,27 @@ class HTTPServer:
def bind_addr(self):
"""Return the interface on which to listen for connections.
For TCP sockets, a (host, port) tuple. Host values may be any IPv4
or IPv6 address, or any valid hostname. The string 'localhost' is a
synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6).
The string '0.0.0.0' is a special IPv4 entry meaning "any active
interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for
IPv6. The empty string or None are not allowed.
For TCP sockets, a (host, port) tuple. Host values may be any
:term:`IPv4` or :term:`IPv6` address, or any valid hostname.
The string 'localhost' is a synonym for '127.0.0.1' (or '::1',
if your hosts file prefers :term:`IPv6`).
The string '0.0.0.0' is a special :term:`IPv4` entry meaning
"any active interface" (INADDR_ANY), and '::' is the similar
IN6ADDR_ANY for :term:`IPv6`.
The empty string or :py:data:`None` are not allowed.
For UNIX sockets, supply the filename as a string.
For UNIX sockets, supply the file name as a string.
Systemd socket activation is automatic and doesn't require tempering
with this variable.
.. glossary::
:abbr:`IPv4 (Internet Protocol version 4)`
Internet Protocol version 4
:abbr:`IPv6 (Internet Protocol version 6)`
Internet Protocol version 6
"""
return self._bind_addr
@ -1695,7 +1732,7 @@ class HTTPServer:
self.stop()
raise
def prepare(self):
def prepare(self): # noqa: C901 # FIXME
"""Prepare server to serving requests.
It binds a socket's port, setups the socket to ``listen()`` and does
@ -1757,6 +1794,9 @@ class HTTPServer:
self.socket.settimeout(1)
self.socket.listen(self.request_queue_size)
# must not be accessed once stop() has been called
self._connections = connections.ConnectionManager(self)
# Create worker threads
self.requests.start()
@ -1765,20 +1805,21 @@ class HTTPServer:
def serve(self):
"""Serve requests, after invoking :func:`prepare()`."""
while self.ready:
while self.ready and not self.interrupt:
try:
self.tick()
self._connections.run(self.expiration_interval)
except (KeyboardInterrupt, SystemExit):
raise
except Exception:
self.error_log(
'Error in HTTPServer.tick', level=logging.ERROR,
'Error in HTTPServer.serve', level=logging.ERROR,
traceback=True,
)
# raise exceptions reported by any worker threads,
# such that the exception is raised from the serve() thread.
if self.interrupt:
while self.interrupt is True:
# Wait for self.stop() to complete. See _set_interrupt.
while self._stopping_for_interrupt:
time.sleep(0.1)
if self.interrupt:
raise self.interrupt
@ -1795,6 +1836,31 @@ class HTTPServer:
self.prepare()
self.serve()
@contextlib.contextmanager
def _run_in_thread(self):
"""Context manager for running this server in a thread."""
self.prepare()
thread = threading.Thread(target=self.serve)
thread.setDaemon(True)
thread.start()
try:
yield thread
finally:
self.stop()
@property
def can_add_keepalive_connection(self):
"""Flag whether it is allowed to add a new keep-alive connection."""
return self.ready and self._connections.can_add_keepalive_connection
def put_conn(self, conn):
"""Put an idle connection back into the ConnectionManager."""
if self.ready:
self._connections.put(conn)
else:
# server is shutting down, just close it
conn.close()
def error_log(self, msg='', level=20, traceback=False):
"""Write error message to log.
@ -1804,7 +1870,7 @@ class HTTPServer:
traceback (bool): add traceback to output or not
"""
# Override this in subclasses as desired
sys.stderr.write(msg + '\n')
sys.stderr.write('{msg!s}\n'.format(msg=msg))
sys.stderr.flush()
if traceback:
tblines = traceback_.format_exc()
@ -1822,7 +1888,7 @@ class HTTPServer:
self.bind_addr = self.resolve_real_bind_addr(sock)
return sock
def bind_unix_socket(self, bind_addr):
def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME
"""Create (or recreate) a UNIX socket object."""
if IS_WINDOWS:
"""
@ -1965,7 +2031,7 @@ class HTTPServer:
@staticmethod
def resolve_real_bind_addr(socket_):
"""Retrieve actual bind addr from bound socket."""
"""Retrieve actual bind address from bound socket."""
# FIXME: keep requested bind_addr separate real bound_addr (port
# is different in case of ephemeral port 0)
bind_addr = socket_.getsockname()
@ -1985,40 +2051,49 @@ class HTTPServer:
return bind_addr
def tick(self):
"""Accept a new connection and put it on the Queue."""
if not self.ready:
return
conn = self.connections.get_conn(self.socket)
if conn:
def process_conn(self, conn):
"""Process an incoming HTTPConnection."""
try:
self.requests.put(conn)
except queue.Full:
# Just drop the conn. TODO: write 503 back?
conn.close()
self.connections.expire()
@property
def interrupt(self):
"""Flag interrupt of the server."""
return self._interrupt
@property
def _stopping_for_interrupt(self):
"""Return whether the server is responding to an interrupt."""
return self._interrupt is _STOPPING_FOR_INTERRUPT
@interrupt.setter
def interrupt(self, interrupt):
"""Perform the shutdown of this server and save the exception."""
self._interrupt = True
"""Perform the shutdown of this server and save the exception.
Typically invoked by a worker thread in
:py:mod:`~cheroot.workers.threadpool`, the exception is raised
from the thread running :py:meth:`serve` once :py:meth:`stop`
has completed.
"""
self._interrupt = _STOPPING_FOR_INTERRUPT
self.stop()
self._interrupt = interrupt
def stop(self):
def stop(self): # noqa: C901 # FIXME
"""Gracefully shutdown a server that is serving forever."""
if not self.ready:
return # already stopped
self.ready = False
if self._start_time is not None:
self._run_time += (time.time() - self._start_time)
self._start_time = None
self._connections.stop()
sock = getattr(self, 'socket', None)
if sock:
if not isinstance(
@ -2060,7 +2135,7 @@ class HTTPServer:
sock.close()
self.socket = None
self.connections.close()
self._connections.close()
self.requests.stop(self.shutdown_timeout)
@ -2108,7 +2183,9 @@ def get_ssl_adapter_class(name='builtin'):
try:
adapter = getattr(mod, attr_name)
except AttributeError:
raise AttributeError("'%s' object has no attribute '%s'"
% (mod_path, attr_name))
raise AttributeError(
"'%s' object has no attribute '%s'"
% (mod_path, attr_name),
)
return adapter

View file

@ -1,7 +1,7 @@
"""
A library for integrating Python's builtin ``ssl`` library with Cheroot.
A library for integrating Python's builtin :py:mod:`ssl` library with Cheroot.
The ssl module must be importable for SSL functionality.
The :py:mod:`ssl` module must be importable for SSL functionality.
To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
``BuiltinSSLAdapter``.
@ -10,6 +10,10 @@ To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket
import sys
import threading
try:
import ssl
except ImportError:
@ -27,13 +31,12 @@ import six
from . import Adapter
from .. import errors
from .._compat import IS_ABOVE_OPENSSL10
from .._compat import IS_ABOVE_OPENSSL10, suppress
from ..makefile import StreamReader, StreamWriter
from ..server import HTTPServer
if six.PY2:
import socket
generic_socket_error = socket.error
del socket
else:
generic_socket_error = OSError
@ -49,37 +52,159 @@ def _assert_ssl_exc_contains(exc, *msgs):
return any(m.lower() in err_msg_lower for m in msgs)
def _loopback_for_cert_thread(context, server):
"""Wrap a socket in ssl and perform the server-side handshake."""
# As we only care about parsing the certificate, the failure of
# which will cause an exception in ``_loopback_for_cert``,
# we can safely ignore connection and ssl related exceptions. Ref:
# https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030
with suppress(ssl.SSLError, OSError):
with context.wrap_socket(
server, do_handshake_on_connect=True, server_side=True,
) as ssl_sock:
# in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server
# sends the client session tickets that can be used to
# resume the TLS session on a new connection without
# performing the full handshake again. session tickets are
# sent as a post-handshake message at some _unspecified_
# time and thus a successful connection may be closed
# without the client having received the tickets.
# Unfortunately, on Windows (Python 3.8+), this is treated
# as an incomplete handshake on the server side and a
# ``ConnectionAbortedError`` is raised.
# TLS 1.3 support is still incomplete in Python 3.8;
# there is no way for the client to wait for tickets.
# While not necessary for retrieving the parsed certificate,
# we send a tiny bit of data over the connection in an
# attempt to give the server a chance to send the session
# tickets and close the connection cleanly.
# Note that, as this is essentially a race condition,
# the error may still occur ocasionally.
ssl_sock.send(b'0000')
def _loopback_for_cert(certificate, private_key, certificate_chain):
"""Create a loopback connection to parse a cert with a private key."""
context = ssl.create_default_context(cafile=certificate_chain)
context.load_cert_chain(certificate, private_key)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
# Python 3+ Unix, Python 3.5+ Windows
client, server = socket.socketpair()
try:
# `wrap_socket` will block until the ssl handshake is complete.
# it must be called on both ends at the same time -> thread
# openssl will cache the peer's cert during a successful handshake
# and return it via `getpeercert` even after the socket is closed.
# when `close` is called, the SSL shutdown notice will be sent
# and then python will wait to receive the corollary shutdown.
thread = threading.Thread(
target=_loopback_for_cert_thread, args=(context, server),
)
try:
thread.start()
with context.wrap_socket(
client, do_handshake_on_connect=True,
server_side=False,
) as ssl_sock:
ssl_sock.recv(4)
return ssl_sock.getpeercert()
finally:
thread.join()
finally:
client.close()
server.close()
def _parse_cert(certificate, private_key, certificate_chain):
"""Parse a certificate."""
# loopback_for_cert uses socket.socketpair which was only
# introduced in Python 3.0 for *nix and 3.5 for Windows
# and requires OS support (AttributeError, OSError)
# it also requires a private key either in its own file
# or combined with the cert (SSLError)
with suppress(AttributeError, ssl.SSLError, OSError):
return _loopback_for_cert(certificate, private_key, certificate_chain)
# KLUDGE: using an undocumented, private, test method to parse a cert
# unfortunately, it is the only built-in way without a connection
# as a private, undocumented method, it may change at any time
# so be tolerant of *any* possible errors it may raise
with suppress(Exception):
return ssl._ssl._test_decode_cert(certificate)
return {}
def _sni_callback(sock, sni, context):
"""Handle the SNI callback to tag the socket with the SNI."""
sock.sni = sni
# return None to allow the TLS negotiation to continue
class BuiltinSSLAdapter(Adapter):
"""A wrapper for integrating Python's builtin ssl module with Cheroot."""
"""Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot."""
certificate = None
"""The filename of the server SSL certificate."""
"""The file name of the server SSL certificate."""
private_key = None
"""The filename of the server's private key file."""
"""The file name of the server's private key file."""
certificate_chain = None
"""The filename of the certificate chain file."""
context = None
"""The ssl.SSLContext that will be used to wrap sockets."""
"""The file name of the certificate chain file."""
ciphers = None
"""The ciphers list of SSL."""
# from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert
CERT_KEY_TO_ENV = {
'subject': 'SSL_CLIENT_S_DN',
'issuer': 'SSL_CLIENT_I_DN',
'version': 'M_VERSION',
'serialNumber': 'M_SERIAL',
'notBefore': 'V_START',
'notAfter': 'V_END',
'subject': 'S_DN',
'issuer': 'I_DN',
'subjectAltName': 'SAN',
# not parsed by the Python standard library
# - A_SIG
# - A_KEY
# not provided by mod_ssl
# - OCSP
# - caIssuers
# - crlDistributionPoints
}
# from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert_dn_rec
CERT_KEY_TO_LDAP_CODE = {
'countryName': 'C',
'stateOrProvinceName': 'ST',
# NOTE: mod_ssl also provides 'stateOrProvinceName' as 'SP'
# for compatibility with SSLeay
'localityName': 'L',
'organizationName': 'O',
'organizationalUnitName': 'OU',
'commonName': 'CN',
'title': 'T',
'initials': 'I',
'givenName': 'G',
'surname': 'S',
'description': 'D',
'userid': 'UID',
'emailAddress': 'Email',
# not provided by mod_ssl
# - dnQualifier: DNQ
# - domainComponent: DC
# - postalCode: PC
# - streetAddress: STREET
# - serialNumber
# - generationQualifier
# - pseudonym
# - jurisdictionCountryName
# - jurisdictionLocalityName
# - jurisdictionStateOrProvince
# - businessCategory
}
def __init__(
@ -102,6 +227,45 @@ class BuiltinSSLAdapter(Adapter):
if self.ciphers is not None:
self.context.set_ciphers(ciphers)
self._server_env = self._make_env_cert_dict(
'SSL_SERVER',
_parse_cert(certificate, private_key, self.certificate_chain),
)
if not self._server_env:
return
cert = None
with open(certificate, mode='rt') as f:
cert = f.read()
# strip off any keys by only taking the first certificate
cert_start = cert.find(ssl.PEM_HEADER)
if cert_start == -1:
return
cert_end = cert.find(ssl.PEM_FOOTER, cert_start)
if cert_end == -1:
return
cert_end += len(ssl.PEM_FOOTER)
self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end]
@property
def context(self):
""":py:class:`~ssl.SSLContext` that will be used to wrap sockets."""
return self._context
@context.setter
def context(self, context):
"""Set the ssl ``context`` to use."""
self._context = context
# Python 3.7+
# if a context is provided via `cherrypy.config.update` then
# `self.context` will be set after `__init__`
# use a property to intercept it to add an SNI callback
# but don't override the user's callback
# TODO: chain callbacks
with suppress(AttributeError):
if ssl.HAS_SNI and context.sni_callback is None:
context.sni_callback = _sni_callback
def bind(self, sock):
"""Wrap and return the given socket."""
return super(BuiltinSSLAdapter, self).bind(sock)
@ -135,6 +299,8 @@ class BuiltinSSLAdapter(Adapter):
'no shared cipher', 'certificate unknown',
'ccs received early',
'certificate verify failed', # client cert w/o trusted CA
'version too low', # caused by SSL3 connections
'unsupported protocol', # caused by TLS1 connections
)
if _assert_ssl_exc_contains(ex, *_block_errors):
# Accepted error, let's pass
@ -148,7 +314,8 @@ class BuiltinSSLAdapter(Adapter):
except generic_socket_error as exc:
"""It is unclear why exactly this happens.
It's reproducible only with openssl>1.0 and stdlib ``ssl`` wrapper.
It's reproducible only with openssl>1.0 and stdlib
:py:mod:`ssl` wrapper.
In CherryPy it's triggered by Checker plugin, which connects
to the app listening to the socket port in TLS mode via plain
HTTP during startup (from the same process).
@ -163,7 +330,6 @@ class BuiltinSSLAdapter(Adapter):
raise
return s, self.get_environ(s)
# TODO: fill this out more with mod ssl env
def get_environ(self, sock):
"""Create WSGI environ entries to be merged into each request."""
cipher = sock.cipher()
@ -172,22 +338,117 @@ class BuiltinSSLAdapter(Adapter):
'HTTPS': 'on',
'SSL_PROTOCOL': cipher[1],
'SSL_CIPHER': cipher[0],
# SSL_VERSION_INTERFACE string The mod_ssl program version
# SSL_VERSION_LIBRARY string The OpenSSL program version
'SSL_CIPHER_EXPORT': '',
'SSL_CIPHER_USEKEYSIZE': cipher[2],
'SSL_VERSION_INTERFACE': '%s Python/%s' % (
HTTPServer.version, sys.version,
),
'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION,
'SSL_CLIENT_VERIFY': 'NONE',
# 'NONE' - client did not provide a cert (overriden below)
}
# Python 3.3+
with suppress(AttributeError):
compression = sock.compression()
if compression is not None:
ssl_environ['SSL_COMPRESS_METHOD'] = compression
# Python 3.6+
with suppress(AttributeError):
ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex()
with suppress(AttributeError):
target_cipher = cipher[:2]
for cip in sock.context.get_ciphers():
if target_cipher == (cip['name'], cip['protocol']):
ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits']
break
# Python 3.7+ sni_callback
with suppress(AttributeError):
ssl_environ['SSL_TLS_SNI'] = sock.sni
if self.context and self.context.verify_mode != ssl.CERT_NONE:
client_cert = sock.getpeercert()
if client_cert:
for cert_key, env_var in self.CERT_KEY_TO_ENV.items():
# builtin ssl **ALWAYS** validates client certificates
# and terminates the connection on failure
ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS'
ssl_environ.update(
self.env_dn_dict(env_var, client_cert.get(cert_key)),
self._make_env_cert_dict('SSL_CLIENT', client_cert),
)
ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert(
sock.getpeercert(binary_form=True),
).strip()
ssl_environ.update(self._server_env)
# not supplied by the Python standard library (as of 3.8)
# - SSL_SESSION_RESUMED
# - SSL_SECURE_RENEG
# - SSL_CLIENT_CERT_CHAIN_n
# - SRP_USER
# - SRP_USERINFO
return ssl_environ
def env_dn_dict(self, env_prefix, cert_value):
"""Return a dict of WSGI environment variables for a client cert DN.
def _make_env_cert_dict(self, env_prefix, parsed_cert):
"""Return a dict of WSGI environment variables for a certificate.
E.g. SSL_CLIENT_M_VERSION, SSL_CLIENT_M_SERIAL, etc.
See https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars.
"""
if not parsed_cert:
return {}
env = {}
for cert_key, env_var in self.CERT_KEY_TO_ENV.items():
key = '%s_%s' % (env_prefix, env_var)
value = parsed_cert.get(cert_key)
if env_var == 'SAN':
env.update(self._make_env_san_dict(key, value))
elif env_var.endswith('_DN'):
env.update(self._make_env_dn_dict(key, value))
else:
env[key] = str(value)
# mod_ssl 2.1+; Python 3.2+
# number of days until the certificate expires
if 'notBefore' in parsed_cert:
remain = ssl.cert_time_to_seconds(parsed_cert['notAfter'])
remain -= ssl.cert_time_to_seconds(parsed_cert['notBefore'])
remain /= 60 * 60 * 24
env['%s_V_REMAIN' % (env_prefix,)] = str(int(remain))
return env
def _make_env_san_dict(self, env_prefix, cert_value):
"""Return a dict of WSGI environment variables for a certificate DN.
E.g. SSL_CLIENT_SAN_Email_0, SSL_CLIENT_SAN_DNS_0, etc.
See SSL_CLIENT_SAN_* at
https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars.
"""
if not cert_value:
return {}
env = {}
dns_count = 0
email_count = 0
for attr_name, val in cert_value:
if attr_name == 'DNS':
env['%s_DNS_%i' % (env_prefix, dns_count)] = val
dns_count += 1
elif attr_name == 'Email':
env['%s_Email_%i' % (env_prefix, email_count)] = val
email_count += 1
# other mod_ssl SAN vars:
# - SAN_OTHER_msUPN_n
return env
def _make_env_dn_dict(self, env_prefix, cert_value):
"""Return a dict of WSGI environment variables for a certificate DN.
E.g. SSL_CLIENT_S_DN_CN, SSL_CLIENT_S_DN_C, etc.
See SSL_CLIENT_S_DN_x509 at
@ -196,12 +457,26 @@ class BuiltinSSLAdapter(Adapter):
if not cert_value:
return {}
env = {}
dn = []
dn_attrs = {}
for rdn in cert_value:
for attr_name, val in rdn:
attr_code = self.CERT_KEY_TO_LDAP_CODE.get(attr_name)
if attr_code:
env['%s_%s' % (env_prefix, attr_code)] = val
dn.append('%s=%s' % (attr_code or attr_name, val))
if not attr_code:
continue
dn_attrs.setdefault(attr_code, [])
dn_attrs[attr_code].append(val)
env = {
env_prefix: ','.join(dn),
}
for attr_code, values in dn_attrs.items():
env['%s_%s' % (env_prefix, attr_code)] = ','.join(values)
if len(values) == 1:
continue
for i, val in enumerate(values):
env['%s_%s_%i' % (env_prefix, attr_code, i)] = val
return env
def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):

View file

@ -1,46 +1,67 @@
"""
A library for integrating pyOpenSSL with Cheroot.
A library for integrating :doc:`pyOpenSSL <pyopenssl:index>` with Cheroot.
The OpenSSL module must be importable for SSL functionality.
You can obtain it from `here <https://launchpad.net/pyopenssl>`_.
The :py:mod:`OpenSSL <pyopenssl:OpenSSL>` module must be importable
for SSL/TLS/HTTPS functionality.
You can obtain it from `here <https://github.com/pyca/pyopenssl>`_.
To use this module, set HTTPServer.ssl_adapter to an instance of
ssl.Adapter. There are two ways to use SSL:
To use this module, set :py:attr:`HTTPServer.ssl_adapter
<cheroot.server.HTTPServer.ssl_adapter>` to an instance of
:py:class:`ssl.Adapter <cheroot.ssl.Adapter>`.
There are two ways to use :abbr:`TLS (Transport-Level Security)`:
Method One
----------
* ``ssl_adapter.context``: an instance of SSL.Context.
* :py:attr:`ssl_adapter.context
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.context>`: an instance of
:py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`.
If this is not None, it is assumed to be an SSL.Context instance,
and will be passed to SSL.Connection on bind(). The developer is
responsible for forming a valid Context object. This approach is
to be preferred for more flexibility, e.g. if the cert and key are
streams instead of files, or need decryption, or SSL.SSLv3_METHOD
is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
the pyOpenSSL documentation for complete options.
If this is not None, it is assumed to be an :py:class:`SSL.Context
<pyopenssl:OpenSSL.SSL.Context>` instance, and will be passed to
:py:class:`SSL.Connection <pyopenssl:OpenSSL.SSL.Connection>` on bind().
The developer is responsible for forming a valid :py:class:`Context
<pyopenssl:OpenSSL.SSL.Context>` object. This
approach is to be preferred for more flexibility, e.g. if the cert and
key are streams instead of files, or need decryption, or
:py:data:`SSL.SSLv3_METHOD <pyopenssl:OpenSSL.SSL.SSLv3_METHOD>`
is desired instead of the default :py:data:`SSL.SSLv23_METHOD
<pyopenssl:OpenSSL.SSL.SSLv3_METHOD>`, etc. Consult
the :doc:`pyOpenSSL <pyopenssl:api/ssl>` documentation for
complete options.
Method Two (shortcut)
---------------------
* ``ssl_adapter.certificate``: the filename of the server SSL certificate.
* ``ssl_adapter.private_key``: the filename of the server's private key file.
* :py:attr:`ssl_adapter.certificate
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.certificate>`: the file name
of the server's TLS certificate.
* :py:attr:`ssl_adapter.private_key
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.private_key>`: the file name
of the server's private key file.
Both are None by default. If ssl_adapter.context is None, but .private_key
and .certificate are both given and valid, they will be read, and the
context will be automatically created from them.
Both are :py:data:`None` by default. If :py:attr:`ssl_adapter.context
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.context>` is :py:data:`None`,
but ``.private_key`` and ``.certificate`` are both given and valid, they
will be read, and the context will be automatically created from them.
.. spelling::
pyopenssl
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket
import sys
import threading
import time
import six
try:
import OpenSSL.version
from OpenSSL import SSL
from OpenSSL import crypto
@ -57,13 +78,14 @@ from ..makefile import StreamReader, StreamWriter
class SSLFileobjectMixin:
"""Base mixin for an SSL socket stream."""
"""Base mixin for a TLS socket stream."""
ssl_timeout = 3
ssl_retry = .01
def _safe_call(self, is_reader, call, *args, **kwargs):
"""Wrap the given call with SSL error-trapping.
# FIXME:
def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901
"""Wrap the given call with TLS error-trapping.
is_reader: if False EOF errors will be raised. If True, EOF errors
will return "" (to emulate normal sockets).
@ -209,9 +231,11 @@ class SSLConnectionProxyMeta:
@six.add_metaclass(SSLConnectionProxyMeta)
class SSLConnection:
"""A thread-safe wrapper for an SSL.Connection.
r"""A thread-safe wrapper for an ``SSL.Connection``.
``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
:param tuple args: the arguments to create the wrapped \
:py:class:`SSL.Connection(*args) \
<pyopenssl:OpenSSL.SSL.Connection>`
"""
def __init__(self, *args):
@ -224,22 +248,24 @@ class pyOpenSSLAdapter(Adapter):
"""A wrapper for integrating pyOpenSSL with Cheroot."""
certificate = None
"""The filename of the server SSL certificate."""
"""The file name of the server's TLS certificate."""
private_key = None
"""The filename of the server's private key file."""
"""The file name of the server's private key file."""
certificate_chain = None
"""Optional. The filename of CA's intermediate certificate bundle.
"""Optional. The file name of CA's intermediate certificate bundle.
This is needed for cheaper "chained root" SSL certificates, and should be
left as None if not required."""
This is needed for cheaper "chained root" TLS certificates,
and should be left as :py:data:`None` if not required."""
context = None
"""An instance of SSL.Context."""
"""
An instance of :py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`.
"""
ciphers = None
"""The ciphers list of SSL."""
"""The ciphers list of TLS."""
def __init__(
self, certificate, private_key, certificate_chain=None,
@ -265,10 +291,16 @@ class pyOpenSSLAdapter(Adapter):
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
# pyOpenSSL doesn't perform the handshake until the first read/write
# forcing the handshake to complete tends to result in the connection
# closing so we can't reliably access protocol/client cert for the env
return sock, self._environ.copy()
def get_context(self):
"""Return an SSL.Context from self attributes."""
"""Return an ``SSL.Context`` from self attributes.
Ref: :py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`
"""
# See https://code.activestate.com/recipes/442473/
c = SSL.Context(SSL.SSLv23_METHOD)
c.use_privatekey_file(self.private_key)
@ -280,18 +312,25 @@ class pyOpenSSLAdapter(Adapter):
def get_environ(self):
"""Return WSGI environ entries to be merged into each request."""
ssl_environ = {
'wsgi.url_scheme': 'https',
'HTTPS': 'on',
# pyOpenSSL doesn't provide access to any of these AFAICT
# 'SSL_PROTOCOL': 'SSLv2',
# SSL_CIPHER string The cipher specification name
# SSL_VERSION_INTERFACE string The mod_ssl program version
# SSL_VERSION_LIBRARY string The OpenSSL program version
'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % (
cheroot_server.HTTPServer.version,
OpenSSL.version.__title__, OpenSSL.version.__version__,
sys.version,
),
'SSL_VERSION_LIBRARY': SSL.SSLeay_version(
SSL.SSLEAY_VERSION,
).decode(),
}
if self.certificate:
# Server certificate attributes
cert = open(self.certificate, 'rb').read()
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
with open(self.certificate, 'rb') as cert_file:
cert = crypto.load_certificate(
crypto.FILETYPE_PEM, cert_file.read(),
)
ssl_environ.update({
'SSL_SERVER_M_VERSION': cert.get_version(),
'SSL_SERVER_M_SERIAL': cert.get_serial_number(),

View file

@ -0,0 +1,38 @@
"""Local pytest plugin.
Contains hooks, which are tightly bound to the Cheroot framework
itself, useless for end-users' app testing.
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
pytest_version = tuple(map(int, pytest.__version__.split('.')))
def pytest_load_initial_conftests(early_config, parser, args):
"""Drop unfilterable warning ignores."""
if pytest_version < (6, 2, 0):
return
# pytest>=6.2.0 under Python 3.8:
# Refs:
# * https://docs.pytest.org/en/stable/usage.html#unraisable
# * https://github.com/pytest-dev/pytest/issues/5299
early_config._inicache['filterwarnings'].extend((
'ignore:Exception in thread CP Server Thread-:'
'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception',
'ignore:Exception in thread Thread-:'
'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception',
'ignore:Exception ignored in. '
'<socket.socket fd=-1, family=AddressFamily.AF_INET, '
'type=SocketKind.SOCK_STREAM, proto=.:'
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
'ignore:Exception ignored in. '
'<socket.socket fd=-1, family=AddressFamily.AF_INET6, '
'type=SocketKind.SOCK_STREAM, proto=.:'
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
))

View file

@ -55,7 +55,7 @@ def http_server():
def make_http_server(bind_addr):
"""Create and start an HTTP server bound to bind_addr."""
"""Create and start an HTTP server bound to ``bind_addr``."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=Gateway,

View file

@ -99,7 +99,7 @@ class CherootWebCase(webtest.WebCase):
date_tolerance = 2
def assertEqualDates(self, dt1, dt2, seconds=None):
"""Assert abs(dt1 - dt2) is within Y seconds."""
"""Assert ``abs(dt1 - dt2)`` is within ``Y`` seconds."""
if seconds is None:
seconds = self.date_tolerance
@ -108,8 +108,10 @@ class CherootWebCase(webtest.WebCase):
else:
diff = dt2 - dt1
if not diff < datetime.timedelta(seconds=seconds):
raise AssertionError('%r and %r are not within %r seconds.' %
(dt1, dt2, seconds))
raise AssertionError(
'%r and %r are not within %r seconds.' %
(dt1, dt2, seconds),
)
class Request:
@ -155,9 +157,13 @@ class Controller:
resp.status = '404 Not Found'
else:
output = handler(req, resp)
if (output is not None
and not any(resp.status.startswith(status_code)
for status_code in ('204', '304'))):
if (
output is not None
and not any(
resp.status.startswith(status_code)
for status_code in ('204', '304')
)
):
resp.body = output
try:
resp.headers.setdefault('Content-Length', str(len(output)))

View file

@ -11,45 +11,45 @@ from cheroot._compat import extract_bytes, memoryview, ntob, ntou, bton
@pytest.mark.parametrize(
'func,inp,out',
[
('func', 'inp', 'out'),
(
(ntob, 'bar', b'bar'),
(ntou, 'bar', u'bar'),
(bton, b'bar', 'bar'),
],
),
)
def test_compat_functions_positive(func, inp, out):
"""Check that compat functions work with correct input."""
"""Check that compatibility functions work with correct input."""
assert func(inp, encoding='utf-8') == out
@pytest.mark.parametrize(
'func',
[
(
ntob,
ntou,
],
),
)
def test_compat_functions_negative_nonnative(func):
"""Check that compat functions fail loudly for incorrect input."""
"""Check that compatibility functions fail loudly for incorrect input."""
non_native_test_str = u'bar' if six.PY2 else b'bar'
with pytest.raises(TypeError):
func(non_native_test_str, encoding='utf-8')
def test_ntou_escape():
"""Check that ntou supports escape-encoding under Python 2."""
"""Check that ``ntou`` supports escape-encoding under Python 2."""
expected = u'hišřії'
actual = ntou('hi\u0161\u0159\u0456\u0457', encoding='escape')
assert actual == expected
@pytest.mark.parametrize(
'input_argument,expected_result',
[
('input_argument', 'expected_result'),
(
(b'qwerty', b'qwerty'),
(memoryview(b'asdfgh'), b'asdfgh'),
],
),
)
def test_extract_bytes(input_argument, expected_result):
"""Check that legitimate inputs produce bytes."""
@ -58,5 +58,9 @@ def test_extract_bytes(input_argument, expected_result):
def test_extract_bytes_invalid():
"""Ensure that invalid input causes exception to be raised."""
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=r'^extract_bytes\(\) only accepts bytes '
'and memoryview/buffer$',
):
extract_bytes(u'some юнікод їїї')

View file

@ -0,0 +1,97 @@
"""Tests to verify the command line interface.
.. spelling::
cli
"""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
import sys
import six
import pytest
from cheroot.cli import (
Application,
parse_wsgi_bind_addr,
)
@pytest.mark.parametrize(
('raw_bind_addr', 'expected_bind_addr'),
(
# tcp/ip
('192.168.1.1:80', ('192.168.1.1', 80)),
# ipv6 ips has to be enclosed in brakets when specified in url form
('[::1]:8000', ('::1', 8000)),
('localhost:5000', ('localhost', 5000)),
# this is a valid input, but foo gets discarted
('foo@bar:5000', ('bar', 5000)),
('foo', ('foo', None)),
('123456789', ('123456789', None)),
# unix sockets
('/tmp/cheroot.sock', '/tmp/cheroot.sock'),
('/tmp/some-random-file-name', '/tmp/some-random-file-name'),
# abstract sockets
('@cheroot', '\x00cheroot'),
),
)
def test_parse_wsgi_bind_addr(raw_bind_addr, expected_bind_addr):
"""Check the parsing of the --bind option.
Verify some of the supported addresses and the expected return value.
"""
assert parse_wsgi_bind_addr(raw_bind_addr) == expected_bind_addr
@pytest.fixture
def wsgi_app(monkeypatch):
"""Return a WSGI app stub."""
class WSGIAppMock:
"""Mock of a wsgi module."""
def application(self):
"""Empty application method.
Default method to be called when no specific callable
is defined in the wsgi application identifier.
It has an empty body because we are expecting to verify that
the same method is return no the actual execution of it.
"""
def main(self):
"""Empty custom method (callable) inside the mocked WSGI app.
It has an empty body because we are expecting to verify that
the same method is return no the actual execution of it.
"""
app = WSGIAppMock()
# patch sys.modules, to include the an instance of WSGIAppMock
# under a specific namespace
if six.PY2:
# python2 requires the previous namespaces to be part of sys.modules
# (e.g. for 'a.b.c' we need to insert 'a', 'a.b' and 'a.b.c')
# otherwise it fails, we're setting the same instance on each level,
# we don't really care about those, just the last one.
monkeypatch.setitem(sys.modules, 'mypkg', app)
monkeypatch.setitem(sys.modules, 'mypkg.wsgi', app)
return app
@pytest.mark.parametrize(
('app_name', 'app_method'),
(
(None, 'application'),
('application', 'application'),
('main', 'main'),
),
)
def test_Aplication_resolve(app_name, app_method, wsgi_app):
"""Check the wsgi application name conversion."""
if app_name is None:
wsgi_app_spec = 'mypkg.wsgi'
else:
wsgi_app_spec = 'mypkg.wsgi:{app_name}'.format(**locals())
expected_app = getattr(wsgi_app, app_method)
assert Application.resolve(wsgi_app_spec).wsgi_app == expected_app

View file

@ -3,15 +3,22 @@
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno
import socket
import time
import logging
import traceback as traceback_
from collections import namedtuple
from six.moves import range, http_client, urllib
import six
import pytest
from jaraco.text import trim, unwrap
from cheroot.test import helper, webtest
from cheroot._compat import IS_CI, IS_PYPY, IS_WINDOWS
import cheroot.server
timeout = 1
@ -26,7 +33,7 @@ class Controller(helper.Controller):
return 'Hello, world!'
def pov(req, resp):
"""Render pov value."""
"""Render ``pov`` value."""
return pov
def stream(req, resp):
@ -43,8 +50,10 @@ class Controller(helper.Controller):
def upload(req, resp):
"""Process file upload and render thank."""
if not req.environ['REQUEST_METHOD'] == 'POST':
raise AssertionError("'POST' != request.method %r" %
req.environ['REQUEST_METHOD'])
raise AssertionError(
"'POST' != request.method %r" %
req.environ['REQUEST_METHOD'],
)
return "thanks for '%s'" % req.environ['wsgi.input'].read()
def custom_204(req, resp):
@ -103,9 +112,33 @@ class Controller(helper.Controller):
}
class ErrorLogMonitor:
"""Mock class to access the server error_log calls made by the server."""
ErrorLogCall = namedtuple('ErrorLogCall', ['msg', 'level', 'traceback'])
def __init__(self):
"""Initialize the server error log monitor/interceptor.
If you need to ignore a particular error message use the property
``ignored_msgs`` by appending to the list the expected error messages.
"""
self.calls = []
# to be used the the teardown validation
self.ignored_msgs = []
def __call__(self, msg='', level=logging.INFO, traceback=False):
"""Intercept the call to the server error_log method."""
if traceback:
tblines = traceback_.format_exc()
else:
tblines = ''
self.calls.append(ErrorLogMonitor.ErrorLogCall(msg, level, tblines))
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
def raw_testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and preconfigure it."""
app = Controller()
def _timeout(req, resp):
@ -117,9 +150,36 @@ def testing_server(wsgi_server_client):
wsgi_server.timeout = timeout
wsgi_server.server_client = wsgi_server_client
wsgi_server.keep_alive_conn_limit = 2
return wsgi_server
@pytest.fixture
def testing_server(raw_testing_server, monkeypatch):
"""Modify the "raw" base server to monitor the error_log messages.
If you need to ignore a particular error message use the property
``testing_server.error_log.ignored_msgs`` by appending to the list
the expected error messages.
"""
# patch the error_log calls of the server instance
monkeypatch.setattr(raw_testing_server, 'error_log', ErrorLogMonitor())
yield raw_testing_server
# Teardown verification, in case that the server logged an
# error that wasn't notified to the client or we just made a mistake.
for c_msg, c_level, c_traceback in raw_testing_server.error_log.calls:
if c_level <= logging.WARNING:
continue
assert c_msg in raw_testing_server.error_log.ignored_msgs, (
'Found error in the error log: '
"message = '{c_msg}', level = '{c_level}'\n"
'{c_traceback}'.format(**locals()),
)
@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
@ -338,7 +398,14 @@ def test_streaming_10(test_client, set_cl):
'http_server_protocol',
(
'HTTP/1.0',
pytest.param(
'HTTP/1.1',
marks=pytest.mark.xfail(
IS_PYPY and IS_CI,
reason='Fails under PyPy in CI for unknown reason',
strict=False,
),
),
),
)
def test_keepalive(test_client, http_server_protocol):
@ -375,6 +442,11 @@ def test_keepalive(test_client, http_server_protocol):
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
assert header_has_value(
'Keep-Alive',
'timeout={test_client.server_instance.timeout}'.format(**locals()),
actual_headers,
)
# Remove the keep-alive header again.
status_line, actual_headers, actual_resp_body = test_client.get(
@ -386,6 +458,7 @@ def test_keepalive(test_client, http_server_protocol):
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
assert not header_exists('Keep-Alive', actual_headers)
test_client.server_instance.protocol = original_server_protocol
@ -401,7 +474,7 @@ def test_keepalive_conn_management(test_client):
http_connection.connect()
return http_connection
def request(conn):
def request(conn, keepalive=True):
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', headers=[('Connection', 'Keep-Alive')],
http_conn=conn, protocol='HTTP/1.0',
@ -410,7 +483,28 @@ def test_keepalive_conn_management(test_client):
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
if keepalive:
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
assert header_has_value(
'Keep-Alive',
'timeout={test_client.server_instance.timeout}'.
format(**locals()),
actual_headers,
)
else:
assert not header_exists('Connection', actual_headers)
assert not header_exists('Keep-Alive', actual_headers)
def check_server_idle_conn_count(count, timeout=1.0):
deadline = time.time() + timeout
while True:
n = test_client.server_instance._connections._num_connections
if n == count:
return
assert time.time() <= deadline, (
'idle conn count mismatch, wanted {count}, got {n}'.
format(**locals()),
)
disconnect_errors = (
http_client.BadStatusLine,
@ -421,50 +515,175 @@ def test_keepalive_conn_management(test_client):
# Make a new connection.
c1 = connection()
request(c1)
check_server_idle_conn_count(1)
# Make a second one.
c2 = connection()
request(c2)
check_server_idle_conn_count(2)
# Reusing the first connection should still work.
request(c1)
check_server_idle_conn_count(2)
# Creating a new connection should still work.
# Creating a new connection should still work, but we should
# have run out of available connections to keep alive, so the
# server should tell us to close.
c3 = connection()
request(c3)
request(c3, keepalive=False)
check_server_idle_conn_count(2)
# Allow a tick.
time.sleep(0.2)
# That's three connections, we should expect the one used less recently
# to be expired.
# Show that the third connection was closed.
with pytest.raises(disconnect_errors):
request(c2)
# But the oldest created one should still be valid.
# (As well as the newest one).
request(c1)
request(c3)
check_server_idle_conn_count(2)
# Wait for some of our timeout.
time.sleep(1.0)
time.sleep(1.2)
# Refresh the third connection.
request(c3)
# Refresh the second connection.
request(c2)
check_server_idle_conn_count(2)
# Wait for the remainder of our timeout, plus one tick.
time.sleep(1.2)
check_server_idle_conn_count(1)
# First connection should now be expired.
with pytest.raises(disconnect_errors):
request(c1)
check_server_idle_conn_count(1)
# But the third one should still be valid.
request(c3)
# But the second one should still be valid.
request(c2)
check_server_idle_conn_count(1)
# Restore original timeout.
test_client.server_instance.timeout = timeout
@pytest.mark.parametrize(
('simulated_exception', 'error_number', 'exception_leaks'),
(
pytest.param(
socket.error, errno.ECONNRESET, False,
id='socket.error(ECONNRESET)',
),
pytest.param(
socket.error, errno.EPIPE, False,
id='socket.error(EPIPE)',
),
pytest.param(
socket.error, errno.ENOTCONN, False,
id='simulated socket.error(ENOTCONN)',
),
pytest.param(
None, # <-- don't raise an artificial exception
errno.ENOTCONN, False,
id='real socket.error(ENOTCONN)',
marks=pytest.mark.xfail(
IS_WINDOWS,
reason='Now reproducible this way on Windows',
),
),
pytest.param(
socket.error, errno.ESHUTDOWN, False,
id='socket.error(ESHUTDOWN)',
),
pytest.param(RuntimeError, 666, True, id='RuntimeError(666)'),
pytest.param(socket.error, -1, True, id='socket.error(-1)'),
) + (
() if six.PY2 else (
pytest.param(
ConnectionResetError, errno.ECONNRESET, False,
id='ConnectionResetError(ECONNRESET)',
),
pytest.param(
BrokenPipeError, errno.EPIPE, False,
id='BrokenPipeError(EPIPE)',
),
pytest.param(
BrokenPipeError, errno.ESHUTDOWN, False,
id='BrokenPipeError(ESHUTDOWN)',
),
)
),
)
def test_broken_connection_during_tcp_fin(
error_number, exception_leaks,
mocker, monkeypatch,
simulated_exception, test_client,
):
"""Test there's no traceback on broken connection during close.
It artificially causes :py:data:`~errno.ECONNRESET` /
:py:data:`~errno.EPIPE` / :py:data:`~errno.ESHUTDOWN` /
:py:data:`~errno.ENOTCONN` as well as unrelated :py:exc:`RuntimeError`
and :py:exc:`socket.error(-1) <socket.error>` on the server socket when
:py:meth:`socket.shutdown() <socket.socket.shutdown>` is called. It's
triggered by closing the client socket before the server had a chance
to respond.
The expectation is that only :py:exc:`RuntimeError` and a
:py:exc:`socket.error` with an unusual error code would leak.
With the :py:data:`None`-parameter, a real non-simulated
:py:exc:`OSError(107, 'Transport endpoint is not connected')
<OSError>` happens.
"""
exc_instance = (
None if simulated_exception is None
else simulated_exception(error_number, 'Simulated socket error')
)
old_close_kernel_socket = (
test_client.server_instance.
ConnectionClass._close_kernel_socket
)
def _close_kernel_socket(self):
monkeypatch.setattr( # `socket.shutdown` is read-only otherwise
self, 'socket',
mocker.mock_module.Mock(wraps=self.socket),
)
if exc_instance is not None:
monkeypatch.setattr(
self.socket, 'shutdown',
mocker.mock_module.Mock(side_effect=exc_instance),
)
_close_kernel_socket.fin_spy = mocker.spy(self.socket, 'shutdown')
_close_kernel_socket.exception_leaked = True
old_close_kernel_socket(self)
_close_kernel_socket.exception_leaked = False
monkeypatch.setattr(
test_client.server_instance.ConnectionClass,
'_close_kernel_socket',
_close_kernel_socket,
)
conn = test_client.get_connection()
conn.auto_open = False
conn.connect()
conn.send(b'GET /hello HTTP/1.1')
conn.send(('Host: %s' % conn.host).encode('ascii'))
conn.close()
for _ in range(10): # Let the server attempt TCP shutdown
time.sleep(0.1)
if hasattr(_close_kernel_socket, 'exception_leaked'):
break
if exc_instance is not None: # simulated by us
assert _close_kernel_socket.fin_spy.spy_exception is exc_instance
else: # real
assert isinstance(
_close_kernel_socket.fin_spy.spy_exception, socket.error,
)
assert _close_kernel_socket.fin_spy.spy_exception.errno == error_number
assert _close_kernel_socket.exception_leaked is exception_leaks
@pytest.mark.parametrize(
'timeout_before_headers',
(
@ -475,7 +694,7 @@ def test_keepalive_conn_management(test_client):
def test_HTTP11_Timeout(test_client, timeout_before_headers):
"""Check timeout without sending any data.
The server will close the conn with a 408.
The server will close the connection with a 408.
"""
conn = test_client.get_connection()
conn.auto_open = False
@ -594,7 +813,7 @@ def test_HTTP11_Timeout_after_request(test_client):
def test_HTTP11_pipelining(test_client):
"""Test HTTP/1.1 pipelining.
httplib doesn't support this directly.
:py:mod:`http.client` doesn't support this directly.
"""
conn = test_client.get_connection()
@ -639,7 +858,7 @@ def test_100_Continue(test_client):
conn = test_client.get_connection()
# Try a page without an Expect request header first.
# Note that httplib's response.begin automatically ignores
# Note that http.client's response.begin automatically ignores
# 100 Continue responses, so we must manually check for it.
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
@ -800,11 +1019,13 @@ def test_No_Message_Body(test_client):
@pytest.mark.xfail(
reason='Server does not correctly read trailers/ending of the previous '
'HTTP request, thus the second request fails as the server tries '
r"to parse b'Content-Type: application/json\r\n' as a "
'Request-Line. This results in HTTP status code 400, instead of 413'
'Ref: https://github.com/cherrypy/cheroot/issues/69',
reason=unwrap(
trim("""
Headers from earlier request leak into the request
line for a subsequent request, resulting in 400
instead of 413. See cherrypy/cheroot#69 for details.
"""),
),
)
def test_Chunked_Encoding(test_client):
"""Test HTTP uploads with chunked transfer-encoding."""
@ -837,7 +1058,7 @@ def test_Chunked_Encoding(test_client):
# Try a chunked request that exceeds server.max_request_body_size.
# Note that the delimiters and trailer are included.
body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n'
body = b'\r\n'.join((b'3e3', b'x' * 995, b'0', b'', b''))
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Transfer-Encoding', 'chunked')
@ -895,7 +1116,7 @@ def test_Content_Length_not_int(test_client):
@pytest.mark.parametrize(
'uri,expected_resp_status,expected_resp_body',
('uri', 'expected_resp_status', 'expected_resp_body'),
(
(
'/wrong_cl_buffered', 500,
@ -929,6 +1150,16 @@ def test_Content_Length_out(
conn.close()
# the server logs the exception that we had verified from the
# client perspective. Tell the error_log verification that
# it can ignore that message.
test_client.server_instance.error_log.ignored_msgs.extend((
# Python 3.7+:
"ValueError('Response body exceeds the declared Content-Length.')",
# Python 2.7-3.6 (macOS?):
"ValueError('Response body exceeds the declared Content-Length.',)",
))
@pytest.mark.xfail(
reason='Sometimes this test fails due to low timeout. '
@ -970,11 +1201,94 @@ def test_No_CRLF(test_client, invalid_terminator):
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
# (b'%s' % b'') is not supported in Python 3.4, so just use +
conn.send(b'GET /hello HTTP/1.1' + invalid_terminator)
# (b'%s' % b'') is not supported in Python 3.4, so just use bytes.join()
conn.send(b''.join((b'GET /hello HTTP/1.1', invalid_terminator)))
response = conn.response_class(conn.sock, method='GET')
response.begin()
actual_resp_body = response.read()
expected_resp_body = b'HTTP requires CRLF terminators'
assert actual_resp_body == expected_resp_body
conn.close()
class FaultySelect:
"""Mock class to insert errors in the selector.select method."""
def __init__(self, original_select):
"""Initilize helper class to wrap the selector.select method."""
self.original_select = original_select
self.request_served = False
self.os_error_triggered = False
def __call__(self, timeout):
"""Intercept the calls to selector.select."""
if self.request_served:
self.os_error_triggered = True
raise OSError('Error while selecting the client socket.')
return self.original_select(timeout)
class FaultyGetMap:
"""Mock class to insert errors in the selector.get_map method."""
def __init__(self, original_get_map):
"""Initilize helper class to wrap the selector.get_map method."""
self.original_get_map = original_get_map
self.sabotage_conn = False
self.conn_closed = False
def __call__(self):
"""Intercept the calls to selector.get_map."""
sabotage_targets = (
conn for _, (_, _, _, conn) in self.original_get_map().items()
if isinstance(conn, cheroot.server.HTTPConnection)
) if self.sabotage_conn and not self.conn_closed else ()
for conn in sabotage_targets:
# close the socket to cause OSError
conn.close()
self.conn_closed = True
return self.original_get_map()
def test_invalid_selected_connection(test_client, monkeypatch):
"""Test the error handling segment of HTTP connection selection.
See :py:meth:`cheroot.connections.ConnectionManager.get_conn`.
"""
# patch the select method
faux_select = FaultySelect(
test_client.server_instance._connections._selector.select,
)
monkeypatch.setattr(
test_client.server_instance._connections._selector,
'select',
faux_select,
)
# patch the get_map method
faux_get_map = FaultyGetMap(
test_client.server_instance._connections._selector._selector.get_map,
)
monkeypatch.setattr(
test_client.server_instance._connections._selector._selector,
'get_map',
faux_get_map,
)
# request a page with connection keep-alive to make sure
# we'll have a connection to be modified.
resp_status, resp_headers, resp_body = test_client.request(
'/page1', headers=[('Connection', 'Keep-Alive')],
)
assert resp_status == '200 OK'
# trigger the internal errors
faux_get_map.sabotage_conn = faux_select.request_served = True
# give time to make sure the error gets handled
time.sleep(test_client.server_instance.expiration_interval * 2)
assert faux_select.os_error_triggered
assert faux_get_map.conn_closed

View file

@ -18,6 +18,7 @@ from cheroot.test import helper
HTTP_BAD_REQUEST = 400
HTTP_LENGTH_REQUIRED = 411
HTTP_NOT_FOUND = 404
HTTP_REQUEST_ENTITY_TOO_LARGE = 413
HTTP_OK = 200
HTTP_VERSION_NOT_SUPPORTED = 505
@ -78,7 +79,7 @@ def _get_http_response(connection, method='GET'):
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
"""Attach a WSGI app to the given server and preconfigure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = HelloController()
wsgi_server.max_request_body_size = 30000000
@ -92,6 +93,21 @@ def test_client(testing_server):
return testing_server.server_client
@pytest.fixture
def testing_server_with_defaults(wsgi_server_client):
"""Attach a WSGI app to the given server and preconfigure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = HelloController()
wsgi_server.server_client = wsgi_server_client
return wsgi_server
@pytest.fixture
def test_client_with_defaults(testing_server_with_defaults):
"""Get and return a test client out of the given server."""
return testing_server_with_defaults.server_client
def test_http_connect_request(test_client):
"""Check that CONNECT query results in Method Not Allowed status."""
status_line = test_client.connect('/anything')[0]
@ -262,8 +278,30 @@ def test_content_length_required(test_client):
assert actual_status == HTTP_LENGTH_REQUIRED
@pytest.mark.xfail(
reason='https://github.com/cherrypy/cheroot/issues/106',
strict=False, # sometimes it passes
)
def test_large_request(test_client_with_defaults):
"""Test GET query with maliciously large Content-Length."""
# If the server's max_request_body_size is not set (i.e. is set to 0)
# then this will result in an `OverflowError: Python int too large to
# convert to C ssize_t` in the server.
# We expect that this should instead return that the request is too
# large.
c = test_client_with_defaults.get_connection()
c.putrequest('GET', '/hello')
c.putheader('Content-Length', str(2**64))
c.endheaders()
response = c.getresponse()
actual_status = response.status
assert actual_status == HTTP_REQUEST_ENTITY_TOO_LARGE
@pytest.mark.parametrize(
'request_line,status_code,expected_body',
('request_line', 'status_code', 'expected_body'),
(
(
b'GET /', # missing proto
@ -401,7 +439,7 @@ class CloseResponse:
@pytest.fixture
def testing_server_close(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
"""Attach a WSGI app to the given server and preconfigure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = CloseController()
wsgi_server.max_request_body_size = 30000000

View file

@ -8,7 +8,7 @@ from cheroot.wsgi import PathInfoDispatcher
def wsgi_invoke(app, environ):
"""Serve 1 requeset from a WSGI application."""
"""Serve 1 request from a WSGI application."""
response = {}
def start_response(status, headers):
@ -25,7 +25,7 @@ def wsgi_invoke(app, environ):
def test_dispatch_no_script_name():
"""Despatch despite lack of SCRIPT_NAME in environ."""
"""Dispatch despite lack of ``SCRIPT_NAME`` in environ."""
# Bare bones WSGI hello world app (from PEP 333).
def app(environ, start_response):
start_response(

View file

@ -8,7 +8,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
@pytest.mark.parametrize(
'err_names,err_nums',
('err_names', 'err_nums'),
(
(('', 'some-nonsense-name'), []),
(
@ -24,7 +24,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
),
)
def test_plat_specific_errors(err_names, err_nums):
"""Test that plat_specific_errors retrieves correct err num list."""
"""Test that ``plat_specific_errors`` gets correct error numbers list."""
actual_err_nums = errors.plat_specific_errors(*err_names)
assert len(actual_err_nums) == len(err_nums)
assert sorted(actual_err_nums) == sorted(err_nums)

View file

@ -1,4 +1,4 @@
"""self-explanatory."""
"""Tests for :py:mod:`cheroot.makefile`."""
from cheroot import makefile
@ -7,14 +7,14 @@ __metaclass__ = type
class MockSocket:
"""Mocks a socket."""
"""A mock socket."""
def __init__(self):
"""Initialize."""
"""Initialize :py:class:`MockSocket`."""
self.messages = []
def recv_into(self, buf):
"""Simulate recv_into for Python 3."""
"""Simulate ``recv_into`` for Python 3."""
if not self.messages:
return 0
msg = self.messages.pop(0)
@ -23,7 +23,7 @@ class MockSocket:
return len(msg)
def recv(self, size):
"""Simulate recv for Python 2."""
"""Simulate ``recv`` for Python 2."""
try:
return self.messages.pop(0)
except IndexError:
@ -44,7 +44,7 @@ def test_bytes_read():
def test_bytes_written():
"""Writer should capture bytes writtten."""
"""Writer should capture bytes written."""
sock = MockSocket()
sock.messages.append(b'foo')
wfile = makefile.MakeFile(sock, 'w')

View file

@ -5,10 +5,12 @@
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from contextlib import closing
import os
import socket
import tempfile
import threading
import time
import uuid
import pytest
@ -16,14 +18,15 @@ import requests
import requests_unixsocket
import six
from six.moves import queue, urllib
from .._compat import bton, ntob
from .._compat import IS_LINUX, IS_MACOS, SYS_PLATFORM
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM
from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
EPHEMERAL_PORT,
get_server_client,
)
@ -42,15 +45,8 @@ non_macos_sock_test = pytest.mark.skipif(
@pytest.fixture(params=('abstract', 'file'))
def unix_sock_file(request):
"""Check that bound UNIX socket address is stored in server."""
if request.param == 'abstract':
yield request.getfixturevalue('unix_abstract_sock')
return
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
yield tmp_sock_fname
os.close(tmp_sock_fh)
os.unlink(tmp_sock_fname)
name = 'unix_{request.param}_sock'.format(**locals())
return request.getfixturevalue(name)
@pytest.fixture
@ -67,6 +63,17 @@ def unix_abstract_sock():
)).decode()
@pytest.fixture
def unix_file_sock():
"""Yield a unix file socket."""
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
yield tmp_sock_fname
os.close(tmp_sock_fh)
os.unlink(tmp_sock_fname)
def test_prepare_makes_server_ready():
"""Check that prepare() makes the server ready, and stop() clears it."""
httpserver = HTTPServer(
@ -110,6 +117,77 @@ def test_stop_interrupts_serve():
assert not serve_thread.is_alive()
@pytest.mark.parametrize(
'exc_cls',
(
IOError,
KeyboardInterrupt,
OSError,
RuntimeError,
),
)
def test_server_interrupt(exc_cls):
"""Check that assigning interrupt stops the server."""
interrupt_msg = 'should catch {uuid!s}'.format(uuid=uuid.uuid4())
raise_marker_sentinel = object()
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
result_q = queue.Queue()
def serve_thread():
# ensure we catch the exception on the serve() thread
try:
httpserver.serve()
except exc_cls as e:
if str(e) == interrupt_msg:
result_q.put(raise_marker_sentinel)
httpserver.prepare()
serve_thread = threading.Thread(target=serve_thread)
serve_thread.start()
serve_thread.join(0.5)
assert serve_thread.is_alive()
# this exception is raised on the serve() thread,
# not in the calling context.
httpserver.interrupt = exc_cls(interrupt_msg)
serve_thread.join(0.5)
assert not serve_thread.is_alive()
assert result_q.get_nowait() is raise_marker_sentinel
def test_serving_is_false_and_stop_returns_after_ctrlc():
"""Check that stop() interrupts running of serve()."""
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
httpserver.prepare()
# Simulate a Ctrl-C on the first call to `run`.
def raise_keyboard_interrupt(*args, **kwargs):
raise KeyboardInterrupt()
httpserver._connections._selector.select = raise_keyboard_interrupt
serve_thread = threading.Thread(target=httpserver.serve)
serve_thread.start()
# The thread should exit right away due to the interrupt.
serve_thread.join(httpserver.expiration_interval * 2)
assert not serve_thread.is_alive()
assert not httpserver._connections._serving
httpserver.stop()
@pytest.mark.parametrize(
'ip_addr',
(
@ -135,7 +213,7 @@ def test_bind_addr_unix(http_server, unix_sock_file):
@unix_only_sock_test
def test_bind_addr_unix_abstract(http_server, unix_abstract_sock):
"""Check that bound UNIX abstract sockaddr is stored in server."""
"""Check that bound UNIX abstract socket address is stored in server."""
httpserver = http_server.send(unix_abstract_sock)
assert httpserver.bind_addr == unix_abstract_sock
@ -167,27 +245,26 @@ class _TestGateway(Gateway):
@pytest.fixture
def peercreds_enabled_server_and_client(http_server, unix_sock_file):
"""Construct a test server with `peercreds_enabled`."""
def peercreds_enabled_server(http_server, unix_sock_file):
"""Construct a test server with ``peercreds_enabled``."""
httpserver = http_server.send(unix_sock_file)
httpserver.gateway = _TestGateway
httpserver.peercreds_enabled = True
return httpserver, get_server_client(httpserver)
return httpserver
@unix_only_sock_test
@non_macos_sock_test
def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
"""Check that peercred lookup works when enabled."""
httpserver, testclient = peercreds_enabled_server_and_client
def test_peercreds_unix_sock(peercreds_enabled_server):
"""Check that ``PEERCRED`` lookup works when enabled."""
httpserver = peercreds_enabled_server
bind_addr = httpserver.bind_addr
if isinstance(bind_addr, six.binary_type):
bind_addr = bind_addr.decode()
unix_base_uri = 'http+unix://{}'.format(
bind_addr.replace('\0', '%00').replace('/', '%2F'),
)
quoted = urllib.parse.quote(bind_addr, safe='')
unix_base_uri = 'http+unix://{quoted}'.format(**locals())
expected_peercreds = os.getpid(), os.getuid(), os.getgid()
expected_peercreds = '|'.join(map(str, expected_peercreds))
@ -208,9 +285,9 @@ def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
)
@unix_only_sock_test
@non_macos_sock_test
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
"""Check that peercred resolution works when enabled."""
httpserver, testclient = peercreds_enabled_server_and_client
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server):
"""Check that ``PEERCRED`` resolution works when enabled."""
httpserver = peercreds_enabled_server
httpserver.peercreds_resolve_enabled = True
bind_addr = httpserver.bind_addr
@ -218,9 +295,8 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
if isinstance(bind_addr, six.binary_type):
bind_addr = bind_addr.decode()
unix_base_uri = 'http+unix://{}'.format(
bind_addr.replace('\0', '%00').replace('/', '%2F'),
)
quoted = urllib.parse.quote(bind_addr, safe='')
unix_base_uri = 'http+unix://{quoted}'.format(**locals())
import grp
import pwd
@ -233,3 +309,111 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI)
peercreds_text_resp.raise_for_status()
assert peercreds_text_resp.text == expected_textcreds
@pytest.mark.skipif(
IS_WINDOWS,
reason='This regression test is for a Linux bug, '
'and the resource module is not available on Windows',
)
@pytest.mark.parametrize(
'resource_limit',
(
1024,
2048,
),
indirect=('resource_limit',),
)
@pytest.mark.usefixtures('many_open_sockets')
def test_high_number_of_file_descriptors(resource_limit):
"""Test the server does not crash with a high file-descriptor value.
This test shouldn't cause a server crash when trying to access
file-descriptor higher than 1024.
The earlier implementation used to rely on ``select()`` syscall that
doesn't support file descriptors with numbers higher than 1024.
"""
# We want to force the server to use a file-descriptor with
# a number above resource_limit
# Create our server
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), gateway=Gateway,
)
try:
# This will trigger a crash if select() is used in the implementation
with httpserver._run_in_thread():
# allow server to run long enough to invoke select()
time.sleep(1.0)
except: # noqa: E722
raise # only needed for `else` to work
else:
# We use closing here for py2-compat
with closing(socket.socket()) as sock:
# Check new sockets created are still above our target number
assert sock.fileno() >= resource_limit
finally:
# Stop our server
httpserver.stop()
if not IS_WINDOWS:
test_high_number_of_file_descriptors = pytest.mark.forked(
test_high_number_of_file_descriptors,
)
@pytest.fixture
def resource_limit(request):
"""Set the resource limit two times bigger then requested."""
resource = pytest.importorskip(
'resource',
reason='The "resource" module is Unix-specific',
)
# Get current resource limits to restore them later
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
# We have to increase the nofile limit above 1024
# Otherwise we see a 'Too many files open' error, instead of
# an error due to the file descriptor number being too high
resource.setrlimit(
resource.RLIMIT_NOFILE,
(request.param * 2, hard_limit),
)
try: # noqa: WPS501
yield request.param
finally:
# Reset the resource limit back to the original soft limit
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
@pytest.fixture
def many_open_sockets(resource_limit):
"""Allocate a lot of file descriptors by opening dummy sockets."""
# Hoard a lot of file descriptors by opening and storing a lot of sockets
test_sockets = []
# Open a lot of file descriptors, so the next one the server
# opens is a high number
try:
for i in range(resource_limit):
sock = socket.socket()
test_sockets.append(sock)
# NOTE: We used to interrupt the loop early but this doesn't seem
# NOTE: to work well in envs with indeterministic runtimes like
# NOTE: PyPy. It looks like sometimes it frees some file
# NOTE: descriptors in between running this fixture and the actual
# NOTE: test code so the early break has been removed to try
# NOTE: address that. The approach may need to be rethought if the
# NOTE: issue reoccurs. Another approach may be disabling the GC.
# Check we opened enough descriptors to reach a high number
the_highest_fileno = max(sock.fileno() for sock in test_sockets)
assert the_highest_fileno >= resource_limit
yield the_highest_fileno
finally:
# Close our open resources
for test_socket in test_sockets:
test_socket.close()

View file

@ -1,4 +1,4 @@
"""Tests for TLS/SSL support."""
"""Tests for TLS support."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
@ -6,11 +6,14 @@ from __future__ import absolute_import, division, print_function
__metaclass__ = type
import functools
import json
import os
import ssl
import subprocess
import sys
import threading
import time
import traceback
import OpenSSL.SSL
import pytest
@ -21,7 +24,7 @@ import trustme
from .._compat import bton, ntob, ntou
from .._compat import IS_ABOVE_OPENSSL10, IS_PYPY
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
from ..server import Gateway, HTTPServer, get_ssl_adapter_class
from ..server import HTTPServer, get_ssl_adapter_class
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
@ -30,9 +33,17 @@ from ..testing import (
_get_conn_data,
_probe_ipv6_sock,
)
from ..wsgi import Gateway_10
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
IS_WIN2016 = (
IS_WINDOWS
# pylint: disable=unsupported-membership-test
and b'Microsoft Windows Server 2016 Datacenter' in subprocess.check_output(
('systeminfo',),
)
)
IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL')
IS_PYOPENSSL_SSL_VERSION_1_0 = (
OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION).
@ -40,6 +51,7 @@ IS_PYOPENSSL_SSL_VERSION_1_0 = (
)
PY27 = sys.version_info[:2] == (2, 7)
PY34 = sys.version_info[:2] == (3, 4)
PY3 = not six.PY2
_stdlib_to_openssl_verify = {
@ -71,7 +83,7 @@ missing_ipv6 = pytest.mark.skipif(
)
class HelloWorldGateway(Gateway):
class HelloWorldGateway(Gateway_10):
"""Gateway responding with Hello World to root URI."""
def respond(self):
@ -83,11 +95,21 @@ class HelloWorldGateway(Gateway):
req.ensure_headers_sent()
req.write(b'Hello world!')
return
if req_uri == '/env':
req.status = b'200 OK'
req.ensure_headers_sent()
env = self.get_environ()
# drop files so that it can be json dumped
env.pop('wsgi.errors')
env.pop('wsgi.input')
print(env)
req.write(json.dumps(env).encode('utf-8'))
return
return super(HelloWorldGateway, self).respond()
def make_tls_http_server(bind_addr, ssl_adapter, request):
"""Create and start an HTTP server bound to bind_addr."""
"""Create and start an HTTP server bound to ``bind_addr``."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=HelloWorldGateway,
@ -128,7 +150,7 @@ def tls_ca_certificate_pem_path(ca):
def tls_certificate(ca):
"""Provide a leaf certificate via fixture."""
interface, host, port = _get_conn_data(ANY_INTERFACE_IPV4)
return ca.issue_server_cert(ntou(interface), )
return ca.issue_server_cert(ntou(interface))
@pytest.fixture
@ -145,6 +167,43 @@ def tls_certificate_private_key_pem_path(tls_certificate):
yield cert_key_pem
def _thread_except_hook(exceptions, args):
"""Append uncaught exception ``args`` in threads to ``exceptions``."""
if issubclass(args.exc_type, SystemExit):
return
# cannot store the exception, it references the thread's stack
exceptions.append((
args.exc_type,
str(args.exc_value),
''.join(
traceback.format_exception(
args.exc_type, args.exc_value, args.exc_traceback,
),
),
))
@pytest.fixture
def thread_exceptions():
"""Provide a list of uncaught exceptions from threads via a fixture.
Only catches exceptions on Python 3.8+.
The list contains: ``(type, str(value), str(traceback))``
"""
exceptions = []
# Python 3.8+
orig_hook = getattr(threading, 'excepthook', None)
if orig_hook is not None:
threading.excepthook = functools.partial(
_thread_except_hook, exceptions,
)
try:
yield exceptions
finally:
if orig_hook is not None:
threading.excepthook = orig_hook
@pytest.mark.parametrize(
'adapter_type',
(
@ -180,7 +239,7 @@ def test_ssl_adapters(
)
resp = requests.get(
'https://' + interface + ':' + str(port) + '/',
'https://{host!s}:{port!s}/'.format(host=interface, port=port),
verify=tls_ca_certificate_pem_path,
)
@ -188,7 +247,7 @@ def test_ssl_adapters(
assert resp.text == 'Hello world!'
@pytest.mark.parametrize(
@pytest.mark.parametrize( # noqa: C901 # FIXME
'adapter_type',
(
'builtin',
@ -196,7 +255,7 @@ def test_ssl_adapters(
),
)
@pytest.mark.parametrize(
'is_trusted_cert,tls_client_identity',
('is_trusted_cert', 'tls_client_identity'),
(
(True, 'localhost'), (True, '127.0.0.1'),
(True, '*.localhost'), (True, 'not_localhost'),
@ -211,7 +270,7 @@ def test_ssl_adapters(
ssl.CERT_REQUIRED, # server should validate if client cert CA is OK
),
)
def test_tls_client_auth(
def test_tls_client_auth( # noqa: C901 # FIXME
# FIXME: remove twisted logic, separate tests
mocker,
tls_http_server, adapter_type,
@ -265,7 +324,7 @@ def test_tls_client_auth(
make_https_request = functools.partial(
requests.get,
'https://' + interface + ':' + str(port) + '/',
'https://{host!s}:{port!s}/'.format(host=interface, port=port),
# Server TLS certificate verification:
verify=tls_ca_certificate_pem_path,
@ -324,21 +383,23 @@ def test_tls_client_auth(
except AttributeError:
if PY34:
pytest.xfail('OpenSSL behaves wierdly under Python 3.4')
elif not six.PY2 and IS_WINDOWS:
elif IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
err_text = str(ssl_err.value)
else:
raise
if isinstance(err_text, int):
err_text = str(ssl_err.value)
expected_substrings = (
'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND
else 'tlsv1 alert unknown ca',
)
if not six.PY2:
if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl':
expected_substrings = ('tlsv1 alert unknown ca', )
expected_substrings = ('tlsv1 alert unknown ca',)
if (
IS_WINDOWS
and tls_verify_mode in (
tls_verify_mode in (
ssl.CERT_REQUIRED,
ssl.CERT_OPTIONAL,
)
@ -350,10 +411,172 @@ def test_tls_client_auth(
"SysCallError(10054, 'WSAECONNRESET')",
"('Connection aborted.', "
'OSError("(10054, \'WSAECONNRESET\')"))',
"('Connection aborted.', "
'OSError("(10054, \'WSAECONNRESET\')",))',
"('Connection aborted.', "
'error("(10054, \'WSAECONNRESET\')",))',
"('Connection aborted.', "
'ConnectionResetError(10054, '
"'An existing connection was forcibly closed "
"by the remote host', None, 10054, None))",
) if IS_WINDOWS else (
"('Connection aborted.', "
'OSError("(104, \'ECONNRESET\')"))',
"('Connection aborted.', "
'OSError("(104, \'ECONNRESET\')",))',
"('Connection aborted.', "
'error("(104, \'ECONNRESET\')",))',
"('Connection aborted.', "
"ConnectionResetError(104, 'Connection reset by peer'))",
"('Connection aborted.', "
"error(104, 'Connection reset by peer'))",
) if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_LINUX
) else (
"('Connection aborted.', "
"BrokenPipeError(32, 'Broken pipe'))",
)
assert any(e in err_text for e in expected_substrings)
@pytest.mark.parametrize( # noqa: C901 # FIXME
'adapter_type',
(
'builtin',
'pyopenssl',
),
)
@pytest.mark.parametrize(
('tls_verify_mode', 'use_client_cert'),
(
(ssl.CERT_NONE, False),
(ssl.CERT_NONE, True),
(ssl.CERT_OPTIONAL, False),
(ssl.CERT_OPTIONAL, True),
(ssl.CERT_REQUIRED, True),
),
)
def test_ssl_env( # noqa: C901 # FIXME
thread_exceptions,
recwarn,
mocker,
tls_http_server, adapter_type,
ca, tls_verify_mode, tls_certificate,
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
tls_ca_certificate_pem_path,
use_client_cert,
):
"""Test the SSL environment generated by the SSL adapters."""
interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4)
with mocker.mock_module.patch(
'idna.core.ulabel',
return_value=ntob('127.0.0.1'),
):
client_cert = ca.issue_cert(ntou('127.0.0.1'))
with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem:
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
tls_adapter = tls_adapter_cls(
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
)
if adapter_type == 'pyopenssl':
tls_adapter.context = tls_adapter.get_context()
tls_adapter.context.set_verify(
_stdlib_to_openssl_verify[tls_verify_mode],
lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
)
else:
tls_adapter.context.verify_mode = tls_verify_mode
ca.configure_trust(tls_adapter.context)
tls_certificate.configure_cert(tls_adapter.context)
tlswsgiserver = tls_http_server((interface, port), tls_adapter)
interface, _host, port = _get_conn_data(tlswsgiserver.bind_addr)
resp = requests.get(
'https://' + interface + ':' + str(port) + '/env',
verify=tls_ca_certificate_pem_path,
cert=cl_pem if use_client_cert else None,
)
if PY34 and resp.status_code != 200:
pytest.xfail(
'Python 3.4 has problems with verifying client certs',
)
env = json.loads(resp.content.decode('utf-8'))
# hard coded env
assert env['wsgi.url_scheme'] == 'https'
assert env['HTTPS'] == 'on'
# ensure these are present
for key in {'SSL_VERSION_INTERFACE', 'SSL_VERSION_LIBRARY'}:
assert key in env
# pyOpenSSL generates the env before the handshake completes
if adapter_type == 'pyopenssl':
return
for key in {'SSL_PROTOCOL', 'SSL_CIPHER'}:
assert key in env
# client certificate env
if tls_verify_mode == ssl.CERT_NONE or not use_client_cert:
assert env['SSL_CLIENT_VERIFY'] == 'NONE'
else:
assert env['SSL_CLIENT_VERIFY'] == 'SUCCESS'
with open(cl_pem, 'rt') as f:
assert env['SSL_CLIENT_CERT'] in f.read()
for key in {
'SSL_CLIENT_M_VERSION', 'SSL_CLIENT_M_SERIAL',
'SSL_CLIENT_I_DN', 'SSL_CLIENT_S_DN',
}:
assert key in env
# builtin ssl environment generation may use a loopback socket
# ensure no ResourceWarning was raised during the test
# NOTE: python 2.7 does not emit ResourceWarning for ssl sockets
if IS_PYPY:
# NOTE: PyPy doesn't have ResourceWarning
# Ref: https://doc.pypy.org/en/latest/cpython_differences.html
return
for warn in recwarn:
if not issubclass(warn.category, ResourceWarning):
continue
# the tests can sporadically generate resource warnings
# due to timing issues
# all of these sporadic warnings appear to be about socket.socket
# and have been observed to come from requests connection pool
msg = str(warn.message)
if 'socket.socket' in msg:
pytest.xfail(
'\n'.join((
'Sometimes this test fails due to '
'a socket.socket ResourceWarning:',
msg,
)),
)
pytest.fail(msg)
# to perform the ssl handshake over that loopback socket,
# the builtin ssl environment generation uses a thread
for _, _, trace in thread_exceptions:
print(trace, file=sys.stderr)
assert not thread_exceptions, ': '.join((
thread_exceptions[0][0].__name__,
thread_exceptions[0][1],
))
@pytest.mark.parametrize(
'ip_addr',
(
@ -382,7 +605,16 @@ def test_https_over_http_error(http_server, ip_addr):
@pytest.mark.parametrize(
'adapter_type',
(
pytest.param(
'builtin',
marks=pytest.mark.xfail(
IS_WINDOWS and six.PY2,
raises=requests.exceptions.ConnectionError,
reason='Stdlib `ssl` module behaves weirdly '
'on Windows under Python 2',
strict=False,
),
),
'pyopenssl',
),
)
@ -428,16 +660,41 @@ def test_http_over_https_error(
fqdn = interface
if ip_addr is ANY_INTERFACE_IPV6:
fqdn = '[{}]'.format(fqdn)
fqdn = '[{fqdn}]'.format(**locals())
expect_fallback_response_over_plain_http = (
(adapter_type == 'pyopenssl'
and (IS_ABOVE_OPENSSL10 or not six.PY2))
or PY27
(
adapter_type == 'pyopenssl'
and (IS_ABOVE_OPENSSL10 or not six.PY2)
)
or PY27
) or (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and not IS_WIN2016
)
if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and IS_WIN2016
and adapter_type == 'builtin'
and ip_addr is ANY_INTERFACE_IPV6
):
expect_fallback_response_over_plain_http = True
if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and not IS_WIN2016
and adapter_type == 'builtin'
and ip_addr is not ANY_INTERFACE_IPV6
):
expect_fallback_response_over_plain_http = False
if expect_fallback_response_over_plain_http:
resp = requests.get(
'http://' + fqdn + ':' + str(port) + '/',
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
)
assert resp.status_code == 400
assert resp.text == (
@ -448,7 +705,7 @@ def test_http_over_https_error(
with pytest.raises(requests.exceptions.ConnectionError) as ssl_err:
requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL
'http://' + fqdn + ':' + str(port) + '/',
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
)
if IS_LINUX:
@ -468,7 +725,7 @@ def test_http_over_https_error(
underlying_error = ssl_err.value.args[0].args[-1]
err_text = str(underlying_error)
assert underlying_error.errno == expected_error_code, (
'The underlying error is {!r}'.
format(underlying_error)
'The underlying error is {underlying_error!r}'.
format(**locals())
)
assert expected_error_text in err_text

View file

@ -0,0 +1,58 @@
"""Test wsgi."""
from concurrent.futures.thread import ThreadPoolExecutor
import pytest
import portend
import requests
from requests_toolbelt.sessions import BaseUrlSession as Session
from jaraco.context import ExceptionTrap
from cheroot import wsgi
from cheroot._compat import IS_MACOS, IS_WINDOWS
IS_SLOW_ENV = IS_MACOS or IS_WINDOWS
@pytest.fixture
def simple_wsgi_server():
"""Fucking simple wsgi server fixture (duh)."""
port = portend.find_available_local_port()
def app(environ, start_response):
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
return [b'Hello world!']
host = '::'
addr = host, port
server = wsgi.Server(addr, app, timeout=600 if IS_SLOW_ENV else 20)
url = 'http://localhost:{port}/'.format(**locals())
with server._run_in_thread() as thread:
yield locals()
def test_connection_keepalive(simple_wsgi_server):
"""Test the connection keepalive works (duh)."""
session = Session(base_url=simple_wsgi_server['url'])
pooled = requests.adapters.HTTPAdapter(
pool_connections=1, pool_maxsize=1000,
)
session.mount('http://', pooled)
def do_request():
with ExceptionTrap(requests.exceptions.ConnectionError) as trap:
resp = session.get('info')
resp.raise_for_status()
return bool(trap)
with ThreadPoolExecutor(max_workers=10 if IS_SLOW_ENV else 50) as pool:
tasks = [
pool.submit(do_request)
for n in range(250 if IS_SLOW_ENV else 1000)
]
failures = sum(task.result() for task in tasks)
assert not failures

View file

@ -1,12 +1,14 @@
"""Extensions to unittest for web frameworks.
Use the WebCase.getPage method to request a page from your HTTP server.
Use the :py:meth:`WebCase.getPage` method to request a page
from your HTTP server.
Framework Integration
=====================
If you have control over your server process, you can handle errors
in the server-side of the HTTP conversation a bit better. You must run
both the client (your WebCase tests) and the server in the same process
(but in separate threads, obviously).
both the client (your :py:class:`WebCase` tests) and the server in the
same process (but in separate threads, obviously).
When an error occurs in the framework, call server_error. It will print
the traceback to stdout, and keep any assertions you have from running
(the assumption is that, if the server errors, the page output will not
@ -122,7 +124,7 @@ class WebCase(unittest.TestCase):
def _Conn(self):
"""Return HTTPConnection or HTTPSConnection based on self.scheme.
* from http.client.
* from :py:mod:`python:http.client`.
"""
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
return getattr(http_client, cls_name)
@ -157,7 +159,7 @@ class WebCase(unittest.TestCase):
@property
def persistent(self):
"""Presense of the persistent HTTP connection."""
"""Presence of the persistent HTTP connection."""
return hasattr(self.HTTP_CONN, '__class__')
@persistent.setter
@ -176,7 +178,9 @@ class WebCase(unittest.TestCase):
self, url, headers=None, method='GET', body=None,
protocol=None, raise_subcls=(),
):
"""Open the url with debugging support. Return status, headers, body.
"""Open the url with debugging support.
Return status, headers, body.
url should be the identifier passed to the server, typically a
server-absolute path and query string (sent between method and
@ -184,16 +188,16 @@ class WebCase(unittest.TestCase):
enabled in the server.
If the application under test generates absolute URIs, be sure
to wrap them first with strip_netloc::
to wrap them first with :py:func:`strip_netloc`::
class MyAppWebCase(WebCase):
def getPage(url, *args, **kwargs):
super(MyAppWebCase, self).getPage(
cheroot.test.webtest.strip_netloc(url),
*args, **kwargs
)
>>> class MyAppWebCase(WebCase):
... def getPage(url, *args, **kwargs):
... super(MyAppWebCase, self).getPage(
... cheroot.test.webtest.strip_netloc(url),
... *args, **kwargs
... )
`raise_subcls` is passed through to openURL.
``raise_subcls`` is passed through to :py:func:`openURL`.
"""
ServerError.on = False
@ -247,7 +251,7 @@ class WebCase(unittest.TestCase):
console_height = 30
def _handlewebError(self, msg):
def _handlewebError(self, msg): # noqa: C901 # FIXME
print('')
print(' ERROR: %s' % msg)
@ -487,7 +491,7 @@ def openURL(*args, **kwargs):
"""
Open a URL, retrying when it fails.
Specify `raise_subcls` (class or tuple of classes) to exclude
Specify ``raise_subcls`` (class or tuple of classes) to exclude
those socket.error subclasses from being suppressed and retried.
"""
raise_subcls = kwargs.pop('raise_subcls', ())
@ -553,7 +557,7 @@ def strip_netloc(url):
server-absolute portion.
Useful for wrapping an absolute-URI for which only the
path is expected (such as in calls to getPage).
path is expected (such as in calls to :py:meth:`WebCase.getPage`).
>>> strip_netloc('https://google.com/foo/bar?bing#baz')
'/foo/bar?bing'

View file

@ -61,14 +61,14 @@ def cheroot_server(server_factory):
httpserver.stop() # destroy it
@pytest.fixture(scope='module')
@pytest.fixture
def wsgi_server():
"""Set up and tear down a Cheroot WSGI server instance."""
for srv in cheroot_server(cheroot.wsgi.Server):
yield srv
@pytest.fixture(scope='module')
@pytest.fixture
def native_server():
"""Set up and tear down a Cheroot HTTP server instance."""
for srv in cheroot_server(cheroot.server.HTTPServer):

View file

@ -1,4 +1,9 @@
"""A thread-based worker pool."""
"""A thread-based worker pool.
.. spelling::
joinable
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
@ -111,11 +116,6 @@ class WorkerThread(threading.Thread):
if conn is _SHUTDOWNREQUEST:
return
# Just close the connection and move on.
if conn.closeable:
conn.close()
continue
self.conn = conn
is_stats_enabled = self.server.stats['Enabled']
if is_stats_enabled:
@ -125,7 +125,7 @@ class WorkerThread(threading.Thread):
keep_conn_open = conn.communicate()
finally:
if keep_conn_open:
self.server.connections.put(conn)
self.server.put_conn(conn)
else:
conn.close()
if is_stats_enabled:
@ -176,7 +176,10 @@ class ThreadPool:
for i in range(self.min):
self._threads.append(WorkerThread(self.server))
for worker in self._threads:
worker.setName('CP Server ' + worker.getName())
worker.setName(
'CP Server {worker_name!s}'.
format(worker_name=worker.getName()),
)
worker.start()
for worker in self._threads:
while not worker.ready:
@ -192,7 +195,7 @@ class ThreadPool:
"""Put request into queue.
Args:
obj (cheroot.server.HTTPConnection): HTTP connection
obj (:py:class:`~cheroot.server.HTTPConnection`): HTTP connection
waiting to be processed
"""
self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
@ -223,7 +226,10 @@ class ThreadPool:
def _spawn_worker(self):
worker = WorkerThread(self.server)
worker.setName('CP Server ' + worker.getName())
worker.setName(
'CP Server {worker_name!s}'.
format(worker_name=worker.getName()),
)
worker.start()
return worker

View file

@ -119,10 +119,7 @@ class Gateway(server.Gateway):
corresponding class
"""
return dict(
(gw.version, gw)
for gw in cls.__subclasses__()
)
return {gw.version: gw for gw in cls.__subclasses__()}
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version."""
@ -195,9 +192,9 @@ class Gateway(server.Gateway):
"""Cast status to bytes representation of current Python version.
According to :pep:`3333`, when using Python 3, the response status
and headers must be bytes masquerading as unicode; that is, they
and headers must be bytes masquerading as Unicode; that is, they
must be of type "str" but are restricted to code points in the
"latin-1" set.
"Latin-1" set.
"""
if six.PY2:
return status
@ -300,7 +297,11 @@ class Gateway_10(Gateway):
# Request headers
env.update(
('HTTP_' + bton(k).upper().replace('-', '_'), bton(v))
(
'HTTP_{header_name!s}'.
format(header_name=bton(k).upper().replace('-', '_')),
bton(v),
)
for k, v in req.inheaders.items()
)
@ -321,7 +322,7 @@ class Gateway_10(Gateway):
class Gateway_u0(Gateway_10):
"""A Gateway class to interface HTTPServer with WSGI u.0.
WSGI u.0 is an experimental protocol, which uses unicode for keys
WSGI u.0 is an experimental protocol, which uses Unicode for keys
and values in both Python 2 and Python 3.
"""
@ -409,7 +410,7 @@ class PathInfoDispatcher:
path = environ['PATH_INFO'] or '/'
for p, app in self.apps:
# The apps list should be sorted by length, descending.
if path.startswith(p + '/') or path == p:
if path.startswith('{path!s}/'.format(path=p)) or path == p:
environ = environ.copy()
environ['SCRIPT_NAME'] = environ.get('SCRIPT_NAME', '') + p
environ['PATH_INFO'] = path[len(p):]