mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-06 13:11:15 -07:00
Update cheroot-8.5.2
This commit is contained in:
parent
4ac151d7de
commit
182e5f553e
25 changed files with 2171 additions and 602 deletions
|
@ -1,13 +1,20 @@
|
|||
# pylint: disable=unused-import
|
||||
"""Compatibility code for using Cheroot with various versions of Python."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
||||
import six
|
||||
|
||||
try:
|
||||
import selectors # lgtm [py/unused-import]
|
||||
except ImportError:
|
||||
import selectors2 as selectors # noqa: F401 # lgtm [py/unused-import]
|
||||
|
||||
try:
|
||||
import ssl
|
||||
IS_ABOVE_OPENSSL10 = ssl.OPENSSL_VERSION_INFO >= (1, 1)
|
||||
|
@ -15,6 +22,24 @@ try:
|
|||
except ImportError:
|
||||
IS_ABOVE_OPENSSL10 = None
|
||||
|
||||
# contextlib.suppress was added in Python 3.4
|
||||
try:
|
||||
from contextlib import suppress
|
||||
except ImportError:
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def suppress(*exceptions):
|
||||
"""Return a context manager that suppresses the `exceptions`."""
|
||||
try:
|
||||
yield
|
||||
except exceptions:
|
||||
pass
|
||||
|
||||
|
||||
IS_CI = bool(os.getenv('CI'))
|
||||
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
|
||||
|
||||
|
||||
IS_PYPY = platform.python_implementation() == 'PyPy'
|
||||
|
||||
|
@ -36,7 +61,7 @@ if not six.PY2:
|
|||
return n.encode(encoding)
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string as unicode with the given encoding."""
|
||||
"""Return the native string as Unicode with the given encoding."""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n
|
||||
|
@ -55,7 +80,7 @@ else:
|
|||
return n
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string as unicode with the given encoding."""
|
||||
"""Return the native string as Unicode with the given encoding."""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes.
|
||||
# First, check for the special encoding 'escape'. The test suite uses
|
||||
|
@ -78,7 +103,7 @@ else:
|
|||
|
||||
|
||||
def assert_native(n):
|
||||
"""Check whether the input is of nativ ``str`` type.
|
||||
"""Check whether the input is of native :py:class:`str` type.
|
||||
|
||||
Raises:
|
||||
TypeError: in case of failed check
|
||||
|
@ -89,22 +114,35 @@ def assert_native(n):
|
|||
|
||||
|
||||
if not six.PY2:
|
||||
"""Python 3 has memoryview builtin."""
|
||||
"""Python 3 has :py:class:`memoryview` builtin."""
|
||||
# Python 2.7 has it backported, but socket.write() does
|
||||
# str(memoryview(b'0' * 100)) -> <memory at 0x7fb6913a5588>
|
||||
# instead of accessing it correctly.
|
||||
memoryview = memoryview
|
||||
else:
|
||||
"""Link memoryview to buffer under Python 2."""
|
||||
"""Link :py:class:`memoryview` to buffer under Python 2."""
|
||||
memoryview = buffer # noqa: F821
|
||||
|
||||
|
||||
def extract_bytes(mv):
|
||||
"""Retrieve bytes out of memoryview/buffer or bytes."""
|
||||
r"""Retrieve bytes out of the given input buffer.
|
||||
|
||||
:param mv: input :py:func:`buffer`
|
||||
:type mv: memoryview or bytes
|
||||
|
||||
:return: unwrapped bytes
|
||||
:rtype: bytes
|
||||
|
||||
:raises ValueError: if the input is not one of \
|
||||
:py:class:`memoryview`/:py:func:`buffer` \
|
||||
or :py:class:`bytes`
|
||||
"""
|
||||
if isinstance(mv, memoryview):
|
||||
return bytes(mv) if six.PY2 else mv.tobytes()
|
||||
|
||||
if isinstance(mv, bytes):
|
||||
return mv
|
||||
|
||||
raise ValueError
|
||||
raise ValueError(
|
||||
'extract_bytes() only accepts bytes and memoryview/buffer',
|
||||
)
|
||||
|
|
|
@ -1,36 +1,42 @@
|
|||
"""Command line tool for starting a Cheroot WSGI/HTTP server instance.
|
||||
|
||||
Basic usage::
|
||||
Basic usage:
|
||||
|
||||
# Start a server on 127.0.0.1:8000 with the default settings
|
||||
# for the WSGI app myapp/wsgi.py:application()
|
||||
cheroot myapp.wsgi
|
||||
.. code-block:: shell-session
|
||||
|
||||
# Start a server on 0.0.0.0:9000 with 8 threads
|
||||
# for the WSGI app myapp/wsgi.py:main_app()
|
||||
cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8
|
||||
$ # Start a server on 127.0.0.1:8000 with the default settings
|
||||
$ # for the WSGI app myapp/wsgi.py:application()
|
||||
$ cheroot myapp.wsgi
|
||||
|
||||
# Start a server for the cheroot.server.Gateway subclass
|
||||
# myapp/gateway.py:HTTPGateway
|
||||
cheroot myapp.gateway:HTTPGateway
|
||||
$ # Start a server on 0.0.0.0:9000 with 8 threads
|
||||
$ # for the WSGI app myapp/wsgi.py:main_app()
|
||||
$ cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8
|
||||
|
||||
# Start a server on the UNIX socket /var/spool/myapp.sock
|
||||
cheroot myapp.wsgi --bind /var/spool/myapp.sock
|
||||
$ # Start a server for the cheroot.server.Gateway subclass
|
||||
$ # myapp/gateway.py:HTTPGateway
|
||||
$ cheroot myapp.gateway:HTTPGateway
|
||||
|
||||
# Start a server on the abstract UNIX socket CherootServer
|
||||
cheroot myapp.wsgi --bind @CherootServer
|
||||
$ # Start a server on the UNIX socket /var/spool/myapp.sock
|
||||
$ cheroot myapp.wsgi --bind /var/spool/myapp.sock
|
||||
|
||||
$ # Start a server on the abstract UNIX socket CherootServer
|
||||
$ cheroot myapp.wsgi --bind @CherootServer
|
||||
|
||||
.. spelling::
|
||||
|
||||
cli
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from importlib import import_module
|
||||
import os
|
||||
import sys
|
||||
import contextlib
|
||||
|
||||
import six
|
||||
|
||||
from . import server
|
||||
from . import wsgi
|
||||
from ._compat import suppress
|
||||
|
||||
|
||||
__metaclass__ = type
|
||||
|
@ -49,6 +55,7 @@ class TCPSocket(BindLocation):
|
|||
Args:
|
||||
address (str): Host name or IP address
|
||||
port (int): TCP port number
|
||||
|
||||
"""
|
||||
self.bind_addr = address, port
|
||||
|
||||
|
@ -64,9 +71,9 @@ class UnixSocket(BindLocation):
|
|||
class AbstractSocket(BindLocation):
|
||||
"""AbstractSocket."""
|
||||
|
||||
def __init__(self, addr):
|
||||
def __init__(self, abstract_socket):
|
||||
"""Initialize."""
|
||||
self.bind_addr = '\0{}'.format(self.abstract_socket)
|
||||
self.bind_addr = '\x00{sock_path}'.format(sock_path=abstract_socket)
|
||||
|
||||
|
||||
class Application:
|
||||
|
@ -77,8 +84,8 @@ class Application:
|
|||
"""Read WSGI app/Gateway path string and import application module."""
|
||||
mod_path, _, app_path = full_path.partition(':')
|
||||
app = getattr(import_module(mod_path), app_path or 'application')
|
||||
|
||||
with contextlib.suppress(TypeError):
|
||||
# suppress the `TypeError` exception, just in case `app` is not a class
|
||||
with suppress(TypeError):
|
||||
if issubclass(app, server.Gateway):
|
||||
return GatewayYo(app)
|
||||
|
||||
|
@ -128,8 +135,17 @@ class GatewayYo:
|
|||
|
||||
def parse_wsgi_bind_location(bind_addr_string):
|
||||
"""Convert bind address string to a BindLocation."""
|
||||
# if the string begins with an @ symbol, use an abstract socket,
|
||||
# this is the first condition to verify, otherwise the urlparse
|
||||
# validation would detect //@<value> as a valid url with a hostname
|
||||
# with value: "<value>" and port: None
|
||||
if bind_addr_string.startswith('@'):
|
||||
return AbstractSocket(bind_addr_string[1:])
|
||||
|
||||
# try and match for an IP/hostname and port
|
||||
match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string))
|
||||
match = six.moves.urllib.parse.urlparse(
|
||||
'//{addr}'.format(addr=bind_addr_string),
|
||||
)
|
||||
try:
|
||||
addr = match.hostname
|
||||
port = match.port
|
||||
|
@ -139,9 +155,6 @@ def parse_wsgi_bind_location(bind_addr_string):
|
|||
pass
|
||||
|
||||
# else, assume a UNIX socket path
|
||||
# if the string begins with an @ symbol, use an abstract socket
|
||||
if bind_addr_string.startswith('@'):
|
||||
return AbstractSocket(bind_addr_string[1:])
|
||||
return UnixSocket(path=bind_addr_string)
|
||||
|
||||
|
||||
|
@ -151,70 +164,70 @@ def parse_wsgi_bind_addr(bind_addr_string):
|
|||
|
||||
|
||||
_arg_spec = {
|
||||
'_wsgi_app': dict(
|
||||
metavar='APP_MODULE',
|
||||
type=Application.resolve,
|
||||
help='WSGI application callable or cheroot.server.Gateway subclass',
|
||||
),
|
||||
'--bind': dict(
|
||||
metavar='ADDRESS',
|
||||
dest='bind_addr',
|
||||
type=parse_wsgi_bind_addr,
|
||||
default='[::1]:8000',
|
||||
help='Network interface to listen on (default: [::1]:8000)',
|
||||
),
|
||||
'--chdir': dict(
|
||||
metavar='PATH',
|
||||
type=os.chdir,
|
||||
help='Set the working directory',
|
||||
),
|
||||
'--server-name': dict(
|
||||
dest='server_name',
|
||||
type=str,
|
||||
help='Web server name to be advertised via Server HTTP header',
|
||||
),
|
||||
'--threads': dict(
|
||||
metavar='INT',
|
||||
dest='numthreads',
|
||||
type=int,
|
||||
help='Minimum number of worker threads',
|
||||
),
|
||||
'--max-threads': dict(
|
||||
metavar='INT',
|
||||
dest='max',
|
||||
type=int,
|
||||
help='Maximum number of worker threads',
|
||||
),
|
||||
'--timeout': dict(
|
||||
metavar='INT',
|
||||
dest='timeout',
|
||||
type=int,
|
||||
help='Timeout in seconds for accepted connections',
|
||||
),
|
||||
'--shutdown-timeout': dict(
|
||||
metavar='INT',
|
||||
dest='shutdown_timeout',
|
||||
type=int,
|
||||
help='Time in seconds to wait for worker threads to cleanly exit',
|
||||
),
|
||||
'--request-queue-size': dict(
|
||||
metavar='INT',
|
||||
dest='request_queue_size',
|
||||
type=int,
|
||||
help='Maximum number of queued connections',
|
||||
),
|
||||
'--accepted-queue-size': dict(
|
||||
metavar='INT',
|
||||
dest='accepted_queue_size',
|
||||
type=int,
|
||||
help='Maximum number of active requests in queue',
|
||||
),
|
||||
'--accepted-queue-timeout': dict(
|
||||
metavar='INT',
|
||||
dest='accepted_queue_timeout',
|
||||
type=int,
|
||||
help='Timeout in seconds for putting requests into queue',
|
||||
),
|
||||
'_wsgi_app': {
|
||||
'metavar': 'APP_MODULE',
|
||||
'type': Application.resolve,
|
||||
'help': 'WSGI application callable or cheroot.server.Gateway subclass',
|
||||
},
|
||||
'--bind': {
|
||||
'metavar': 'ADDRESS',
|
||||
'dest': 'bind_addr',
|
||||
'type': parse_wsgi_bind_addr,
|
||||
'default': '[::1]:8000',
|
||||
'help': 'Network interface to listen on (default: [::1]:8000)',
|
||||
},
|
||||
'--chdir': {
|
||||
'metavar': 'PATH',
|
||||
'type': os.chdir,
|
||||
'help': 'Set the working directory',
|
||||
},
|
||||
'--server-name': {
|
||||
'dest': 'server_name',
|
||||
'type': str,
|
||||
'help': 'Web server name to be advertised via Server HTTP header',
|
||||
},
|
||||
'--threads': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'numthreads',
|
||||
'type': int,
|
||||
'help': 'Minimum number of worker threads',
|
||||
},
|
||||
'--max-threads': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'max',
|
||||
'type': int,
|
||||
'help': 'Maximum number of worker threads',
|
||||
},
|
||||
'--timeout': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'timeout',
|
||||
'type': int,
|
||||
'help': 'Timeout in seconds for accepted connections',
|
||||
},
|
||||
'--shutdown-timeout': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'shutdown_timeout',
|
||||
'type': int,
|
||||
'help': 'Time in seconds to wait for worker threads to cleanly exit',
|
||||
},
|
||||
'--request-queue-size': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'request_queue_size',
|
||||
'type': int,
|
||||
'help': 'Maximum number of queued connections',
|
||||
},
|
||||
'--accepted-queue-size': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'accepted_queue_size',
|
||||
'type': int,
|
||||
'help': 'Maximum number of active requests in queue',
|
||||
},
|
||||
'--accepted-queue-timeout': {
|
||||
'metavar': 'INT',
|
||||
'dest': 'accepted_queue_timeout',
|
||||
'type': int,
|
||||
'help': 'Timeout in seconds for putting requests into queue',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -5,11 +5,13 @@ __metaclass__ = type
|
|||
|
||||
import io
|
||||
import os
|
||||
import select
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from . import errors
|
||||
from ._compat import selectors
|
||||
from ._compat import suppress
|
||||
from .makefile import MakeFile
|
||||
|
||||
import six
|
||||
|
@ -47,6 +49,69 @@ else:
|
|||
fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC)
|
||||
|
||||
|
||||
class _ThreadsafeSelector:
|
||||
"""Thread-safe wrapper around a DefaultSelector.
|
||||
|
||||
There are 2 thread contexts in which it may be accessed:
|
||||
* the selector thread
|
||||
* one of the worker threads in workers/threadpool.py
|
||||
|
||||
The expected read/write patterns are:
|
||||
* :py:func:`~iter`: selector thread
|
||||
* :py:meth:`register`: selector thread and threadpool,
|
||||
via :py:meth:`~cheroot.workers.threadpool.ThreadPool.put`
|
||||
* :py:meth:`unregister`: selector thread only
|
||||
|
||||
Notably, this means :py:class:`_ThreadsafeSelector` never needs to worry
|
||||
that connections will be removed behind its back.
|
||||
|
||||
The lock is held when iterating or modifying the selector but is not
|
||||
required when :py:meth:`select()ing <selectors.BaseSelector.select>` on it.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._selector = selectors.DefaultSelector()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __len__(self):
|
||||
with self._lock:
|
||||
return len(self._selector.get_map() or {})
|
||||
|
||||
@property
|
||||
def connections(self):
|
||||
"""Retrieve connections registered with the selector."""
|
||||
with self._lock:
|
||||
mapping = self._selector.get_map() or {}
|
||||
for _, (_, sock_fd, _, conn) in mapping.items():
|
||||
yield (sock_fd, conn)
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
"""Register ``fileobj`` with the selector."""
|
||||
with self._lock:
|
||||
return self._selector.register(fileobj, events, data)
|
||||
|
||||
def unregister(self, fileobj):
|
||||
"""Unregister ``fileobj`` from the selector."""
|
||||
with self._lock:
|
||||
return self._selector.unregister(fileobj)
|
||||
|
||||
def select(self, timeout=None):
|
||||
"""Return socket fd and data pairs from selectors.select call.
|
||||
|
||||
Returns entries ready to read in the form:
|
||||
(socket_file_descriptor, connection)
|
||||
"""
|
||||
return (
|
||||
(key.fd, key.data)
|
||||
for key, _ in self._selector.select(timeout=timeout)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the selector."""
|
||||
with self._lock:
|
||||
self._selector.close()
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Class which manages HTTPConnection objects.
|
||||
|
||||
|
@ -60,21 +125,34 @@ class ConnectionManager:
|
|||
server (cheroot.server.HTTPServer): web server object
|
||||
that uses this ConnectionManager instance.
|
||||
"""
|
||||
self._serving = False
|
||||
self._stop_requested = False
|
||||
|
||||
self.server = server
|
||||
self.connections = []
|
||||
self._selector = _ThreadsafeSelector()
|
||||
|
||||
self._selector.register(
|
||||
server.socket.fileno(),
|
||||
selectors.EVENT_READ, data=server,
|
||||
)
|
||||
|
||||
def put(self, conn):
|
||||
"""Put idle connection into the ConnectionManager to be managed.
|
||||
|
||||
Args:
|
||||
conn (cheroot.server.HTTPConnection): HTTP connection
|
||||
to be managed.
|
||||
:param conn: HTTP connection to be managed
|
||||
:type conn: cheroot.server.HTTPConnection
|
||||
"""
|
||||
conn.last_used = time.time()
|
||||
conn.ready_with_data = conn.rfile.has_data()
|
||||
self.connections.append(conn)
|
||||
# if this conn doesn't have any more data waiting to be read,
|
||||
# register it with the selector.
|
||||
if conn.rfile.has_data():
|
||||
self.server.process_conn(conn)
|
||||
else:
|
||||
self._selector.register(
|
||||
conn.socket.fileno(), selectors.EVENT_READ, data=conn,
|
||||
)
|
||||
|
||||
def expire(self):
|
||||
def _expire(self):
|
||||
"""Expire least recently used connections.
|
||||
|
||||
This happens if there are either too many open connections, or if the
|
||||
|
@ -82,107 +160,102 @@ class ConnectionManager:
|
|||
|
||||
This should be called periodically.
|
||||
"""
|
||||
if not self.connections:
|
||||
return
|
||||
# find any connections still registered with the selector
|
||||
# that have not been active recently enough.
|
||||
threshold = time.time() - self.server.timeout
|
||||
timed_out_connections = [
|
||||
(sock_fd, conn)
|
||||
for (sock_fd, conn) in self._selector.connections
|
||||
if conn != self.server and conn.last_used < threshold
|
||||
]
|
||||
for sock_fd, conn in timed_out_connections:
|
||||
self._selector.unregister(sock_fd)
|
||||
conn.close()
|
||||
|
||||
# Look at the first connection - if it can be closed, then do
|
||||
# that, and wait for get_conn to return it.
|
||||
conn = self.connections[0]
|
||||
if conn.closeable:
|
||||
return
|
||||
def stop(self):
|
||||
"""Stop the selector loop in run() synchronously.
|
||||
|
||||
# Too many connections?
|
||||
ka_limit = self.server.keep_alive_conn_limit
|
||||
if ka_limit is not None and len(self.connections) > ka_limit:
|
||||
conn.closeable = True
|
||||
return
|
||||
May take up to half a second.
|
||||
"""
|
||||
self._stop_requested = True
|
||||
while self._serving:
|
||||
time.sleep(0.01)
|
||||
|
||||
# Connection too old?
|
||||
if (conn.last_used + self.server.timeout) < time.time():
|
||||
conn.closeable = True
|
||||
return
|
||||
|
||||
def get_conn(self, server_socket):
|
||||
"""Return a HTTPConnection object which is ready to be handled.
|
||||
|
||||
A connection returned by this method should be ready for a worker
|
||||
to handle it. If there are no connections ready, None will be
|
||||
returned.
|
||||
|
||||
Any connection returned by this method will need to be `put`
|
||||
back if it should be examined again for another request.
|
||||
def run(self, expiration_interval):
|
||||
"""Run the connections selector indefinitely.
|
||||
|
||||
Args:
|
||||
server_socket (socket.socket): Socket to listen to for new
|
||||
connections.
|
||||
Returns:
|
||||
cheroot.server.HTTPConnection instance, or None.
|
||||
expiration_interval (float): Interval, in seconds, at which
|
||||
connections will be checked for expiration.
|
||||
|
||||
Connections that are ready to process are submitted via
|
||||
self.server.process_conn()
|
||||
|
||||
Connections submitted for processing must be `put()`
|
||||
back if they should be examined again for another request.
|
||||
|
||||
Can be shut down by calling `stop()`.
|
||||
"""
|
||||
# Grab file descriptors from sockets, but stop if we find a
|
||||
# connection which is already marked as ready.
|
||||
socket_dict = {}
|
||||
for conn in self.connections:
|
||||
if conn.closeable or conn.ready_with_data:
|
||||
break
|
||||
socket_dict[conn.socket.fileno()] = conn
|
||||
else:
|
||||
# No ready connection.
|
||||
conn = None
|
||||
|
||||
# We have a connection ready for use.
|
||||
if conn:
|
||||
self.connections.remove(conn)
|
||||
return conn
|
||||
|
||||
# Will require a select call.
|
||||
ss_fileno = server_socket.fileno()
|
||||
socket_dict[ss_fileno] = server_socket
|
||||
self._serving = True
|
||||
try:
|
||||
rlist, _, _ = select.select(list(socket_dict), [], [], 0.1)
|
||||
# No available socket.
|
||||
if not rlist:
|
||||
return None
|
||||
self._run(expiration_interval)
|
||||
finally:
|
||||
self._serving = False
|
||||
|
||||
def _run(self, expiration_interval):
|
||||
last_expiration_check = time.time()
|
||||
|
||||
while not self._stop_requested:
|
||||
try:
|
||||
active_list = self._selector.select(timeout=0.01)
|
||||
except OSError:
|
||||
# Mark any connection which no longer appears valid.
|
||||
for fno, conn in list(socket_dict.items()):
|
||||
# If the server socket is invalid, we'll just ignore it and
|
||||
# wait to be shutdown.
|
||||
if fno == ss_fileno:
|
||||
self._remove_invalid_sockets()
|
||||
continue
|
||||
try:
|
||||
os.fstat(fno)
|
||||
except OSError:
|
||||
# Socket is invalid, close the connection, insert at
|
||||
# the front.
|
||||
self.connections.remove(conn)
|
||||
self.connections.insert(0, conn)
|
||||
conn.closeable = True
|
||||
|
||||
# Wait for the next tick to occur.
|
||||
return None
|
||||
|
||||
try:
|
||||
# See if we have a new connection coming in.
|
||||
rlist.remove(ss_fileno)
|
||||
except ValueError:
|
||||
# No new connection, but reuse existing socket.
|
||||
conn = socket_dict[rlist.pop()]
|
||||
for (sock_fd, conn) in active_list:
|
||||
if conn is self.server:
|
||||
# New connection
|
||||
new_conn = self._from_server_socket(self.server.socket)
|
||||
if new_conn is not None:
|
||||
self.server.process_conn(new_conn)
|
||||
else:
|
||||
conn = server_socket
|
||||
# unregister connection from the selector until the server
|
||||
# has read from it and returned it via put()
|
||||
self._selector.unregister(sock_fd)
|
||||
self.server.process_conn(conn)
|
||||
|
||||
# All remaining connections in rlist should be marked as ready.
|
||||
for fno in rlist:
|
||||
socket_dict[fno].ready_with_data = True
|
||||
now = time.time()
|
||||
if (now - last_expiration_check) > expiration_interval:
|
||||
self._expire()
|
||||
last_expiration_check = now
|
||||
|
||||
# New connection.
|
||||
if conn is server_socket:
|
||||
return self._from_server_socket(server_socket)
|
||||
def _remove_invalid_sockets(self):
|
||||
"""Clean up the resources of any broken connections.
|
||||
|
||||
self.connections.remove(conn)
|
||||
return conn
|
||||
This method attempts to detect any connections in an invalid state,
|
||||
unregisters them from the selector and closes the file descriptors of
|
||||
the corresponding network sockets where possible.
|
||||
"""
|
||||
invalid_conns = []
|
||||
for sock_fd, conn in self._selector.connections:
|
||||
if conn is self.server:
|
||||
continue
|
||||
|
||||
def _from_server_socket(self, server_socket):
|
||||
try:
|
||||
os.fstat(sock_fd)
|
||||
except OSError:
|
||||
invalid_conns.append((sock_fd, conn))
|
||||
|
||||
for sock_fd, conn in invalid_conns:
|
||||
self._selector.unregister(sock_fd)
|
||||
# One of the reason on why a socket could cause an error
|
||||
# is that the socket is already closed, ignore the
|
||||
# socket error if we try to close it at this point.
|
||||
# This is equivalent to OSError in Py3
|
||||
with suppress(socket.error):
|
||||
conn.close()
|
||||
|
||||
def _from_server_socket(self, server_socket): # noqa: C901 # FIXME
|
||||
try:
|
||||
s, addr = server_socket.accept()
|
||||
if self.server.stats['Enabled']:
|
||||
|
@ -274,6 +347,23 @@ class ConnectionManager:
|
|||
|
||||
def close(self):
|
||||
"""Close all monitored connections."""
|
||||
for conn in self.connections[:]:
|
||||
for (_, conn) in self._selector.connections:
|
||||
if conn is not self.server: # server closes its own socket
|
||||
conn.close()
|
||||
self.connections = []
|
||||
self._selector.close()
|
||||
|
||||
@property
|
||||
def _num_connections(self):
|
||||
"""Return the current number of connections.
|
||||
|
||||
Includes all connections registered with the selector,
|
||||
minus one for the server socket, which is always registered
|
||||
with the selector.
|
||||
"""
|
||||
return len(self._selector) - 1
|
||||
|
||||
@property
|
||||
def can_add_keepalive_connection(self):
|
||||
"""Flag whether it is allowed to add a new keep-alive connection."""
|
||||
ka_limit = self.server.keep_alive_conn_limit
|
||||
return ka_limit is None or self._num_connections < ka_limit
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""Collection of exceptions raised and/or processed by Cheroot."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
@ -23,14 +24,14 @@ class FatalSSLAlert(Exception):
|
|||
|
||||
|
||||
def plat_specific_errors(*errnames):
|
||||
"""Return error numbers for all errors in errnames on this platform.
|
||||
"""Return error numbers for all errors in ``errnames`` on this platform.
|
||||
|
||||
The 'errno' module contains different global constants depending on
|
||||
the specific platform (OS). This function will return the list of
|
||||
numeric values for a given list of potential names.
|
||||
The :py:mod:`errno` module contains different global constants
|
||||
depending on the specific platform (OS). This function will return
|
||||
the list of numeric values for a given list of potential names.
|
||||
"""
|
||||
missing_attr = set([None, ])
|
||||
unique_nums = set(getattr(errno, k, None) for k in errnames)
|
||||
missing_attr = {None}
|
||||
unique_nums = {getattr(errno, k, None) for k in errnames}
|
||||
return list(unique_nums - missing_attr)
|
||||
|
||||
|
||||
|
@ -56,3 +57,32 @@ socket_errors_nonblocking = plat_specific_errors(
|
|||
if sys.platform == 'darwin':
|
||||
socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE'))
|
||||
socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE'))
|
||||
|
||||
|
||||
acceptable_sock_shutdown_error_codes = {
|
||||
errno.ENOTCONN,
|
||||
errno.EPIPE, errno.ESHUTDOWN, # corresponds to BrokenPipeError in Python 3
|
||||
errno.ECONNRESET, # corresponds to ConnectionResetError in Python 3
|
||||
}
|
||||
"""Errors that may happen during the connection close sequence.
|
||||
|
||||
* ENOTCONN — client is no longer connected
|
||||
* EPIPE — write on a pipe while the other end has been closed
|
||||
* ESHUTDOWN — write on a socket which has been shutdown for writing
|
||||
* ECONNRESET — connection is reset by the peer, we received a TCP RST packet
|
||||
|
||||
Refs:
|
||||
* https://github.com/cherrypy/cheroot/issues/341#issuecomment-735884889
|
||||
* https://bugs.python.org/issue30319
|
||||
* https://bugs.python.org/issue30329
|
||||
* https://github.com/python/cpython/commit/83a2c28
|
||||
* https://github.com/python/cpython/blob/c39b52f/Lib/poplib.py#L297-L302
|
||||
* https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown
|
||||
"""
|
||||
|
||||
try: # py3
|
||||
acceptable_sock_shutdown_exceptions = (
|
||||
BrokenPipeError, ConnectionResetError,
|
||||
)
|
||||
except NameError: # py2
|
||||
acceptable_sock_shutdown_exceptions = ()
|
||||
|
|
|
@ -68,7 +68,7 @@ class MakeFile_PY2(getattr(socket, '_fileobject', object)):
|
|||
self._refcount -= 1
|
||||
|
||||
def write(self, data):
|
||||
"""Sendall for non-blocking sockets."""
|
||||
"""Send entire data contents for non-blocking sockets."""
|
||||
bytes_sent = 0
|
||||
data_mv = memoryview(data)
|
||||
payload_size = len(data_mv)
|
||||
|
@ -122,7 +122,7 @@ class MakeFile_PY2(getattr(socket, '_fileobject', object)):
|
|||
# FauxSocket is no longer needed
|
||||
del FauxSocket
|
||||
|
||||
if not _fileobject_uses_str_type:
|
||||
if not _fileobject_uses_str_type: # noqa: C901 # FIXME
|
||||
def read(self, size=-1):
|
||||
"""Read data from the socket to buffer."""
|
||||
# Use max, disallow tiny reads in a loop as they are very
|
||||
|
|
|
@ -7,12 +7,13 @@ sticking incoming connections onto a Queue::
|
|||
|
||||
server = HTTPServer(...)
|
||||
server.start()
|
||||
-> while True:
|
||||
tick()
|
||||
# This blocks until a request comes in:
|
||||
child = socket.accept()
|
||||
-> serve()
|
||||
while ready:
|
||||
_connections.run()
|
||||
while not stop_requested:
|
||||
child = socket.accept() # blocks until a request comes in
|
||||
conn = HTTPConnection(child, ...)
|
||||
server.requests.put(conn)
|
||||
server.process_conn(conn) # adds conn to threadpool
|
||||
|
||||
Worker threads are kept in a pool and poll the Queue, popping off and then
|
||||
handling each connection in turn. Each connection can consist of an arbitrary
|
||||
|
@ -58,7 +59,6 @@ And now for a trivial doctest to exercise the test suite
|
|||
|
||||
>>> 'HTTPServer' in globals()
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
@ -74,6 +74,8 @@ import time
|
|||
import traceback as traceback_
|
||||
import logging
|
||||
import platform
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
|
@ -93,6 +95,7 @@ from .makefile import MakeFile, StreamWriter
|
|||
|
||||
__all__ = (
|
||||
'HTTPRequest', 'HTTPConnection', 'HTTPServer',
|
||||
'HeaderReader', 'DropUnderscoreHeaderReader',
|
||||
'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile',
|
||||
'Gateway', 'get_ssl_adapter_class',
|
||||
)
|
||||
|
@ -156,7 +159,10 @@ EMPTY = b''
|
|||
ASTERISK = b'*'
|
||||
FORWARD_SLASH = b'/'
|
||||
QUOTED_SLASH = b'%2F'
|
||||
QUOTED_SLASH_REGEX = re.compile(b'(?i)' + QUOTED_SLASH)
|
||||
QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH)))
|
||||
|
||||
|
||||
_STOPPING_FOR_INTERRUPT = object() # sentinel used during shutdown
|
||||
|
||||
|
||||
comma_separated_headers = [
|
||||
|
@ -179,7 +185,7 @@ class HeaderReader:
|
|||
Interface and default implementation.
|
||||
"""
|
||||
|
||||
def __call__(self, rfile, hdict=None):
|
||||
def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME
|
||||
"""
|
||||
Read headers from the given stream into the given header dict.
|
||||
|
||||
|
@ -248,15 +254,14 @@ class DropUnderscoreHeaderReader(HeaderReader):
|
|||
|
||||
|
||||
class SizeCheckWrapper:
|
||||
"""Wraps a file-like object, raising MaxSizeExceeded if too large."""
|
||||
"""Wraps a file-like object, raising MaxSizeExceeded if too large.
|
||||
|
||||
:param rfile: ``file`` of a limited size
|
||||
:param int maxlen: maximum length of the file being read
|
||||
"""
|
||||
|
||||
def __init__(self, rfile, maxlen):
|
||||
"""Initialize SizeCheckWrapper instance.
|
||||
|
||||
Args:
|
||||
rfile (file): file of a limited size
|
||||
maxlen (int): maximum length of the file being read
|
||||
"""
|
||||
"""Initialize SizeCheckWrapper instance."""
|
||||
self.rfile = rfile
|
||||
self.maxlen = maxlen
|
||||
self.bytes_read = 0
|
||||
|
@ -266,14 +271,12 @@ class SizeCheckWrapper:
|
|||
raise errors.MaxSizeExceeded()
|
||||
|
||||
def read(self, size=None):
|
||||
"""Read a chunk from rfile buffer and return it.
|
||||
"""Read a chunk from ``rfile`` buffer and return it.
|
||||
|
||||
Args:
|
||||
size (int): amount of data to read
|
||||
|
||||
Returns:
|
||||
bytes: Chunk from rfile, limited by size if specified.
|
||||
:param int size: amount of data to read
|
||||
|
||||
:returns: chunk from ``rfile``, limited by size if specified
|
||||
:rtype: bytes
|
||||
"""
|
||||
data = self.rfile.read(size)
|
||||
self.bytes_read += len(data)
|
||||
|
@ -281,14 +284,12 @@ class SizeCheckWrapper:
|
|||
return data
|
||||
|
||||
def readline(self, size=None):
|
||||
"""Read a single line from rfile buffer and return it.
|
||||
"""Read a single line from ``rfile`` buffer and return it.
|
||||
|
||||
Args:
|
||||
size (int): minimum amount of data to read
|
||||
|
||||
Returns:
|
||||
bytes: One line from rfile.
|
||||
:param int size: minimum amount of data to read
|
||||
|
||||
:returns: one line from ``rfile``
|
||||
:rtype: bytes
|
||||
"""
|
||||
if size is not None:
|
||||
data = self.rfile.readline(size)
|
||||
|
@ -309,14 +310,12 @@ class SizeCheckWrapper:
|
|||
return EMPTY.join(res)
|
||||
|
||||
def readlines(self, sizehint=0):
|
||||
"""Read all lines from rfile buffer and return them.
|
||||
"""Read all lines from ``rfile`` buffer and return them.
|
||||
|
||||
Args:
|
||||
sizehint (int): hint of minimum amount of data to read
|
||||
|
||||
Returns:
|
||||
list[bytes]: Lines of bytes read from rfile.
|
||||
:param int sizehint: hint of minimum amount of data to read
|
||||
|
||||
:returns: lines of bytes read from ``rfile``
|
||||
:rtype: list[bytes]
|
||||
"""
|
||||
# Shamelessly stolen from StringIO
|
||||
total = 0
|
||||
|
@ -331,7 +330,7 @@ class SizeCheckWrapper:
|
|||
return lines
|
||||
|
||||
def close(self):
|
||||
"""Release resources allocated for rfile."""
|
||||
"""Release resources allocated for ``rfile``."""
|
||||
self.rfile.close()
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -349,28 +348,24 @@ class SizeCheckWrapper:
|
|||
|
||||
|
||||
class KnownLengthRFile:
|
||||
"""Wraps a file-like object, returning an empty string when exhausted."""
|
||||
"""Wraps a file-like object, returning an empty string when exhausted.
|
||||
|
||||
:param rfile: ``file`` of a known size
|
||||
:param int content_length: length of the file being read
|
||||
"""
|
||||
|
||||
def __init__(self, rfile, content_length):
|
||||
"""Initialize KnownLengthRFile instance.
|
||||
|
||||
Args:
|
||||
rfile (file): file of a known size
|
||||
content_length (int): length of the file being read
|
||||
|
||||
"""
|
||||
"""Initialize KnownLengthRFile instance."""
|
||||
self.rfile = rfile
|
||||
self.remaining = content_length
|
||||
|
||||
def read(self, size=None):
|
||||
"""Read a chunk from rfile buffer and return it.
|
||||
"""Read a chunk from ``rfile`` buffer and return it.
|
||||
|
||||
Args:
|
||||
size (int): amount of data to read
|
||||
|
||||
Returns:
|
||||
bytes: Chunk from rfile, limited by size if specified.
|
||||
:param int size: amount of data to read
|
||||
|
||||
:rtype: bytes
|
||||
:returns: chunk from ``rfile``, limited by size if specified
|
||||
"""
|
||||
if self.remaining == 0:
|
||||
return b''
|
||||
|
@ -384,14 +379,12 @@ class KnownLengthRFile:
|
|||
return data
|
||||
|
||||
def readline(self, size=None):
|
||||
"""Read a single line from rfile buffer and return it.
|
||||
"""Read a single line from ``rfile`` buffer and return it.
|
||||
|
||||
Args:
|
||||
size (int): minimum amount of data to read
|
||||
|
||||
Returns:
|
||||
bytes: One line from rfile.
|
||||
:param int size: minimum amount of data to read
|
||||
|
||||
:returns: one line from ``rfile``
|
||||
:rtype: bytes
|
||||
"""
|
||||
if self.remaining == 0:
|
||||
return b''
|
||||
|
@ -405,14 +398,12 @@ class KnownLengthRFile:
|
|||
return data
|
||||
|
||||
def readlines(self, sizehint=0):
|
||||
"""Read all lines from rfile buffer and return them.
|
||||
"""Read all lines from ``rfile`` buffer and return them.
|
||||
|
||||
Args:
|
||||
sizehint (int): hint of minimum amount of data to read
|
||||
|
||||
Returns:
|
||||
list[bytes]: Lines of bytes read from rfile.
|
||||
:param int sizehint: hint of minimum amount of data to read
|
||||
|
||||
:returns: lines of bytes read from ``rfile``
|
||||
:rtype: list[bytes]
|
||||
"""
|
||||
# Shamelessly stolen from StringIO
|
||||
total = 0
|
||||
|
@ -427,7 +418,7 @@ class KnownLengthRFile:
|
|||
return lines
|
||||
|
||||
def close(self):
|
||||
"""Release resources allocated for rfile."""
|
||||
"""Release resources allocated for ``rfile``."""
|
||||
self.rfile.close()
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -449,16 +440,14 @@ class ChunkedRFile:
|
|||
This class is intended to provide a conforming wsgi.input value for
|
||||
request entities that have been encoded with the 'chunked' transfer
|
||||
encoding.
|
||||
|
||||
:param rfile: file encoded with the 'chunked' transfer encoding
|
||||
:param int maxlen: maximum length of the file being read
|
||||
:param int bufsize: size of the buffer used to read the file
|
||||
"""
|
||||
|
||||
def __init__(self, rfile, maxlen, bufsize=8192):
|
||||
"""Initialize ChunkedRFile instance.
|
||||
|
||||
Args:
|
||||
rfile (file): file encoded with the 'chunked' transfer encoding
|
||||
maxlen (int): maximum length of the file being read
|
||||
bufsize (int): size of the buffer used to read the file
|
||||
"""
|
||||
"""Initialize ChunkedRFile instance."""
|
||||
self.rfile = rfile
|
||||
self.maxlen = maxlen
|
||||
self.bytes_read = 0
|
||||
|
@ -484,7 +473,10 @@ class ChunkedRFile:
|
|||
chunk_size = line.pop(0)
|
||||
chunk_size = int(chunk_size, 16)
|
||||
except ValueError:
|
||||
raise ValueError('Bad chunked transfer size: ' + repr(chunk_size))
|
||||
raise ValueError(
|
||||
'Bad chunked transfer size: {chunk_size!r}'.
|
||||
format(chunk_size=chunk_size),
|
||||
)
|
||||
|
||||
if chunk_size <= 0:
|
||||
self.closed = True
|
||||
|
@ -507,14 +499,12 @@ class ChunkedRFile:
|
|||
)
|
||||
|
||||
def read(self, size=None):
|
||||
"""Read a chunk from rfile buffer and return it.
|
||||
"""Read a chunk from ``rfile`` buffer and return it.
|
||||
|
||||
Args:
|
||||
size (int): amount of data to read
|
||||
|
||||
Returns:
|
||||
bytes: Chunk from rfile, limited by size if specified.
|
||||
:param int size: amount of data to read
|
||||
|
||||
:returns: chunk from ``rfile``, limited by size if specified
|
||||
:rtype: bytes
|
||||
"""
|
||||
data = EMPTY
|
||||
|
||||
|
@ -540,14 +530,12 @@ class ChunkedRFile:
|
|||
self.buffer = EMPTY
|
||||
|
||||
def readline(self, size=None):
|
||||
"""Read a single line from rfile buffer and return it.
|
||||
"""Read a single line from ``rfile`` buffer and return it.
|
||||
|
||||
Args:
|
||||
size (int): minimum amount of data to read
|
||||
|
||||
Returns:
|
||||
bytes: One line from rfile.
|
||||
:param int size: minimum amount of data to read
|
||||
|
||||
:returns: one line from ``rfile``
|
||||
:rtype: bytes
|
||||
"""
|
||||
data = EMPTY
|
||||
|
||||
|
@ -583,14 +571,12 @@ class ChunkedRFile:
|
|||
self.buffer = self.buffer[newline_pos:]
|
||||
|
||||
def readlines(self, sizehint=0):
|
||||
"""Read all lines from rfile buffer and return them.
|
||||
"""Read all lines from ``rfile`` buffer and return them.
|
||||
|
||||
Args:
|
||||
sizehint (int): hint of minimum amount of data to read
|
||||
|
||||
Returns:
|
||||
list[bytes]: Lines of bytes read from rfile.
|
||||
:param int sizehint: hint of minimum amount of data to read
|
||||
|
||||
:returns: lines of bytes read from ``rfile``
|
||||
:rtype: list[bytes]
|
||||
"""
|
||||
# Shamelessly stolen from StringIO
|
||||
total = 0
|
||||
|
@ -635,7 +621,7 @@ class ChunkedRFile:
|
|||
yield line
|
||||
|
||||
def close(self):
|
||||
"""Release resources allocated for rfile."""
|
||||
"""Release resources allocated for ``rfile``."""
|
||||
self.rfile.close()
|
||||
|
||||
|
||||
|
@ -744,7 +730,7 @@ class HTTPRequest:
|
|||
|
||||
self.ready = True
|
||||
|
||||
def read_request_line(self):
|
||||
def read_request_line(self): # noqa: C901 # FIXME
|
||||
"""Read and parse first line of the HTTP request.
|
||||
|
||||
Returns:
|
||||
|
@ -845,7 +831,7 @@ class HTTPRequest:
|
|||
|
||||
# `urlsplit()` above parses "example.com:3128" as path part of URI.
|
||||
# this is a workaround, which makes it detect netloc correctly
|
||||
uri_split = urllib.parse.urlsplit(b'//' + uri)
|
||||
uri_split = urllib.parse.urlsplit(b''.join((b'//', uri)))
|
||||
_scheme, _authority, _path, _qs, _fragment = uri_split
|
||||
_port = EMPTY
|
||||
try:
|
||||
|
@ -975,8 +961,14 @@ class HTTPRequest:
|
|||
|
||||
return True
|
||||
|
||||
def read_request_headers(self):
|
||||
"""Read self.rfile into self.inheaders. Return success."""
|
||||
def read_request_headers(self): # noqa: C901 # FIXME
|
||||
"""Read ``self.rfile`` into ``self.inheaders``.
|
||||
|
||||
Ref: :py:attr:`self.inheaders <HTTPRequest.outheaders>`.
|
||||
|
||||
:returns: success status
|
||||
:rtype: bool
|
||||
"""
|
||||
# then all the http headers
|
||||
try:
|
||||
self.header_reader(self.rfile, self.inheaders)
|
||||
|
@ -1054,8 +1046,10 @@ class HTTPRequest:
|
|||
# Don't use simple_response here, because it emits headers
|
||||
# we don't want. See
|
||||
# https://github.com/cherrypy/cherrypy/issues/951
|
||||
msg = self.server.protocol.encode('ascii')
|
||||
msg += b' 100 Continue\r\n\r\n'
|
||||
msg = b''.join((
|
||||
self.server.protocol.encode('ascii'), SPACE, b'100 Continue',
|
||||
CRLF, CRLF,
|
||||
))
|
||||
try:
|
||||
self.conn.wfile.write(msg)
|
||||
except socket.error as ex:
|
||||
|
@ -1138,10 +1132,11 @@ class HTTPRequest:
|
|||
else:
|
||||
self.conn.wfile.write(chunk)
|
||||
|
||||
def send_headers(self):
|
||||
def send_headers(self): # noqa: C901 # FIXME
|
||||
"""Assert, process, and send the HTTP response message-headers.
|
||||
|
||||
You must set self.status, and self.outheaders before calling this.
|
||||
You must set ``self.status``, and :py:attr:`self.outheaders
|
||||
<HTTPRequest.outheaders>` before calling this.
|
||||
"""
|
||||
hkeys = [key.lower() for key, value in self.outheaders]
|
||||
status = int(self.status[:3])
|
||||
|
@ -1168,6 +1163,12 @@ class HTTPRequest:
|
|||
# Closing the conn is the only way to determine len.
|
||||
self.close_connection = True
|
||||
|
||||
# Override the decision to not close the connection if the connection
|
||||
# manager doesn't have space for it.
|
||||
if not self.close_connection:
|
||||
can_keep = self.server.can_add_keepalive_connection
|
||||
self.close_connection = not can_keep
|
||||
|
||||
if b'connection' not in hkeys:
|
||||
if self.response_protocol == 'HTTP/1.1':
|
||||
# Both server and client are HTTP/1.1 or better
|
||||
|
@ -1178,6 +1179,14 @@ class HTTPRequest:
|
|||
if not self.close_connection:
|
||||
self.outheaders.append((b'Connection', b'Keep-Alive'))
|
||||
|
||||
if (b'Connection', b'Keep-Alive') in self.outheaders:
|
||||
self.outheaders.append((
|
||||
b'Keep-Alive',
|
||||
u'timeout={connection_timeout}'.
|
||||
format(connection_timeout=self.server.timeout).
|
||||
encode('ISO-8859-1'),
|
||||
))
|
||||
|
||||
if (not self.close_connection) and (not self.chunked_read):
|
||||
# Read any remaining request body data on the socket.
|
||||
# "If an origin server receives a request that does not include an
|
||||
|
@ -1228,9 +1237,7 @@ class HTTPConnection:
|
|||
peercreds_resolve_enabled = False
|
||||
|
||||
# Fields set by ConnectionManager.
|
||||
closeable = False
|
||||
last_used = None
|
||||
ready_with_data = False
|
||||
|
||||
def __init__(self, server, sock, makefile=MakeFile):
|
||||
"""Initialize HTTPConnection instance.
|
||||
|
@ -1259,7 +1266,7 @@ class HTTPConnection:
|
|||
lru_cache(maxsize=1)(self.get_peer_creds)
|
||||
)
|
||||
|
||||
def communicate(self):
|
||||
def communicate(self): # noqa: C901 # FIXME
|
||||
"""Read each request and respond appropriately.
|
||||
|
||||
Returns true if the connection should be kept open.
|
||||
|
@ -1351,6 +1358,9 @@ class HTTPConnection:
|
|||
|
||||
if not self.linger:
|
||||
self._close_kernel_socket()
|
||||
# close the socket file descriptor
|
||||
# (will be closed in the OS if there is no
|
||||
# other reference to the underlying socket)
|
||||
self.socket.close()
|
||||
else:
|
||||
# On the other hand, sometimes we want to hang around for a bit
|
||||
|
@ -1426,12 +1436,12 @@ class HTTPConnection:
|
|||
return gid
|
||||
|
||||
def resolve_peer_creds(self): # LRU cached on per-instance basis
|
||||
"""Return the username and group tuple of the peercreds if available.
|
||||
"""Look up the username and group tuple of the ``PEERCREDS``.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: in case of unsupported OS
|
||||
RuntimeError: in case of UID/GID lookup unsupported or disabled
|
||||
:returns: the username and group tuple of the ``PEERCREDS``
|
||||
|
||||
:raises NotImplementedError: if the OS is unsupported
|
||||
:raises RuntimeError: if UID/GID lookup is unsupported or disabled
|
||||
"""
|
||||
if not IS_UID_GID_RESOLVABLE:
|
||||
raise NotImplementedError(
|
||||
|
@ -1462,18 +1472,20 @@ class HTTPConnection:
|
|||
return group
|
||||
|
||||
def _close_kernel_socket(self):
|
||||
"""Close kernel socket in outdated Python versions.
|
||||
"""Terminate the connection at the transport level."""
|
||||
# Honor ``sock_shutdown`` for PyOpenSSL connections.
|
||||
shutdown = getattr(
|
||||
self.socket, 'sock_shutdown',
|
||||
self.socket.shutdown,
|
||||
)
|
||||
|
||||
On old Python versions,
|
||||
Python's socket module does NOT call close on the kernel
|
||||
socket when you call socket.close(). We do so manually here
|
||||
because we want this server to send a FIN TCP segment
|
||||
immediately. Note this must be called *before* calling
|
||||
socket.close(), because the latter drops its reference to
|
||||
the kernel socket.
|
||||
"""
|
||||
if six.PY2 and hasattr(self.socket, '_sock'):
|
||||
self.socket._sock.close()
|
||||
try:
|
||||
shutdown(socket.SHUT_RDWR) # actually send a TCP FIN
|
||||
except errors.acceptable_sock_shutdown_exceptions:
|
||||
pass
|
||||
except socket.error as e:
|
||||
if e.errno not in errors.acceptable_sock_shutdown_error_codes:
|
||||
raise
|
||||
|
||||
|
||||
class HTTPServer:
|
||||
|
@ -1515,7 +1527,12 @@ class HTTPServer:
|
|||
timeout = 10
|
||||
"""The timeout in seconds for accepted connections (default 10)."""
|
||||
|
||||
version = 'Cheroot/' + __version__
|
||||
expiration_interval = 0.5
|
||||
"""The interval, in seconds, at which the server checks for
|
||||
expired connections (default 0.5).
|
||||
"""
|
||||
|
||||
version = 'Cheroot/{version!s}'.format(version=__version__)
|
||||
"""A version string for the HTTPServer."""
|
||||
|
||||
software = None
|
||||
|
@ -1540,16 +1557,23 @@ class HTTPServer:
|
|||
"""The class to use for handling HTTP connections."""
|
||||
|
||||
ssl_adapter = None
|
||||
"""An instance of ssl.Adapter (or a subclass).
|
||||
"""An instance of ``ssl.Adapter`` (or a subclass).
|
||||
|
||||
You must have the corresponding SSL driver library installed.
|
||||
Ref: :py:class:`ssl.Adapter <cheroot.ssl.Adapter>`.
|
||||
|
||||
You must have the corresponding TLS driver library installed.
|
||||
"""
|
||||
|
||||
peercreds_enabled = False
|
||||
"""If True, peer cred lookup can be performed via UNIX domain socket."""
|
||||
"""
|
||||
If :py:data:`True`, peer creds will be looked up via UNIX domain socket.
|
||||
"""
|
||||
|
||||
peercreds_resolve_enabled = False
|
||||
"""If True, username/group will be looked up in the OS from peercreds."""
|
||||
"""
|
||||
If :py:data:`True`, username/group will be looked up in the OS from
|
||||
``PEERCREDS``-provided IDs.
|
||||
"""
|
||||
|
||||
keep_alive_conn_limit = 10
|
||||
"""The maximum number of waiting keep-alive connections that will be kept open.
|
||||
|
@ -1577,7 +1601,6 @@ class HTTPServer:
|
|||
self.requests = threadpool.ThreadPool(
|
||||
self, min=minthreads or 1, max=maxthreads,
|
||||
)
|
||||
self.connections = connections.ConnectionManager(self)
|
||||
|
||||
if not server_name:
|
||||
server_name = self.version
|
||||
|
@ -1603,25 +1626,29 @@ class HTTPServer:
|
|||
'Threads Idle': lambda s: getattr(self.requests, 'idle', None),
|
||||
'Socket Errors': 0,
|
||||
'Requests': lambda s: (not s['Enabled']) and -1 or sum(
|
||||
[w['Requests'](w) for w in s['Worker Threads'].values()], 0,
|
||||
(w['Requests'](w) for w in s['Worker Threads'].values()), 0,
|
||||
),
|
||||
'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum(
|
||||
[w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0,
|
||||
(w['Bytes Read'](w) for w in s['Worker Threads'].values()), 0,
|
||||
),
|
||||
'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum(
|
||||
[w['Bytes Written'](w) for w in s['Worker Threads'].values()],
|
||||
(w['Bytes Written'](w) for w in s['Worker Threads'].values()),
|
||||
0,
|
||||
),
|
||||
'Work Time': lambda s: (not s['Enabled']) and -1 or sum(
|
||||
[w['Work Time'](w) for w in s['Worker Threads'].values()], 0,
|
||||
(w['Work Time'](w) for w in s['Worker Threads'].values()), 0,
|
||||
),
|
||||
'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum(
|
||||
[w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
|
||||
for w in s['Worker Threads'].values()], 0,
|
||||
(
|
||||
w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
|
||||
for w in s['Worker Threads'].values()
|
||||
), 0,
|
||||
),
|
||||
'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum(
|
||||
[w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
|
||||
for w in s['Worker Threads'].values()], 0,
|
||||
(
|
||||
w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
|
||||
for w in s['Worker Threads'].values()
|
||||
), 0,
|
||||
),
|
||||
'Worker Threads': {},
|
||||
}
|
||||
|
@ -1645,17 +1672,27 @@ class HTTPServer:
|
|||
def bind_addr(self):
|
||||
"""Return the interface on which to listen for connections.
|
||||
|
||||
For TCP sockets, a (host, port) tuple. Host values may be any IPv4
|
||||
or IPv6 address, or any valid hostname. The string 'localhost' is a
|
||||
synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6).
|
||||
The string '0.0.0.0' is a special IPv4 entry meaning "any active
|
||||
interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for
|
||||
IPv6. The empty string or None are not allowed.
|
||||
For TCP sockets, a (host, port) tuple. Host values may be any
|
||||
:term:`IPv4` or :term:`IPv6` address, or any valid hostname.
|
||||
The string 'localhost' is a synonym for '127.0.0.1' (or '::1',
|
||||
if your hosts file prefers :term:`IPv6`).
|
||||
The string '0.0.0.0' is a special :term:`IPv4` entry meaning
|
||||
"any active interface" (INADDR_ANY), and '::' is the similar
|
||||
IN6ADDR_ANY for :term:`IPv6`.
|
||||
The empty string or :py:data:`None` are not allowed.
|
||||
|
||||
For UNIX sockets, supply the filename as a string.
|
||||
For UNIX sockets, supply the file name as a string.
|
||||
|
||||
Systemd socket activation is automatic and doesn't require tempering
|
||||
with this variable.
|
||||
|
||||
.. glossary::
|
||||
|
||||
:abbr:`IPv4 (Internet Protocol version 4)`
|
||||
Internet Protocol version 4
|
||||
|
||||
:abbr:`IPv6 (Internet Protocol version 6)`
|
||||
Internet Protocol version 6
|
||||
"""
|
||||
return self._bind_addr
|
||||
|
||||
|
@ -1695,7 +1732,7 @@ class HTTPServer:
|
|||
self.stop()
|
||||
raise
|
||||
|
||||
def prepare(self):
|
||||
def prepare(self): # noqa: C901 # FIXME
|
||||
"""Prepare server to serving requests.
|
||||
|
||||
It binds a socket's port, setups the socket to ``listen()`` and does
|
||||
|
@ -1757,6 +1794,9 @@ class HTTPServer:
|
|||
self.socket.settimeout(1)
|
||||
self.socket.listen(self.request_queue_size)
|
||||
|
||||
# must not be accessed once stop() has been called
|
||||
self._connections = connections.ConnectionManager(self)
|
||||
|
||||
# Create worker threads
|
||||
self.requests.start()
|
||||
|
||||
|
@ -1765,20 +1805,21 @@ class HTTPServer:
|
|||
|
||||
def serve(self):
|
||||
"""Serve requests, after invoking :func:`prepare()`."""
|
||||
while self.ready:
|
||||
while self.ready and not self.interrupt:
|
||||
try:
|
||||
self.tick()
|
||||
self._connections.run(self.expiration_interval)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except Exception:
|
||||
self.error_log(
|
||||
'Error in HTTPServer.tick', level=logging.ERROR,
|
||||
'Error in HTTPServer.serve', level=logging.ERROR,
|
||||
traceback=True,
|
||||
)
|
||||
|
||||
# raise exceptions reported by any worker threads,
|
||||
# such that the exception is raised from the serve() thread.
|
||||
if self.interrupt:
|
||||
while self.interrupt is True:
|
||||
# Wait for self.stop() to complete. See _set_interrupt.
|
||||
while self._stopping_for_interrupt:
|
||||
time.sleep(0.1)
|
||||
if self.interrupt:
|
||||
raise self.interrupt
|
||||
|
@ -1795,6 +1836,31 @@ class HTTPServer:
|
|||
self.prepare()
|
||||
self.serve()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _run_in_thread(self):
|
||||
"""Context manager for running this server in a thread."""
|
||||
self.prepare()
|
||||
thread = threading.Thread(target=self.serve)
|
||||
thread.setDaemon(True)
|
||||
thread.start()
|
||||
try:
|
||||
yield thread
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
@property
|
||||
def can_add_keepalive_connection(self):
|
||||
"""Flag whether it is allowed to add a new keep-alive connection."""
|
||||
return self.ready and self._connections.can_add_keepalive_connection
|
||||
|
||||
def put_conn(self, conn):
|
||||
"""Put an idle connection back into the ConnectionManager."""
|
||||
if self.ready:
|
||||
self._connections.put(conn)
|
||||
else:
|
||||
# server is shutting down, just close it
|
||||
conn.close()
|
||||
|
||||
def error_log(self, msg='', level=20, traceback=False):
|
||||
"""Write error message to log.
|
||||
|
||||
|
@ -1804,7 +1870,7 @@ class HTTPServer:
|
|||
traceback (bool): add traceback to output or not
|
||||
"""
|
||||
# Override this in subclasses as desired
|
||||
sys.stderr.write(msg + '\n')
|
||||
sys.stderr.write('{msg!s}\n'.format(msg=msg))
|
||||
sys.stderr.flush()
|
||||
if traceback:
|
||||
tblines = traceback_.format_exc()
|
||||
|
@ -1822,7 +1888,7 @@ class HTTPServer:
|
|||
self.bind_addr = self.resolve_real_bind_addr(sock)
|
||||
return sock
|
||||
|
||||
def bind_unix_socket(self, bind_addr):
|
||||
def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME
|
||||
"""Create (or recreate) a UNIX socket object."""
|
||||
if IS_WINDOWS:
|
||||
"""
|
||||
|
@ -1965,7 +2031,7 @@ class HTTPServer:
|
|||
|
||||
@staticmethod
|
||||
def resolve_real_bind_addr(socket_):
|
||||
"""Retrieve actual bind addr from bound socket."""
|
||||
"""Retrieve actual bind address from bound socket."""
|
||||
# FIXME: keep requested bind_addr separate real bound_addr (port
|
||||
# is different in case of ephemeral port 0)
|
||||
bind_addr = socket_.getsockname()
|
||||
|
@ -1985,40 +2051,49 @@ class HTTPServer:
|
|||
|
||||
return bind_addr
|
||||
|
||||
def tick(self):
|
||||
"""Accept a new connection and put it on the Queue."""
|
||||
if not self.ready:
|
||||
return
|
||||
|
||||
conn = self.connections.get_conn(self.socket)
|
||||
if conn:
|
||||
def process_conn(self, conn):
|
||||
"""Process an incoming HTTPConnection."""
|
||||
try:
|
||||
self.requests.put(conn)
|
||||
except queue.Full:
|
||||
# Just drop the conn. TODO: write 503 back?
|
||||
conn.close()
|
||||
|
||||
self.connections.expire()
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
"""Flag interrupt of the server."""
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def _stopping_for_interrupt(self):
|
||||
"""Return whether the server is responding to an interrupt."""
|
||||
return self._interrupt is _STOPPING_FOR_INTERRUPT
|
||||
|
||||
@interrupt.setter
|
||||
def interrupt(self, interrupt):
|
||||
"""Perform the shutdown of this server and save the exception."""
|
||||
self._interrupt = True
|
||||
"""Perform the shutdown of this server and save the exception.
|
||||
|
||||
Typically invoked by a worker thread in
|
||||
:py:mod:`~cheroot.workers.threadpool`, the exception is raised
|
||||
from the thread running :py:meth:`serve` once :py:meth:`stop`
|
||||
has completed.
|
||||
"""
|
||||
self._interrupt = _STOPPING_FOR_INTERRUPT
|
||||
self.stop()
|
||||
self._interrupt = interrupt
|
||||
|
||||
def stop(self):
|
||||
def stop(self): # noqa: C901 # FIXME
|
||||
"""Gracefully shutdown a server that is serving forever."""
|
||||
if not self.ready:
|
||||
return # already stopped
|
||||
|
||||
self.ready = False
|
||||
if self._start_time is not None:
|
||||
self._run_time += (time.time() - self._start_time)
|
||||
self._start_time = None
|
||||
|
||||
self._connections.stop()
|
||||
|
||||
sock = getattr(self, 'socket', None)
|
||||
if sock:
|
||||
if not isinstance(
|
||||
|
@ -2060,7 +2135,7 @@ class HTTPServer:
|
|||
sock.close()
|
||||
self.socket = None
|
||||
|
||||
self.connections.close()
|
||||
self._connections.close()
|
||||
self.requests.stop(self.shutdown_timeout)
|
||||
|
||||
|
||||
|
@ -2108,7 +2183,9 @@ def get_ssl_adapter_class(name='builtin'):
|
|||
try:
|
||||
adapter = getattr(mod, attr_name)
|
||||
except AttributeError:
|
||||
raise AttributeError("'%s' object has no attribute '%s'"
|
||||
% (mod_path, attr_name))
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute '%s'"
|
||||
% (mod_path, attr_name),
|
||||
)
|
||||
|
||||
return adapter
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
A library for integrating Python's builtin ``ssl`` library with Cheroot.
|
||||
A library for integrating Python's builtin :py:mod:`ssl` library with Cheroot.
|
||||
|
||||
The ssl module must be importable for SSL functionality.
|
||||
The :py:mod:`ssl` module must be importable for SSL functionality.
|
||||
|
||||
To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
|
||||
``BuiltinSSLAdapter``.
|
||||
|
@ -10,6 +10,10 @@ To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
|
|||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
|
@ -27,13 +31,12 @@ import six
|
|||
|
||||
from . import Adapter
|
||||
from .. import errors
|
||||
from .._compat import IS_ABOVE_OPENSSL10
|
||||
from .._compat import IS_ABOVE_OPENSSL10, suppress
|
||||
from ..makefile import StreamReader, StreamWriter
|
||||
from ..server import HTTPServer
|
||||
|
||||
if six.PY2:
|
||||
import socket
|
||||
generic_socket_error = socket.error
|
||||
del socket
|
||||
else:
|
||||
generic_socket_error = OSError
|
||||
|
||||
|
@ -49,37 +52,159 @@ def _assert_ssl_exc_contains(exc, *msgs):
|
|||
return any(m.lower() in err_msg_lower for m in msgs)
|
||||
|
||||
|
||||
def _loopback_for_cert_thread(context, server):
|
||||
"""Wrap a socket in ssl and perform the server-side handshake."""
|
||||
# As we only care about parsing the certificate, the failure of
|
||||
# which will cause an exception in ``_loopback_for_cert``,
|
||||
# we can safely ignore connection and ssl related exceptions. Ref:
|
||||
# https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030
|
||||
with suppress(ssl.SSLError, OSError):
|
||||
with context.wrap_socket(
|
||||
server, do_handshake_on_connect=True, server_side=True,
|
||||
) as ssl_sock:
|
||||
# in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server
|
||||
# sends the client session tickets that can be used to
|
||||
# resume the TLS session on a new connection without
|
||||
# performing the full handshake again. session tickets are
|
||||
# sent as a post-handshake message at some _unspecified_
|
||||
# time and thus a successful connection may be closed
|
||||
# without the client having received the tickets.
|
||||
# Unfortunately, on Windows (Python 3.8+), this is treated
|
||||
# as an incomplete handshake on the server side and a
|
||||
# ``ConnectionAbortedError`` is raised.
|
||||
# TLS 1.3 support is still incomplete in Python 3.8;
|
||||
# there is no way for the client to wait for tickets.
|
||||
# While not necessary for retrieving the parsed certificate,
|
||||
# we send a tiny bit of data over the connection in an
|
||||
# attempt to give the server a chance to send the session
|
||||
# tickets and close the connection cleanly.
|
||||
# Note that, as this is essentially a race condition,
|
||||
# the error may still occur ocasionally.
|
||||
ssl_sock.send(b'0000')
|
||||
|
||||
|
||||
def _loopback_for_cert(certificate, private_key, certificate_chain):
|
||||
"""Create a loopback connection to parse a cert with a private key."""
|
||||
context = ssl.create_default_context(cafile=certificate_chain)
|
||||
context.load_cert_chain(certificate, private_key)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Python 3+ Unix, Python 3.5+ Windows
|
||||
client, server = socket.socketpair()
|
||||
try:
|
||||
# `wrap_socket` will block until the ssl handshake is complete.
|
||||
# it must be called on both ends at the same time -> thread
|
||||
# openssl will cache the peer's cert during a successful handshake
|
||||
# and return it via `getpeercert` even after the socket is closed.
|
||||
# when `close` is called, the SSL shutdown notice will be sent
|
||||
# and then python will wait to receive the corollary shutdown.
|
||||
thread = threading.Thread(
|
||||
target=_loopback_for_cert_thread, args=(context, server),
|
||||
)
|
||||
try:
|
||||
thread.start()
|
||||
with context.wrap_socket(
|
||||
client, do_handshake_on_connect=True,
|
||||
server_side=False,
|
||||
) as ssl_sock:
|
||||
ssl_sock.recv(4)
|
||||
return ssl_sock.getpeercert()
|
||||
finally:
|
||||
thread.join()
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
def _parse_cert(certificate, private_key, certificate_chain):
|
||||
"""Parse a certificate."""
|
||||
# loopback_for_cert uses socket.socketpair which was only
|
||||
# introduced in Python 3.0 for *nix and 3.5 for Windows
|
||||
# and requires OS support (AttributeError, OSError)
|
||||
# it also requires a private key either in its own file
|
||||
# or combined with the cert (SSLError)
|
||||
with suppress(AttributeError, ssl.SSLError, OSError):
|
||||
return _loopback_for_cert(certificate, private_key, certificate_chain)
|
||||
|
||||
# KLUDGE: using an undocumented, private, test method to parse a cert
|
||||
# unfortunately, it is the only built-in way without a connection
|
||||
# as a private, undocumented method, it may change at any time
|
||||
# so be tolerant of *any* possible errors it may raise
|
||||
with suppress(Exception):
|
||||
return ssl._ssl._test_decode_cert(certificate)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _sni_callback(sock, sni, context):
|
||||
"""Handle the SNI callback to tag the socket with the SNI."""
|
||||
sock.sni = sni
|
||||
# return None to allow the TLS negotiation to continue
|
||||
|
||||
|
||||
class BuiltinSSLAdapter(Adapter):
|
||||
"""A wrapper for integrating Python's builtin ssl module with Cheroot."""
|
||||
"""Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot."""
|
||||
|
||||
certificate = None
|
||||
"""The filename of the server SSL certificate."""
|
||||
"""The file name of the server SSL certificate."""
|
||||
|
||||
private_key = None
|
||||
"""The filename of the server's private key file."""
|
||||
"""The file name of the server's private key file."""
|
||||
|
||||
certificate_chain = None
|
||||
"""The filename of the certificate chain file."""
|
||||
|
||||
context = None
|
||||
"""The ssl.SSLContext that will be used to wrap sockets."""
|
||||
"""The file name of the certificate chain file."""
|
||||
|
||||
ciphers = None
|
||||
"""The ciphers list of SSL."""
|
||||
|
||||
# from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert
|
||||
CERT_KEY_TO_ENV = {
|
||||
'subject': 'SSL_CLIENT_S_DN',
|
||||
'issuer': 'SSL_CLIENT_I_DN',
|
||||
'version': 'M_VERSION',
|
||||
'serialNumber': 'M_SERIAL',
|
||||
'notBefore': 'V_START',
|
||||
'notAfter': 'V_END',
|
||||
'subject': 'S_DN',
|
||||
'issuer': 'I_DN',
|
||||
'subjectAltName': 'SAN',
|
||||
# not parsed by the Python standard library
|
||||
# - A_SIG
|
||||
# - A_KEY
|
||||
# not provided by mod_ssl
|
||||
# - OCSP
|
||||
# - caIssuers
|
||||
# - crlDistributionPoints
|
||||
}
|
||||
|
||||
# from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert_dn_rec
|
||||
CERT_KEY_TO_LDAP_CODE = {
|
||||
'countryName': 'C',
|
||||
'stateOrProvinceName': 'ST',
|
||||
# NOTE: mod_ssl also provides 'stateOrProvinceName' as 'SP'
|
||||
# for compatibility with SSLeay
|
||||
'localityName': 'L',
|
||||
'organizationName': 'O',
|
||||
'organizationalUnitName': 'OU',
|
||||
'commonName': 'CN',
|
||||
'title': 'T',
|
||||
'initials': 'I',
|
||||
'givenName': 'G',
|
||||
'surname': 'S',
|
||||
'description': 'D',
|
||||
'userid': 'UID',
|
||||
'emailAddress': 'Email',
|
||||
# not provided by mod_ssl
|
||||
# - dnQualifier: DNQ
|
||||
# - domainComponent: DC
|
||||
# - postalCode: PC
|
||||
# - streetAddress: STREET
|
||||
# - serialNumber
|
||||
# - generationQualifier
|
||||
# - pseudonym
|
||||
# - jurisdictionCountryName
|
||||
# - jurisdictionLocalityName
|
||||
# - jurisdictionStateOrProvince
|
||||
# - businessCategory
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -102,6 +227,45 @@ class BuiltinSSLAdapter(Adapter):
|
|||
if self.ciphers is not None:
|
||||
self.context.set_ciphers(ciphers)
|
||||
|
||||
self._server_env = self._make_env_cert_dict(
|
||||
'SSL_SERVER',
|
||||
_parse_cert(certificate, private_key, self.certificate_chain),
|
||||
)
|
||||
if not self._server_env:
|
||||
return
|
||||
cert = None
|
||||
with open(certificate, mode='rt') as f:
|
||||
cert = f.read()
|
||||
|
||||
# strip off any keys by only taking the first certificate
|
||||
cert_start = cert.find(ssl.PEM_HEADER)
|
||||
if cert_start == -1:
|
||||
return
|
||||
cert_end = cert.find(ssl.PEM_FOOTER, cert_start)
|
||||
if cert_end == -1:
|
||||
return
|
||||
cert_end += len(ssl.PEM_FOOTER)
|
||||
self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end]
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
""":py:class:`~ssl.SSLContext` that will be used to wrap sockets."""
|
||||
return self._context
|
||||
|
||||
@context.setter
|
||||
def context(self, context):
|
||||
"""Set the ssl ``context`` to use."""
|
||||
self._context = context
|
||||
# Python 3.7+
|
||||
# if a context is provided via `cherrypy.config.update` then
|
||||
# `self.context` will be set after `__init__`
|
||||
# use a property to intercept it to add an SNI callback
|
||||
# but don't override the user's callback
|
||||
# TODO: chain callbacks
|
||||
with suppress(AttributeError):
|
||||
if ssl.HAS_SNI and context.sni_callback is None:
|
||||
context.sni_callback = _sni_callback
|
||||
|
||||
def bind(self, sock):
|
||||
"""Wrap and return the given socket."""
|
||||
return super(BuiltinSSLAdapter, self).bind(sock)
|
||||
|
@ -135,6 +299,8 @@ class BuiltinSSLAdapter(Adapter):
|
|||
'no shared cipher', 'certificate unknown',
|
||||
'ccs received early',
|
||||
'certificate verify failed', # client cert w/o trusted CA
|
||||
'version too low', # caused by SSL3 connections
|
||||
'unsupported protocol', # caused by TLS1 connections
|
||||
)
|
||||
if _assert_ssl_exc_contains(ex, *_block_errors):
|
||||
# Accepted error, let's pass
|
||||
|
@ -148,7 +314,8 @@ class BuiltinSSLAdapter(Adapter):
|
|||
except generic_socket_error as exc:
|
||||
"""It is unclear why exactly this happens.
|
||||
|
||||
It's reproducible only with openssl>1.0 and stdlib ``ssl`` wrapper.
|
||||
It's reproducible only with openssl>1.0 and stdlib
|
||||
:py:mod:`ssl` wrapper.
|
||||
In CherryPy it's triggered by Checker plugin, which connects
|
||||
to the app listening to the socket port in TLS mode via plain
|
||||
HTTP during startup (from the same process).
|
||||
|
@ -163,7 +330,6 @@ class BuiltinSSLAdapter(Adapter):
|
|||
raise
|
||||
return s, self.get_environ(s)
|
||||
|
||||
# TODO: fill this out more with mod ssl env
|
||||
def get_environ(self, sock):
|
||||
"""Create WSGI environ entries to be merged into each request."""
|
||||
cipher = sock.cipher()
|
||||
|
@ -172,22 +338,117 @@ class BuiltinSSLAdapter(Adapter):
|
|||
'HTTPS': 'on',
|
||||
'SSL_PROTOCOL': cipher[1],
|
||||
'SSL_CIPHER': cipher[0],
|
||||
# SSL_VERSION_INTERFACE string The mod_ssl program version
|
||||
# SSL_VERSION_LIBRARY string The OpenSSL program version
|
||||
'SSL_CIPHER_EXPORT': '',
|
||||
'SSL_CIPHER_USEKEYSIZE': cipher[2],
|
||||
'SSL_VERSION_INTERFACE': '%s Python/%s' % (
|
||||
HTTPServer.version, sys.version,
|
||||
),
|
||||
'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION,
|
||||
'SSL_CLIENT_VERIFY': 'NONE',
|
||||
# 'NONE' - client did not provide a cert (overriden below)
|
||||
}
|
||||
|
||||
# Python 3.3+
|
||||
with suppress(AttributeError):
|
||||
compression = sock.compression()
|
||||
if compression is not None:
|
||||
ssl_environ['SSL_COMPRESS_METHOD'] = compression
|
||||
|
||||
# Python 3.6+
|
||||
with suppress(AttributeError):
|
||||
ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex()
|
||||
with suppress(AttributeError):
|
||||
target_cipher = cipher[:2]
|
||||
for cip in sock.context.get_ciphers():
|
||||
if target_cipher == (cip['name'], cip['protocol']):
|
||||
ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits']
|
||||
break
|
||||
|
||||
# Python 3.7+ sni_callback
|
||||
with suppress(AttributeError):
|
||||
ssl_environ['SSL_TLS_SNI'] = sock.sni
|
||||
|
||||
if self.context and self.context.verify_mode != ssl.CERT_NONE:
|
||||
client_cert = sock.getpeercert()
|
||||
if client_cert:
|
||||
for cert_key, env_var in self.CERT_KEY_TO_ENV.items():
|
||||
# builtin ssl **ALWAYS** validates client certificates
|
||||
# and terminates the connection on failure
|
||||
ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS'
|
||||
ssl_environ.update(
|
||||
self.env_dn_dict(env_var, client_cert.get(cert_key)),
|
||||
self._make_env_cert_dict('SSL_CLIENT', client_cert),
|
||||
)
|
||||
ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert(
|
||||
sock.getpeercert(binary_form=True),
|
||||
).strip()
|
||||
|
||||
ssl_environ.update(self._server_env)
|
||||
|
||||
# not supplied by the Python standard library (as of 3.8)
|
||||
# - SSL_SESSION_RESUMED
|
||||
# - SSL_SECURE_RENEG
|
||||
# - SSL_CLIENT_CERT_CHAIN_n
|
||||
# - SRP_USER
|
||||
# - SRP_USERINFO
|
||||
|
||||
return ssl_environ
|
||||
|
||||
def env_dn_dict(self, env_prefix, cert_value):
|
||||
"""Return a dict of WSGI environment variables for a client cert DN.
|
||||
def _make_env_cert_dict(self, env_prefix, parsed_cert):
|
||||
"""Return a dict of WSGI environment variables for a certificate.
|
||||
|
||||
E.g. SSL_CLIENT_M_VERSION, SSL_CLIENT_M_SERIAL, etc.
|
||||
See https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars.
|
||||
"""
|
||||
if not parsed_cert:
|
||||
return {}
|
||||
|
||||
env = {}
|
||||
for cert_key, env_var in self.CERT_KEY_TO_ENV.items():
|
||||
key = '%s_%s' % (env_prefix, env_var)
|
||||
value = parsed_cert.get(cert_key)
|
||||
if env_var == 'SAN':
|
||||
env.update(self._make_env_san_dict(key, value))
|
||||
elif env_var.endswith('_DN'):
|
||||
env.update(self._make_env_dn_dict(key, value))
|
||||
else:
|
||||
env[key] = str(value)
|
||||
|
||||
# mod_ssl 2.1+; Python 3.2+
|
||||
# number of days until the certificate expires
|
||||
if 'notBefore' in parsed_cert:
|
||||
remain = ssl.cert_time_to_seconds(parsed_cert['notAfter'])
|
||||
remain -= ssl.cert_time_to_seconds(parsed_cert['notBefore'])
|
||||
remain /= 60 * 60 * 24
|
||||
env['%s_V_REMAIN' % (env_prefix,)] = str(int(remain))
|
||||
|
||||
return env
|
||||
|
||||
def _make_env_san_dict(self, env_prefix, cert_value):
|
||||
"""Return a dict of WSGI environment variables for a certificate DN.
|
||||
|
||||
E.g. SSL_CLIENT_SAN_Email_0, SSL_CLIENT_SAN_DNS_0, etc.
|
||||
See SSL_CLIENT_SAN_* at
|
||||
https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars.
|
||||
"""
|
||||
if not cert_value:
|
||||
return {}
|
||||
|
||||
env = {}
|
||||
dns_count = 0
|
||||
email_count = 0
|
||||
for attr_name, val in cert_value:
|
||||
if attr_name == 'DNS':
|
||||
env['%s_DNS_%i' % (env_prefix, dns_count)] = val
|
||||
dns_count += 1
|
||||
elif attr_name == 'Email':
|
||||
env['%s_Email_%i' % (env_prefix, email_count)] = val
|
||||
email_count += 1
|
||||
|
||||
# other mod_ssl SAN vars:
|
||||
# - SAN_OTHER_msUPN_n
|
||||
return env
|
||||
|
||||
def _make_env_dn_dict(self, env_prefix, cert_value):
|
||||
"""Return a dict of WSGI environment variables for a certificate DN.
|
||||
|
||||
E.g. SSL_CLIENT_S_DN_CN, SSL_CLIENT_S_DN_C, etc.
|
||||
See SSL_CLIENT_S_DN_x509 at
|
||||
|
@ -196,12 +457,26 @@ class BuiltinSSLAdapter(Adapter):
|
|||
if not cert_value:
|
||||
return {}
|
||||
|
||||
env = {}
|
||||
dn = []
|
||||
dn_attrs = {}
|
||||
for rdn in cert_value:
|
||||
for attr_name, val in rdn:
|
||||
attr_code = self.CERT_KEY_TO_LDAP_CODE.get(attr_name)
|
||||
if attr_code:
|
||||
env['%s_%s' % (env_prefix, attr_code)] = val
|
||||
dn.append('%s=%s' % (attr_code or attr_name, val))
|
||||
if not attr_code:
|
||||
continue
|
||||
dn_attrs.setdefault(attr_code, [])
|
||||
dn_attrs[attr_code].append(val)
|
||||
|
||||
env = {
|
||||
env_prefix: ','.join(dn),
|
||||
}
|
||||
for attr_code, values in dn_attrs.items():
|
||||
env['%s_%s' % (env_prefix, attr_code)] = ','.join(values)
|
||||
if len(values) == 1:
|
||||
continue
|
||||
for i, val in enumerate(values):
|
||||
env['%s_%s_%i' % (env_prefix, attr_code, i)] = val
|
||||
return env
|
||||
|
||||
def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):
|
||||
|
|
|
@ -1,46 +1,67 @@
|
|||
"""
|
||||
A library for integrating pyOpenSSL with Cheroot.
|
||||
A library for integrating :doc:`pyOpenSSL <pyopenssl:index>` with Cheroot.
|
||||
|
||||
The OpenSSL module must be importable for SSL functionality.
|
||||
You can obtain it from `here <https://launchpad.net/pyopenssl>`_.
|
||||
The :py:mod:`OpenSSL <pyopenssl:OpenSSL>` module must be importable
|
||||
for SSL/TLS/HTTPS functionality.
|
||||
You can obtain it from `here <https://github.com/pyca/pyopenssl>`_.
|
||||
|
||||
To use this module, set HTTPServer.ssl_adapter to an instance of
|
||||
ssl.Adapter. There are two ways to use SSL:
|
||||
To use this module, set :py:attr:`HTTPServer.ssl_adapter
|
||||
<cheroot.server.HTTPServer.ssl_adapter>` to an instance of
|
||||
:py:class:`ssl.Adapter <cheroot.ssl.Adapter>`.
|
||||
There are two ways to use :abbr:`TLS (Transport-Level Security)`:
|
||||
|
||||
Method One
|
||||
----------
|
||||
|
||||
* ``ssl_adapter.context``: an instance of SSL.Context.
|
||||
* :py:attr:`ssl_adapter.context
|
||||
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.context>`: an instance of
|
||||
:py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`.
|
||||
|
||||
If this is not None, it is assumed to be an SSL.Context instance,
|
||||
and will be passed to SSL.Connection on bind(). The developer is
|
||||
responsible for forming a valid Context object. This approach is
|
||||
to be preferred for more flexibility, e.g. if the cert and key are
|
||||
streams instead of files, or need decryption, or SSL.SSLv3_METHOD
|
||||
is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
|
||||
the pyOpenSSL documentation for complete options.
|
||||
If this is not None, it is assumed to be an :py:class:`SSL.Context
|
||||
<pyopenssl:OpenSSL.SSL.Context>` instance, and will be passed to
|
||||
:py:class:`SSL.Connection <pyopenssl:OpenSSL.SSL.Connection>` on bind().
|
||||
The developer is responsible for forming a valid :py:class:`Context
|
||||
<pyopenssl:OpenSSL.SSL.Context>` object. This
|
||||
approach is to be preferred for more flexibility, e.g. if the cert and
|
||||
key are streams instead of files, or need decryption, or
|
||||
:py:data:`SSL.SSLv3_METHOD <pyopenssl:OpenSSL.SSL.SSLv3_METHOD>`
|
||||
is desired instead of the default :py:data:`SSL.SSLv23_METHOD
|
||||
<pyopenssl:OpenSSL.SSL.SSLv3_METHOD>`, etc. Consult
|
||||
the :doc:`pyOpenSSL <pyopenssl:api/ssl>` documentation for
|
||||
complete options.
|
||||
|
||||
Method Two (shortcut)
|
||||
---------------------
|
||||
|
||||
* ``ssl_adapter.certificate``: the filename of the server SSL certificate.
|
||||
* ``ssl_adapter.private_key``: the filename of the server's private key file.
|
||||
* :py:attr:`ssl_adapter.certificate
|
||||
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.certificate>`: the file name
|
||||
of the server's TLS certificate.
|
||||
* :py:attr:`ssl_adapter.private_key
|
||||
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.private_key>`: the file name
|
||||
of the server's private key file.
|
||||
|
||||
Both are None by default. If ssl_adapter.context is None, but .private_key
|
||||
and .certificate are both given and valid, they will be read, and the
|
||||
context will be automatically created from them.
|
||||
Both are :py:data:`None` by default. If :py:attr:`ssl_adapter.context
|
||||
<cheroot.ssl.pyopenssl.pyOpenSSLAdapter.context>` is :py:data:`None`,
|
||||
but ``.private_key`` and ``.certificate`` are both given and valid, they
|
||||
will be read, and the context will be automatically created from them.
|
||||
|
||||
.. spelling::
|
||||
|
||||
pyopenssl
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
try:
|
||||
import OpenSSL.version
|
||||
from OpenSSL import SSL
|
||||
from OpenSSL import crypto
|
||||
|
||||
|
@ -57,13 +78,14 @@ from ..makefile import StreamReader, StreamWriter
|
|||
|
||||
|
||||
class SSLFileobjectMixin:
|
||||
"""Base mixin for an SSL socket stream."""
|
||||
"""Base mixin for a TLS socket stream."""
|
||||
|
||||
ssl_timeout = 3
|
||||
ssl_retry = .01
|
||||
|
||||
def _safe_call(self, is_reader, call, *args, **kwargs):
|
||||
"""Wrap the given call with SSL error-trapping.
|
||||
# FIXME:
|
||||
def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901
|
||||
"""Wrap the given call with TLS error-trapping.
|
||||
|
||||
is_reader: if False EOF errors will be raised. If True, EOF errors
|
||||
will return "" (to emulate normal sockets).
|
||||
|
@ -209,9 +231,11 @@ class SSLConnectionProxyMeta:
|
|||
|
||||
@six.add_metaclass(SSLConnectionProxyMeta)
|
||||
class SSLConnection:
|
||||
"""A thread-safe wrapper for an SSL.Connection.
|
||||
r"""A thread-safe wrapper for an ``SSL.Connection``.
|
||||
|
||||
``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
|
||||
:param tuple args: the arguments to create the wrapped \
|
||||
:py:class:`SSL.Connection(*args) \
|
||||
<pyopenssl:OpenSSL.SSL.Connection>`
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
|
@ -224,22 +248,24 @@ class pyOpenSSLAdapter(Adapter):
|
|||
"""A wrapper for integrating pyOpenSSL with Cheroot."""
|
||||
|
||||
certificate = None
|
||||
"""The filename of the server SSL certificate."""
|
||||
"""The file name of the server's TLS certificate."""
|
||||
|
||||
private_key = None
|
||||
"""The filename of the server's private key file."""
|
||||
"""The file name of the server's private key file."""
|
||||
|
||||
certificate_chain = None
|
||||
"""Optional. The filename of CA's intermediate certificate bundle.
|
||||
"""Optional. The file name of CA's intermediate certificate bundle.
|
||||
|
||||
This is needed for cheaper "chained root" SSL certificates, and should be
|
||||
left as None if not required."""
|
||||
This is needed for cheaper "chained root" TLS certificates,
|
||||
and should be left as :py:data:`None` if not required."""
|
||||
|
||||
context = None
|
||||
"""An instance of SSL.Context."""
|
||||
"""
|
||||
An instance of :py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`.
|
||||
"""
|
||||
|
||||
ciphers = None
|
||||
"""The ciphers list of SSL."""
|
||||
"""The ciphers list of TLS."""
|
||||
|
||||
def __init__(
|
||||
self, certificate, private_key, certificate_chain=None,
|
||||
|
@ -265,10 +291,16 @@ class pyOpenSSLAdapter(Adapter):
|
|||
|
||||
def wrap(self, sock):
|
||||
"""Wrap and return the given socket, plus WSGI environ entries."""
|
||||
# pyOpenSSL doesn't perform the handshake until the first read/write
|
||||
# forcing the handshake to complete tends to result in the connection
|
||||
# closing so we can't reliably access protocol/client cert for the env
|
||||
return sock, self._environ.copy()
|
||||
|
||||
def get_context(self):
|
||||
"""Return an SSL.Context from self attributes."""
|
||||
"""Return an ``SSL.Context`` from self attributes.
|
||||
|
||||
Ref: :py:class:`SSL.Context <pyopenssl:OpenSSL.SSL.Context>`
|
||||
"""
|
||||
# See https://code.activestate.com/recipes/442473/
|
||||
c = SSL.Context(SSL.SSLv23_METHOD)
|
||||
c.use_privatekey_file(self.private_key)
|
||||
|
@ -280,18 +312,25 @@ class pyOpenSSLAdapter(Adapter):
|
|||
def get_environ(self):
|
||||
"""Return WSGI environ entries to be merged into each request."""
|
||||
ssl_environ = {
|
||||
'wsgi.url_scheme': 'https',
|
||||
'HTTPS': 'on',
|
||||
# pyOpenSSL doesn't provide access to any of these AFAICT
|
||||
# 'SSL_PROTOCOL': 'SSLv2',
|
||||
# SSL_CIPHER string The cipher specification name
|
||||
# SSL_VERSION_INTERFACE string The mod_ssl program version
|
||||
# SSL_VERSION_LIBRARY string The OpenSSL program version
|
||||
'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % (
|
||||
cheroot_server.HTTPServer.version,
|
||||
OpenSSL.version.__title__, OpenSSL.version.__version__,
|
||||
sys.version,
|
||||
),
|
||||
'SSL_VERSION_LIBRARY': SSL.SSLeay_version(
|
||||
SSL.SSLEAY_VERSION,
|
||||
).decode(),
|
||||
}
|
||||
|
||||
if self.certificate:
|
||||
# Server certificate attributes
|
||||
cert = open(self.certificate, 'rb').read()
|
||||
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
|
||||
with open(self.certificate, 'rb') as cert_file:
|
||||
cert = crypto.load_certificate(
|
||||
crypto.FILETYPE_PEM, cert_file.read(),
|
||||
)
|
||||
|
||||
ssl_environ.update({
|
||||
'SSL_SERVER_M_VERSION': cert.get_version(),
|
||||
'SSL_SERVER_M_SERIAL': cert.get_serial_number(),
|
||||
|
|
38
lib/cheroot/test/_pytest_plugin.py
Normal file
38
lib/cheroot/test/_pytest_plugin.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
"""Local pytest plugin.
|
||||
|
||||
Contains hooks, which are tightly bound to the Cheroot framework
|
||||
itself, useless for end-users' app testing.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
pytest_version = tuple(map(int, pytest.__version__.split('.')))
|
||||
|
||||
|
||||
def pytest_load_initial_conftests(early_config, parser, args):
|
||||
"""Drop unfilterable warning ignores."""
|
||||
if pytest_version < (6, 2, 0):
|
||||
return
|
||||
|
||||
# pytest>=6.2.0 under Python 3.8:
|
||||
# Refs:
|
||||
# * https://docs.pytest.org/en/stable/usage.html#unraisable
|
||||
# * https://github.com/pytest-dev/pytest/issues/5299
|
||||
early_config._inicache['filterwarnings'].extend((
|
||||
'ignore:Exception in thread CP Server Thread-:'
|
||||
'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception',
|
||||
'ignore:Exception in thread Thread-:'
|
||||
'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception',
|
||||
'ignore:Exception ignored in. '
|
||||
'<socket.socket fd=-1, family=AddressFamily.AF_INET, '
|
||||
'type=SocketKind.SOCK_STREAM, proto=.:'
|
||||
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
|
||||
'ignore:Exception ignored in. '
|
||||
'<socket.socket fd=-1, family=AddressFamily.AF_INET6, '
|
||||
'type=SocketKind.SOCK_STREAM, proto=.:'
|
||||
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
|
||||
))
|
|
@ -55,7 +55,7 @@ def http_server():
|
|||
|
||||
|
||||
def make_http_server(bind_addr):
|
||||
"""Create and start an HTTP server bound to bind_addr."""
|
||||
"""Create and start an HTTP server bound to ``bind_addr``."""
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=bind_addr,
|
||||
gateway=Gateway,
|
||||
|
|
|
@ -99,7 +99,7 @@ class CherootWebCase(webtest.WebCase):
|
|||
date_tolerance = 2
|
||||
|
||||
def assertEqualDates(self, dt1, dt2, seconds=None):
|
||||
"""Assert abs(dt1 - dt2) is within Y seconds."""
|
||||
"""Assert ``abs(dt1 - dt2)`` is within ``Y`` seconds."""
|
||||
if seconds is None:
|
||||
seconds = self.date_tolerance
|
||||
|
||||
|
@ -108,8 +108,10 @@ class CherootWebCase(webtest.WebCase):
|
|||
else:
|
||||
diff = dt2 - dt1
|
||||
if not diff < datetime.timedelta(seconds=seconds):
|
||||
raise AssertionError('%r and %r are not within %r seconds.' %
|
||||
(dt1, dt2, seconds))
|
||||
raise AssertionError(
|
||||
'%r and %r are not within %r seconds.' %
|
||||
(dt1, dt2, seconds),
|
||||
)
|
||||
|
||||
|
||||
class Request:
|
||||
|
@ -155,9 +157,13 @@ class Controller:
|
|||
resp.status = '404 Not Found'
|
||||
else:
|
||||
output = handler(req, resp)
|
||||
if (output is not None
|
||||
and not any(resp.status.startswith(status_code)
|
||||
for status_code in ('204', '304'))):
|
||||
if (
|
||||
output is not None
|
||||
and not any(
|
||||
resp.status.startswith(status_code)
|
||||
for status_code in ('204', '304')
|
||||
)
|
||||
):
|
||||
resp.body = output
|
||||
try:
|
||||
resp.headers.setdefault('Content-Length', str(len(output)))
|
||||
|
|
|
@ -11,45 +11,45 @@ from cheroot._compat import extract_bytes, memoryview, ntob, ntou, bton
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'func,inp,out',
|
||||
[
|
||||
('func', 'inp', 'out'),
|
||||
(
|
||||
(ntob, 'bar', b'bar'),
|
||||
(ntou, 'bar', u'bar'),
|
||||
(bton, b'bar', 'bar'),
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_compat_functions_positive(func, inp, out):
|
||||
"""Check that compat functions work with correct input."""
|
||||
"""Check that compatibility functions work with correct input."""
|
||||
assert func(inp, encoding='utf-8') == out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'func',
|
||||
[
|
||||
(
|
||||
ntob,
|
||||
ntou,
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_compat_functions_negative_nonnative(func):
|
||||
"""Check that compat functions fail loudly for incorrect input."""
|
||||
"""Check that compatibility functions fail loudly for incorrect input."""
|
||||
non_native_test_str = u'bar' if six.PY2 else b'bar'
|
||||
with pytest.raises(TypeError):
|
||||
func(non_native_test_str, encoding='utf-8')
|
||||
|
||||
|
||||
def test_ntou_escape():
|
||||
"""Check that ntou supports escape-encoding under Python 2."""
|
||||
"""Check that ``ntou`` supports escape-encoding under Python 2."""
|
||||
expected = u'hišřії'
|
||||
actual = ntou('hi\u0161\u0159\u0456\u0457', encoding='escape')
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_argument,expected_result',
|
||||
[
|
||||
('input_argument', 'expected_result'),
|
||||
(
|
||||
(b'qwerty', b'qwerty'),
|
||||
(memoryview(b'asdfgh'), b'asdfgh'),
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_extract_bytes(input_argument, expected_result):
|
||||
"""Check that legitimate inputs produce bytes."""
|
||||
|
@ -58,5 +58,9 @@ def test_extract_bytes(input_argument, expected_result):
|
|||
|
||||
def test_extract_bytes_invalid():
|
||||
"""Ensure that invalid input causes exception to be raised."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r'^extract_bytes\(\) only accepts bytes '
|
||||
'and memoryview/buffer$',
|
||||
):
|
||||
extract_bytes(u'some юнікод їїї')
|
||||
|
|
97
lib/cheroot/test/test_cli.py
Normal file
97
lib/cheroot/test/test_cli.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
"""Tests to verify the command line interface.
|
||||
|
||||
.. spelling::
|
||||
|
||||
cli
|
||||
"""
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
import sys
|
||||
|
||||
import six
|
||||
import pytest
|
||||
|
||||
from cheroot.cli import (
|
||||
Application,
|
||||
parse_wsgi_bind_addr,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('raw_bind_addr', 'expected_bind_addr'),
|
||||
(
|
||||
# tcp/ip
|
||||
('192.168.1.1:80', ('192.168.1.1', 80)),
|
||||
# ipv6 ips has to be enclosed in brakets when specified in url form
|
||||
('[::1]:8000', ('::1', 8000)),
|
||||
('localhost:5000', ('localhost', 5000)),
|
||||
# this is a valid input, but foo gets discarted
|
||||
('foo@bar:5000', ('bar', 5000)),
|
||||
('foo', ('foo', None)),
|
||||
('123456789', ('123456789', None)),
|
||||
# unix sockets
|
||||
('/tmp/cheroot.sock', '/tmp/cheroot.sock'),
|
||||
('/tmp/some-random-file-name', '/tmp/some-random-file-name'),
|
||||
# abstract sockets
|
||||
('@cheroot', '\x00cheroot'),
|
||||
),
|
||||
)
|
||||
def test_parse_wsgi_bind_addr(raw_bind_addr, expected_bind_addr):
|
||||
"""Check the parsing of the --bind option.
|
||||
|
||||
Verify some of the supported addresses and the expected return value.
|
||||
"""
|
||||
assert parse_wsgi_bind_addr(raw_bind_addr) == expected_bind_addr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wsgi_app(monkeypatch):
|
||||
"""Return a WSGI app stub."""
|
||||
class WSGIAppMock:
|
||||
"""Mock of a wsgi module."""
|
||||
|
||||
def application(self):
|
||||
"""Empty application method.
|
||||
|
||||
Default method to be called when no specific callable
|
||||
is defined in the wsgi application identifier.
|
||||
|
||||
It has an empty body because we are expecting to verify that
|
||||
the same method is return no the actual execution of it.
|
||||
"""
|
||||
|
||||
def main(self):
|
||||
"""Empty custom method (callable) inside the mocked WSGI app.
|
||||
|
||||
It has an empty body because we are expecting to verify that
|
||||
the same method is return no the actual execution of it.
|
||||
"""
|
||||
app = WSGIAppMock()
|
||||
# patch sys.modules, to include the an instance of WSGIAppMock
|
||||
# under a specific namespace
|
||||
if six.PY2:
|
||||
# python2 requires the previous namespaces to be part of sys.modules
|
||||
# (e.g. for 'a.b.c' we need to insert 'a', 'a.b' and 'a.b.c')
|
||||
# otherwise it fails, we're setting the same instance on each level,
|
||||
# we don't really care about those, just the last one.
|
||||
monkeypatch.setitem(sys.modules, 'mypkg', app)
|
||||
monkeypatch.setitem(sys.modules, 'mypkg.wsgi', app)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('app_name', 'app_method'),
|
||||
(
|
||||
(None, 'application'),
|
||||
('application', 'application'),
|
||||
('main', 'main'),
|
||||
),
|
||||
)
|
||||
def test_Aplication_resolve(app_name, app_method, wsgi_app):
|
||||
"""Check the wsgi application name conversion."""
|
||||
if app_name is None:
|
||||
wsgi_app_spec = 'mypkg.wsgi'
|
||||
else:
|
||||
wsgi_app_spec = 'mypkg.wsgi:{app_name}'.format(**locals())
|
||||
expected_app = getattr(wsgi_app, app_method)
|
||||
assert Application.resolve(wsgi_app_spec).wsgi_app == expected_app
|
|
@ -3,15 +3,22 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
import logging
|
||||
import traceback as traceback_
|
||||
from collections import namedtuple
|
||||
|
||||
from six.moves import range, http_client, urllib
|
||||
|
||||
import six
|
||||
import pytest
|
||||
from jaraco.text import trim, unwrap
|
||||
|
||||
from cheroot.test import helper, webtest
|
||||
from cheroot._compat import IS_CI, IS_PYPY, IS_WINDOWS
|
||||
import cheroot.server
|
||||
|
||||
|
||||
timeout = 1
|
||||
|
@ -26,7 +33,7 @@ class Controller(helper.Controller):
|
|||
return 'Hello, world!'
|
||||
|
||||
def pov(req, resp):
|
||||
"""Render pov value."""
|
||||
"""Render ``pov`` value."""
|
||||
return pov
|
||||
|
||||
def stream(req, resp):
|
||||
|
@ -43,8 +50,10 @@ class Controller(helper.Controller):
|
|||
def upload(req, resp):
|
||||
"""Process file upload and render thank."""
|
||||
if not req.environ['REQUEST_METHOD'] == 'POST':
|
||||
raise AssertionError("'POST' != request.method %r" %
|
||||
req.environ['REQUEST_METHOD'])
|
||||
raise AssertionError(
|
||||
"'POST' != request.method %r" %
|
||||
req.environ['REQUEST_METHOD'],
|
||||
)
|
||||
return "thanks for '%s'" % req.environ['wsgi.input'].read()
|
||||
|
||||
def custom_204(req, resp):
|
||||
|
@ -103,9 +112,33 @@ class Controller(helper.Controller):
|
|||
}
|
||||
|
||||
|
||||
class ErrorLogMonitor:
|
||||
"""Mock class to access the server error_log calls made by the server."""
|
||||
|
||||
ErrorLogCall = namedtuple('ErrorLogCall', ['msg', 'level', 'traceback'])
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the server error log monitor/interceptor.
|
||||
|
||||
If you need to ignore a particular error message use the property
|
||||
``ignored_msgs`` by appending to the list the expected error messages.
|
||||
"""
|
||||
self.calls = []
|
||||
# to be used the the teardown validation
|
||||
self.ignored_msgs = []
|
||||
|
||||
def __call__(self, msg='', level=logging.INFO, traceback=False):
|
||||
"""Intercept the call to the server error_log method."""
|
||||
if traceback:
|
||||
tblines = traceback_.format_exc()
|
||||
else:
|
||||
tblines = ''
|
||||
self.calls.append(ErrorLogMonitor.ErrorLogCall(msg, level, tblines))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testing_server(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and pre-configure it."""
|
||||
def raw_testing_server(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and preconfigure it."""
|
||||
app = Controller()
|
||||
|
||||
def _timeout(req, resp):
|
||||
|
@ -117,9 +150,36 @@ def testing_server(wsgi_server_client):
|
|||
wsgi_server.timeout = timeout
|
||||
wsgi_server.server_client = wsgi_server_client
|
||||
wsgi_server.keep_alive_conn_limit = 2
|
||||
|
||||
return wsgi_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testing_server(raw_testing_server, monkeypatch):
|
||||
"""Modify the "raw" base server to monitor the error_log messages.
|
||||
|
||||
If you need to ignore a particular error message use the property
|
||||
``testing_server.error_log.ignored_msgs`` by appending to the list
|
||||
the expected error messages.
|
||||
"""
|
||||
# patch the error_log calls of the server instance
|
||||
monkeypatch.setattr(raw_testing_server, 'error_log', ErrorLogMonitor())
|
||||
|
||||
yield raw_testing_server
|
||||
|
||||
# Teardown verification, in case that the server logged an
|
||||
# error that wasn't notified to the client or we just made a mistake.
|
||||
for c_msg, c_level, c_traceback in raw_testing_server.error_log.calls:
|
||||
if c_level <= logging.WARNING:
|
||||
continue
|
||||
|
||||
assert c_msg in raw_testing_server.error_log.ignored_msgs, (
|
||||
'Found error in the error log: '
|
||||
"message = '{c_msg}', level = '{c_level}'\n"
|
||||
'{c_traceback}'.format(**locals()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(testing_server):
|
||||
"""Get and return a test client out of the given server."""
|
||||
|
@ -338,7 +398,14 @@ def test_streaming_10(test_client, set_cl):
|
|||
'http_server_protocol',
|
||||
(
|
||||
'HTTP/1.0',
|
||||
pytest.param(
|
||||
'HTTP/1.1',
|
||||
marks=pytest.mark.xfail(
|
||||
IS_PYPY and IS_CI,
|
||||
reason='Fails under PyPy in CI for unknown reason',
|
||||
strict=False,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_keepalive(test_client, http_server_protocol):
|
||||
|
@ -375,6 +442,11 @@ def test_keepalive(test_client, http_server_protocol):
|
|||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
|
||||
assert header_has_value(
|
||||
'Keep-Alive',
|
||||
'timeout={test_client.server_instance.timeout}'.format(**locals()),
|
||||
actual_headers,
|
||||
)
|
||||
|
||||
# Remove the keep-alive header again.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
|
@ -386,6 +458,7 @@ def test_keepalive(test_client, http_server_protocol):
|
|||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
assert not header_exists('Keep-Alive', actual_headers)
|
||||
|
||||
test_client.server_instance.protocol = original_server_protocol
|
||||
|
||||
|
@ -401,7 +474,7 @@ def test_keepalive_conn_management(test_client):
|
|||
http_connection.connect()
|
||||
return http_connection
|
||||
|
||||
def request(conn):
|
||||
def request(conn, keepalive=True):
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/page3', headers=[('Connection', 'Keep-Alive')],
|
||||
http_conn=conn, protocol='HTTP/1.0',
|
||||
|
@ -410,7 +483,28 @@ def test_keepalive_conn_management(test_client):
|
|||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
if keepalive:
|
||||
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
|
||||
assert header_has_value(
|
||||
'Keep-Alive',
|
||||
'timeout={test_client.server_instance.timeout}'.
|
||||
format(**locals()),
|
||||
actual_headers,
|
||||
)
|
||||
else:
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
assert not header_exists('Keep-Alive', actual_headers)
|
||||
|
||||
def check_server_idle_conn_count(count, timeout=1.0):
|
||||
deadline = time.time() + timeout
|
||||
while True:
|
||||
n = test_client.server_instance._connections._num_connections
|
||||
if n == count:
|
||||
return
|
||||
assert time.time() <= deadline, (
|
||||
'idle conn count mismatch, wanted {count}, got {n}'.
|
||||
format(**locals()),
|
||||
)
|
||||
|
||||
disconnect_errors = (
|
||||
http_client.BadStatusLine,
|
||||
|
@ -421,50 +515,175 @@ def test_keepalive_conn_management(test_client):
|
|||
# Make a new connection.
|
||||
c1 = connection()
|
||||
request(c1)
|
||||
check_server_idle_conn_count(1)
|
||||
|
||||
# Make a second one.
|
||||
c2 = connection()
|
||||
request(c2)
|
||||
check_server_idle_conn_count(2)
|
||||
|
||||
# Reusing the first connection should still work.
|
||||
request(c1)
|
||||
check_server_idle_conn_count(2)
|
||||
|
||||
# Creating a new connection should still work.
|
||||
# Creating a new connection should still work, but we should
|
||||
# have run out of available connections to keep alive, so the
|
||||
# server should tell us to close.
|
||||
c3 = connection()
|
||||
request(c3)
|
||||
request(c3, keepalive=False)
|
||||
check_server_idle_conn_count(2)
|
||||
|
||||
# Allow a tick.
|
||||
time.sleep(0.2)
|
||||
|
||||
# That's three connections, we should expect the one used less recently
|
||||
# to be expired.
|
||||
# Show that the third connection was closed.
|
||||
with pytest.raises(disconnect_errors):
|
||||
request(c2)
|
||||
|
||||
# But the oldest created one should still be valid.
|
||||
# (As well as the newest one).
|
||||
request(c1)
|
||||
request(c3)
|
||||
check_server_idle_conn_count(2)
|
||||
|
||||
# Wait for some of our timeout.
|
||||
time.sleep(1.0)
|
||||
time.sleep(1.2)
|
||||
|
||||
# Refresh the third connection.
|
||||
request(c3)
|
||||
# Refresh the second connection.
|
||||
request(c2)
|
||||
check_server_idle_conn_count(2)
|
||||
|
||||
# Wait for the remainder of our timeout, plus one tick.
|
||||
time.sleep(1.2)
|
||||
check_server_idle_conn_count(1)
|
||||
|
||||
# First connection should now be expired.
|
||||
with pytest.raises(disconnect_errors):
|
||||
request(c1)
|
||||
check_server_idle_conn_count(1)
|
||||
|
||||
# But the third one should still be valid.
|
||||
request(c3)
|
||||
# But the second one should still be valid.
|
||||
request(c2)
|
||||
check_server_idle_conn_count(1)
|
||||
|
||||
# Restore original timeout.
|
||||
test_client.server_instance.timeout = timeout
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
('simulated_exception', 'error_number', 'exception_leaks'),
|
||||
(
|
||||
pytest.param(
|
||||
socket.error, errno.ECONNRESET, False,
|
||||
id='socket.error(ECONNRESET)',
|
||||
),
|
||||
pytest.param(
|
||||
socket.error, errno.EPIPE, False,
|
||||
id='socket.error(EPIPE)',
|
||||
),
|
||||
pytest.param(
|
||||
socket.error, errno.ENOTCONN, False,
|
||||
id='simulated socket.error(ENOTCONN)',
|
||||
),
|
||||
pytest.param(
|
||||
None, # <-- don't raise an artificial exception
|
||||
errno.ENOTCONN, False,
|
||||
id='real socket.error(ENOTCONN)',
|
||||
marks=pytest.mark.xfail(
|
||||
IS_WINDOWS,
|
||||
reason='Now reproducible this way on Windows',
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
socket.error, errno.ESHUTDOWN, False,
|
||||
id='socket.error(ESHUTDOWN)',
|
||||
),
|
||||
pytest.param(RuntimeError, 666, True, id='RuntimeError(666)'),
|
||||
pytest.param(socket.error, -1, True, id='socket.error(-1)'),
|
||||
) + (
|
||||
() if six.PY2 else (
|
||||
pytest.param(
|
||||
ConnectionResetError, errno.ECONNRESET, False,
|
||||
id='ConnectionResetError(ECONNRESET)',
|
||||
),
|
||||
pytest.param(
|
||||
BrokenPipeError, errno.EPIPE, False,
|
||||
id='BrokenPipeError(EPIPE)',
|
||||
),
|
||||
pytest.param(
|
||||
BrokenPipeError, errno.ESHUTDOWN, False,
|
||||
id='BrokenPipeError(ESHUTDOWN)',
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_broken_connection_during_tcp_fin(
|
||||
error_number, exception_leaks,
|
||||
mocker, monkeypatch,
|
||||
simulated_exception, test_client,
|
||||
):
|
||||
"""Test there's no traceback on broken connection during close.
|
||||
|
||||
It artificially causes :py:data:`~errno.ECONNRESET` /
|
||||
:py:data:`~errno.EPIPE` / :py:data:`~errno.ESHUTDOWN` /
|
||||
:py:data:`~errno.ENOTCONN` as well as unrelated :py:exc:`RuntimeError`
|
||||
and :py:exc:`socket.error(-1) <socket.error>` on the server socket when
|
||||
:py:meth:`socket.shutdown() <socket.socket.shutdown>` is called. It's
|
||||
triggered by closing the client socket before the server had a chance
|
||||
to respond.
|
||||
|
||||
The expectation is that only :py:exc:`RuntimeError` and a
|
||||
:py:exc:`socket.error` with an unusual error code would leak.
|
||||
|
||||
With the :py:data:`None`-parameter, a real non-simulated
|
||||
:py:exc:`OSError(107, 'Transport endpoint is not connected')
|
||||
<OSError>` happens.
|
||||
"""
|
||||
exc_instance = (
|
||||
None if simulated_exception is None
|
||||
else simulated_exception(error_number, 'Simulated socket error')
|
||||
)
|
||||
old_close_kernel_socket = (
|
||||
test_client.server_instance.
|
||||
ConnectionClass._close_kernel_socket
|
||||
)
|
||||
|
||||
def _close_kernel_socket(self):
|
||||
monkeypatch.setattr( # `socket.shutdown` is read-only otherwise
|
||||
self, 'socket',
|
||||
mocker.mock_module.Mock(wraps=self.socket),
|
||||
)
|
||||
if exc_instance is not None:
|
||||
monkeypatch.setattr(
|
||||
self.socket, 'shutdown',
|
||||
mocker.mock_module.Mock(side_effect=exc_instance),
|
||||
)
|
||||
_close_kernel_socket.fin_spy = mocker.spy(self.socket, 'shutdown')
|
||||
_close_kernel_socket.exception_leaked = True
|
||||
old_close_kernel_socket(self)
|
||||
_close_kernel_socket.exception_leaked = False
|
||||
|
||||
monkeypatch.setattr(
|
||||
test_client.server_instance.ConnectionClass,
|
||||
'_close_kernel_socket',
|
||||
_close_kernel_socket,
|
||||
)
|
||||
|
||||
conn = test_client.get_connection()
|
||||
conn.auto_open = False
|
||||
conn.connect()
|
||||
conn.send(b'GET /hello HTTP/1.1')
|
||||
conn.send(('Host: %s' % conn.host).encode('ascii'))
|
||||
conn.close()
|
||||
|
||||
for _ in range(10): # Let the server attempt TCP shutdown
|
||||
time.sleep(0.1)
|
||||
if hasattr(_close_kernel_socket, 'exception_leaked'):
|
||||
break
|
||||
|
||||
if exc_instance is not None: # simulated by us
|
||||
assert _close_kernel_socket.fin_spy.spy_exception is exc_instance
|
||||
else: # real
|
||||
assert isinstance(
|
||||
_close_kernel_socket.fin_spy.spy_exception, socket.error,
|
||||
)
|
||||
assert _close_kernel_socket.fin_spy.spy_exception.errno == error_number
|
||||
|
||||
assert _close_kernel_socket.exception_leaked is exception_leaks
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'timeout_before_headers',
|
||||
(
|
||||
|
@ -475,7 +694,7 @@ def test_keepalive_conn_management(test_client):
|
|||
def test_HTTP11_Timeout(test_client, timeout_before_headers):
|
||||
"""Check timeout without sending any data.
|
||||
|
||||
The server will close the conn with a 408.
|
||||
The server will close the connection with a 408.
|
||||
"""
|
||||
conn = test_client.get_connection()
|
||||
conn.auto_open = False
|
||||
|
@ -594,7 +813,7 @@ def test_HTTP11_Timeout_after_request(test_client):
|
|||
def test_HTTP11_pipelining(test_client):
|
||||
"""Test HTTP/1.1 pipelining.
|
||||
|
||||
httplib doesn't support this directly.
|
||||
:py:mod:`http.client` doesn't support this directly.
|
||||
"""
|
||||
conn = test_client.get_connection()
|
||||
|
||||
|
@ -639,7 +858,7 @@ def test_100_Continue(test_client):
|
|||
conn = test_client.get_connection()
|
||||
|
||||
# Try a page without an Expect request header first.
|
||||
# Note that httplib's response.begin automatically ignores
|
||||
# Note that http.client's response.begin automatically ignores
|
||||
# 100 Continue responses, so we must manually check for it.
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
|
@ -800,11 +1019,13 @@ def test_No_Message_Body(test_client):
|
|||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason='Server does not correctly read trailers/ending of the previous '
|
||||
'HTTP request, thus the second request fails as the server tries '
|
||||
r"to parse b'Content-Type: application/json\r\n' as a "
|
||||
'Request-Line. This results in HTTP status code 400, instead of 413'
|
||||
'Ref: https://github.com/cherrypy/cheroot/issues/69',
|
||||
reason=unwrap(
|
||||
trim("""
|
||||
Headers from earlier request leak into the request
|
||||
line for a subsequent request, resulting in 400
|
||||
instead of 413. See cherrypy/cheroot#69 for details.
|
||||
"""),
|
||||
),
|
||||
)
|
||||
def test_Chunked_Encoding(test_client):
|
||||
"""Test HTTP uploads with chunked transfer-encoding."""
|
||||
|
@ -837,7 +1058,7 @@ def test_Chunked_Encoding(test_client):
|
|||
|
||||
# Try a chunked request that exceeds server.max_request_body_size.
|
||||
# Note that the delimiters and trailer are included.
|
||||
body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n'
|
||||
body = b'\r\n'.join((b'3e3', b'x' * 995, b'0', b'', b''))
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Transfer-Encoding', 'chunked')
|
||||
|
@ -895,7 +1116,7 @@ def test_Content_Length_not_int(test_client):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'uri,expected_resp_status,expected_resp_body',
|
||||
('uri', 'expected_resp_status', 'expected_resp_body'),
|
||||
(
|
||||
(
|
||||
'/wrong_cl_buffered', 500,
|
||||
|
@ -929,6 +1150,16 @@ def test_Content_Length_out(
|
|||
|
||||
conn.close()
|
||||
|
||||
# the server logs the exception that we had verified from the
|
||||
# client perspective. Tell the error_log verification that
|
||||
# it can ignore that message.
|
||||
test_client.server_instance.error_log.ignored_msgs.extend((
|
||||
# Python 3.7+:
|
||||
"ValueError('Response body exceeds the declared Content-Length.')",
|
||||
# Python 2.7-3.6 (macOS?):
|
||||
"ValueError('Response body exceeds the declared Content-Length.',)",
|
||||
))
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason='Sometimes this test fails due to low timeout. '
|
||||
|
@ -970,11 +1201,94 @@ def test_No_CRLF(test_client, invalid_terminator):
|
|||
# Initialize a persistent HTTP connection
|
||||
conn = test_client.get_connection()
|
||||
|
||||
# (b'%s' % b'') is not supported in Python 3.4, so just use +
|
||||
conn.send(b'GET /hello HTTP/1.1' + invalid_terminator)
|
||||
# (b'%s' % b'') is not supported in Python 3.4, so just use bytes.join()
|
||||
conn.send(b''.join((b'GET /hello HTTP/1.1', invalid_terminator)))
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
actual_resp_body = response.read()
|
||||
expected_resp_body = b'HTTP requires CRLF terminators'
|
||||
assert actual_resp_body == expected_resp_body
|
||||
conn.close()
|
||||
|
||||
|
||||
class FaultySelect:
|
||||
"""Mock class to insert errors in the selector.select method."""
|
||||
|
||||
def __init__(self, original_select):
|
||||
"""Initilize helper class to wrap the selector.select method."""
|
||||
self.original_select = original_select
|
||||
self.request_served = False
|
||||
self.os_error_triggered = False
|
||||
|
||||
def __call__(self, timeout):
|
||||
"""Intercept the calls to selector.select."""
|
||||
if self.request_served:
|
||||
self.os_error_triggered = True
|
||||
raise OSError('Error while selecting the client socket.')
|
||||
|
||||
return self.original_select(timeout)
|
||||
|
||||
|
||||
class FaultyGetMap:
|
||||
"""Mock class to insert errors in the selector.get_map method."""
|
||||
|
||||
def __init__(self, original_get_map):
|
||||
"""Initilize helper class to wrap the selector.get_map method."""
|
||||
self.original_get_map = original_get_map
|
||||
self.sabotage_conn = False
|
||||
self.conn_closed = False
|
||||
|
||||
def __call__(self):
|
||||
"""Intercept the calls to selector.get_map."""
|
||||
sabotage_targets = (
|
||||
conn for _, (_, _, _, conn) in self.original_get_map().items()
|
||||
if isinstance(conn, cheroot.server.HTTPConnection)
|
||||
) if self.sabotage_conn and not self.conn_closed else ()
|
||||
|
||||
for conn in sabotage_targets:
|
||||
# close the socket to cause OSError
|
||||
conn.close()
|
||||
self.conn_closed = True
|
||||
|
||||
return self.original_get_map()
|
||||
|
||||
|
||||
def test_invalid_selected_connection(test_client, monkeypatch):
|
||||
"""Test the error handling segment of HTTP connection selection.
|
||||
|
||||
See :py:meth:`cheroot.connections.ConnectionManager.get_conn`.
|
||||
"""
|
||||
# patch the select method
|
||||
faux_select = FaultySelect(
|
||||
test_client.server_instance._connections._selector.select,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
test_client.server_instance._connections._selector,
|
||||
'select',
|
||||
faux_select,
|
||||
)
|
||||
|
||||
# patch the get_map method
|
||||
faux_get_map = FaultyGetMap(
|
||||
test_client.server_instance._connections._selector._selector.get_map,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
test_client.server_instance._connections._selector._selector,
|
||||
'get_map',
|
||||
faux_get_map,
|
||||
)
|
||||
|
||||
# request a page with connection keep-alive to make sure
|
||||
# we'll have a connection to be modified.
|
||||
resp_status, resp_headers, resp_body = test_client.request(
|
||||
'/page1', headers=[('Connection', 'Keep-Alive')],
|
||||
)
|
||||
|
||||
assert resp_status == '200 OK'
|
||||
# trigger the internal errors
|
||||
faux_get_map.sabotage_conn = faux_select.request_served = True
|
||||
# give time to make sure the error gets handled
|
||||
time.sleep(test_client.server_instance.expiration_interval * 2)
|
||||
assert faux_select.os_error_triggered
|
||||
assert faux_get_map.conn_closed
|
||||
|
|
|
@ -18,6 +18,7 @@ from cheroot.test import helper
|
|||
HTTP_BAD_REQUEST = 400
|
||||
HTTP_LENGTH_REQUIRED = 411
|
||||
HTTP_NOT_FOUND = 404
|
||||
HTTP_REQUEST_ENTITY_TOO_LARGE = 413
|
||||
HTTP_OK = 200
|
||||
HTTP_VERSION_NOT_SUPPORTED = 505
|
||||
|
||||
|
@ -78,7 +79,7 @@ def _get_http_response(connection, method='GET'):
|
|||
|
||||
@pytest.fixture
|
||||
def testing_server(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and pre-configure it."""
|
||||
"""Attach a WSGI app to the given server and preconfigure it."""
|
||||
wsgi_server = wsgi_server_client.server_instance
|
||||
wsgi_server.wsgi_app = HelloController()
|
||||
wsgi_server.max_request_body_size = 30000000
|
||||
|
@ -92,6 +93,21 @@ def test_client(testing_server):
|
|||
return testing_server.server_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testing_server_with_defaults(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and preconfigure it."""
|
||||
wsgi_server = wsgi_server_client.server_instance
|
||||
wsgi_server.wsgi_app = HelloController()
|
||||
wsgi_server.server_client = wsgi_server_client
|
||||
return wsgi_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_with_defaults(testing_server_with_defaults):
|
||||
"""Get and return a test client out of the given server."""
|
||||
return testing_server_with_defaults.server_client
|
||||
|
||||
|
||||
def test_http_connect_request(test_client):
|
||||
"""Check that CONNECT query results in Method Not Allowed status."""
|
||||
status_line = test_client.connect('/anything')[0]
|
||||
|
@ -262,8 +278,30 @@ def test_content_length_required(test_client):
|
|||
assert actual_status == HTTP_LENGTH_REQUIRED
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason='https://github.com/cherrypy/cheroot/issues/106',
|
||||
strict=False, # sometimes it passes
|
||||
)
|
||||
def test_large_request(test_client_with_defaults):
|
||||
"""Test GET query with maliciously large Content-Length."""
|
||||
# If the server's max_request_body_size is not set (i.e. is set to 0)
|
||||
# then this will result in an `OverflowError: Python int too large to
|
||||
# convert to C ssize_t` in the server.
|
||||
# We expect that this should instead return that the request is too
|
||||
# large.
|
||||
c = test_client_with_defaults.get_connection()
|
||||
c.putrequest('GET', '/hello')
|
||||
c.putheader('Content-Length', str(2**64))
|
||||
c.endheaders()
|
||||
|
||||
response = c.getresponse()
|
||||
actual_status = response.status
|
||||
|
||||
assert actual_status == HTTP_REQUEST_ENTITY_TOO_LARGE
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'request_line,status_code,expected_body',
|
||||
('request_line', 'status_code', 'expected_body'),
|
||||
(
|
||||
(
|
||||
b'GET /', # missing proto
|
||||
|
@ -401,7 +439,7 @@ class CloseResponse:
|
|||
|
||||
@pytest.fixture
|
||||
def testing_server_close(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and pre-configure it."""
|
||||
"""Attach a WSGI app to the given server and preconfigure it."""
|
||||
wsgi_server = wsgi_server_client.server_instance
|
||||
wsgi_server.wsgi_app = CloseController()
|
||||
wsgi_server.max_request_body_size = 30000000
|
||||
|
|
|
@ -8,7 +8,7 @@ from cheroot.wsgi import PathInfoDispatcher
|
|||
|
||||
|
||||
def wsgi_invoke(app, environ):
|
||||
"""Serve 1 requeset from a WSGI application."""
|
||||
"""Serve 1 request from a WSGI application."""
|
||||
response = {}
|
||||
|
||||
def start_response(status, headers):
|
||||
|
@ -25,7 +25,7 @@ def wsgi_invoke(app, environ):
|
|||
|
||||
|
||||
def test_dispatch_no_script_name():
|
||||
"""Despatch despite lack of SCRIPT_NAME in environ."""
|
||||
"""Dispatch despite lack of ``SCRIPT_NAME`` in environ."""
|
||||
# Bare bones WSGI hello world app (from PEP 333).
|
||||
def app(environ, start_response):
|
||||
start_response(
|
||||
|
|
|
@ -8,7 +8,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'err_names,err_nums',
|
||||
('err_names', 'err_nums'),
|
||||
(
|
||||
(('', 'some-nonsense-name'), []),
|
||||
(
|
||||
|
@ -24,7 +24,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
|
|||
),
|
||||
)
|
||||
def test_plat_specific_errors(err_names, err_nums):
|
||||
"""Test that plat_specific_errors retrieves correct err num list."""
|
||||
"""Test that ``plat_specific_errors`` gets correct error numbers list."""
|
||||
actual_err_nums = errors.plat_specific_errors(*err_names)
|
||||
assert len(actual_err_nums) == len(err_nums)
|
||||
assert sorted(actual_err_nums) == sorted(err_nums)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""self-explanatory."""
|
||||
"""Tests for :py:mod:`cheroot.makefile`."""
|
||||
|
||||
from cheroot import makefile
|
||||
|
||||
|
@ -7,14 +7,14 @@ __metaclass__ = type
|
|||
|
||||
|
||||
class MockSocket:
|
||||
"""Mocks a socket."""
|
||||
"""A mock socket."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize."""
|
||||
"""Initialize :py:class:`MockSocket`."""
|
||||
self.messages = []
|
||||
|
||||
def recv_into(self, buf):
|
||||
"""Simulate recv_into for Python 3."""
|
||||
"""Simulate ``recv_into`` for Python 3."""
|
||||
if not self.messages:
|
||||
return 0
|
||||
msg = self.messages.pop(0)
|
||||
|
@ -23,7 +23,7 @@ class MockSocket:
|
|||
return len(msg)
|
||||
|
||||
def recv(self, size):
|
||||
"""Simulate recv for Python 2."""
|
||||
"""Simulate ``recv`` for Python 2."""
|
||||
try:
|
||||
return self.messages.pop(0)
|
||||
except IndexError:
|
||||
|
@ -44,7 +44,7 @@ def test_bytes_read():
|
|||
|
||||
|
||||
def test_bytes_written():
|
||||
"""Writer should capture bytes writtten."""
|
||||
"""Writer should capture bytes written."""
|
||||
sock = MockSocket()
|
||||
sock.messages.append(b'foo')
|
||||
wfile = makefile.MakeFile(sock, 'w')
|
||||
|
|
|
@ -5,10 +5,12 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
from contextlib import closing
|
||||
import os
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
@ -16,14 +18,15 @@ import requests
|
|||
import requests_unixsocket
|
||||
import six
|
||||
|
||||
from six.moves import queue, urllib
|
||||
|
||||
from .._compat import bton, ntob
|
||||
from .._compat import IS_LINUX, IS_MACOS, SYS_PLATFORM
|
||||
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM
|
||||
from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer
|
||||
from ..testing import (
|
||||
ANY_INTERFACE_IPV4,
|
||||
ANY_INTERFACE_IPV6,
|
||||
EPHEMERAL_PORT,
|
||||
get_server_client,
|
||||
)
|
||||
|
||||
|
||||
|
@ -42,15 +45,8 @@ non_macos_sock_test = pytest.mark.skipif(
|
|||
@pytest.fixture(params=('abstract', 'file'))
|
||||
def unix_sock_file(request):
|
||||
"""Check that bound UNIX socket address is stored in server."""
|
||||
if request.param == 'abstract':
|
||||
yield request.getfixturevalue('unix_abstract_sock')
|
||||
return
|
||||
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
|
||||
|
||||
yield tmp_sock_fname
|
||||
|
||||
os.close(tmp_sock_fh)
|
||||
os.unlink(tmp_sock_fname)
|
||||
name = 'unix_{request.param}_sock'.format(**locals())
|
||||
return request.getfixturevalue(name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -67,6 +63,17 @@ def unix_abstract_sock():
|
|||
)).decode()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unix_file_sock():
|
||||
"""Yield a unix file socket."""
|
||||
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
|
||||
|
||||
yield tmp_sock_fname
|
||||
|
||||
os.close(tmp_sock_fh)
|
||||
os.unlink(tmp_sock_fname)
|
||||
|
||||
|
||||
def test_prepare_makes_server_ready():
|
||||
"""Check that prepare() makes the server ready, and stop() clears it."""
|
||||
httpserver = HTTPServer(
|
||||
|
@ -110,6 +117,77 @@ def test_stop_interrupts_serve():
|
|||
assert not serve_thread.is_alive()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'exc_cls',
|
||||
(
|
||||
IOError,
|
||||
KeyboardInterrupt,
|
||||
OSError,
|
||||
RuntimeError,
|
||||
),
|
||||
)
|
||||
def test_server_interrupt(exc_cls):
|
||||
"""Check that assigning interrupt stops the server."""
|
||||
interrupt_msg = 'should catch {uuid!s}'.format(uuid=uuid.uuid4())
|
||||
raise_marker_sentinel = object()
|
||||
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
|
||||
gateway=Gateway,
|
||||
)
|
||||
|
||||
result_q = queue.Queue()
|
||||
|
||||
def serve_thread():
|
||||
# ensure we catch the exception on the serve() thread
|
||||
try:
|
||||
httpserver.serve()
|
||||
except exc_cls as e:
|
||||
if str(e) == interrupt_msg:
|
||||
result_q.put(raise_marker_sentinel)
|
||||
|
||||
httpserver.prepare()
|
||||
serve_thread = threading.Thread(target=serve_thread)
|
||||
serve_thread.start()
|
||||
|
||||
serve_thread.join(0.5)
|
||||
assert serve_thread.is_alive()
|
||||
|
||||
# this exception is raised on the serve() thread,
|
||||
# not in the calling context.
|
||||
httpserver.interrupt = exc_cls(interrupt_msg)
|
||||
|
||||
serve_thread.join(0.5)
|
||||
assert not serve_thread.is_alive()
|
||||
assert result_q.get_nowait() is raise_marker_sentinel
|
||||
|
||||
|
||||
def test_serving_is_false_and_stop_returns_after_ctrlc():
|
||||
"""Check that stop() interrupts running of serve()."""
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
|
||||
gateway=Gateway,
|
||||
)
|
||||
|
||||
httpserver.prepare()
|
||||
|
||||
# Simulate a Ctrl-C on the first call to `run`.
|
||||
def raise_keyboard_interrupt(*args, **kwargs):
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
httpserver._connections._selector.select = raise_keyboard_interrupt
|
||||
|
||||
serve_thread = threading.Thread(target=httpserver.serve)
|
||||
serve_thread.start()
|
||||
|
||||
# The thread should exit right away due to the interrupt.
|
||||
serve_thread.join(httpserver.expiration_interval * 2)
|
||||
assert not serve_thread.is_alive()
|
||||
|
||||
assert not httpserver._connections._serving
|
||||
httpserver.stop()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'ip_addr',
|
||||
(
|
||||
|
@ -135,7 +213,7 @@ def test_bind_addr_unix(http_server, unix_sock_file):
|
|||
|
||||
@unix_only_sock_test
|
||||
def test_bind_addr_unix_abstract(http_server, unix_abstract_sock):
|
||||
"""Check that bound UNIX abstract sockaddr is stored in server."""
|
||||
"""Check that bound UNIX abstract socket address is stored in server."""
|
||||
httpserver = http_server.send(unix_abstract_sock)
|
||||
|
||||
assert httpserver.bind_addr == unix_abstract_sock
|
||||
|
@ -167,27 +245,26 @@ class _TestGateway(Gateway):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def peercreds_enabled_server_and_client(http_server, unix_sock_file):
|
||||
"""Construct a test server with `peercreds_enabled`."""
|
||||
def peercreds_enabled_server(http_server, unix_sock_file):
|
||||
"""Construct a test server with ``peercreds_enabled``."""
|
||||
httpserver = http_server.send(unix_sock_file)
|
||||
httpserver.gateway = _TestGateway
|
||||
httpserver.peercreds_enabled = True
|
||||
return httpserver, get_server_client(httpserver)
|
||||
return httpserver
|
||||
|
||||
|
||||
@unix_only_sock_test
|
||||
@non_macos_sock_test
|
||||
def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
|
||||
"""Check that peercred lookup works when enabled."""
|
||||
httpserver, testclient = peercreds_enabled_server_and_client
|
||||
def test_peercreds_unix_sock(peercreds_enabled_server):
|
||||
"""Check that ``PEERCRED`` lookup works when enabled."""
|
||||
httpserver = peercreds_enabled_server
|
||||
bind_addr = httpserver.bind_addr
|
||||
|
||||
if isinstance(bind_addr, six.binary_type):
|
||||
bind_addr = bind_addr.decode()
|
||||
|
||||
unix_base_uri = 'http+unix://{}'.format(
|
||||
bind_addr.replace('\0', '%00').replace('/', '%2F'),
|
||||
)
|
||||
quoted = urllib.parse.quote(bind_addr, safe='')
|
||||
unix_base_uri = 'http+unix://{quoted}'.format(**locals())
|
||||
|
||||
expected_peercreds = os.getpid(), os.getuid(), os.getgid()
|
||||
expected_peercreds = '|'.join(map(str, expected_peercreds))
|
||||
|
@ -208,9 +285,9 @@ def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
|
|||
)
|
||||
@unix_only_sock_test
|
||||
@non_macos_sock_test
|
||||
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
|
||||
"""Check that peercred resolution works when enabled."""
|
||||
httpserver, testclient = peercreds_enabled_server_and_client
|
||||
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server):
|
||||
"""Check that ``PEERCRED`` resolution works when enabled."""
|
||||
httpserver = peercreds_enabled_server
|
||||
httpserver.peercreds_resolve_enabled = True
|
||||
|
||||
bind_addr = httpserver.bind_addr
|
||||
|
@ -218,9 +295,8 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
|
|||
if isinstance(bind_addr, six.binary_type):
|
||||
bind_addr = bind_addr.decode()
|
||||
|
||||
unix_base_uri = 'http+unix://{}'.format(
|
||||
bind_addr.replace('\0', '%00').replace('/', '%2F'),
|
||||
)
|
||||
quoted = urllib.parse.quote(bind_addr, safe='')
|
||||
unix_base_uri = 'http+unix://{quoted}'.format(**locals())
|
||||
|
||||
import grp
|
||||
import pwd
|
||||
|
@ -233,3 +309,111 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
|
|||
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI)
|
||||
peercreds_text_resp.raise_for_status()
|
||||
assert peercreds_text_resp.text == expected_textcreds
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
IS_WINDOWS,
|
||||
reason='This regression test is for a Linux bug, '
|
||||
'and the resource module is not available on Windows',
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'resource_limit',
|
||||
(
|
||||
1024,
|
||||
2048,
|
||||
),
|
||||
indirect=('resource_limit',),
|
||||
)
|
||||
@pytest.mark.usefixtures('many_open_sockets')
|
||||
def test_high_number_of_file_descriptors(resource_limit):
|
||||
"""Test the server does not crash with a high file-descriptor value.
|
||||
|
||||
This test shouldn't cause a server crash when trying to access
|
||||
file-descriptor higher than 1024.
|
||||
|
||||
The earlier implementation used to rely on ``select()`` syscall that
|
||||
doesn't support file descriptors with numbers higher than 1024.
|
||||
"""
|
||||
# We want to force the server to use a file-descriptor with
|
||||
# a number above resource_limit
|
||||
|
||||
# Create our server
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), gateway=Gateway,
|
||||
)
|
||||
|
||||
try:
|
||||
# This will trigger a crash if select() is used in the implementation
|
||||
with httpserver._run_in_thread():
|
||||
# allow server to run long enough to invoke select()
|
||||
time.sleep(1.0)
|
||||
except: # noqa: E722
|
||||
raise # only needed for `else` to work
|
||||
else:
|
||||
# We use closing here for py2-compat
|
||||
with closing(socket.socket()) as sock:
|
||||
# Check new sockets created are still above our target number
|
||||
assert sock.fileno() >= resource_limit
|
||||
finally:
|
||||
# Stop our server
|
||||
httpserver.stop()
|
||||
|
||||
|
||||
if not IS_WINDOWS:
|
||||
test_high_number_of_file_descriptors = pytest.mark.forked(
|
||||
test_high_number_of_file_descriptors,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resource_limit(request):
|
||||
"""Set the resource limit two times bigger then requested."""
|
||||
resource = pytest.importorskip(
|
||||
'resource',
|
||||
reason='The "resource" module is Unix-specific',
|
||||
)
|
||||
|
||||
# Get current resource limits to restore them later
|
||||
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
|
||||
# We have to increase the nofile limit above 1024
|
||||
# Otherwise we see a 'Too many files open' error, instead of
|
||||
# an error due to the file descriptor number being too high
|
||||
resource.setrlimit(
|
||||
resource.RLIMIT_NOFILE,
|
||||
(request.param * 2, hard_limit),
|
||||
)
|
||||
|
||||
try: # noqa: WPS501
|
||||
yield request.param
|
||||
finally:
|
||||
# Reset the resource limit back to the original soft limit
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def many_open_sockets(resource_limit):
|
||||
"""Allocate a lot of file descriptors by opening dummy sockets."""
|
||||
# Hoard a lot of file descriptors by opening and storing a lot of sockets
|
||||
test_sockets = []
|
||||
# Open a lot of file descriptors, so the next one the server
|
||||
# opens is a high number
|
||||
try:
|
||||
for i in range(resource_limit):
|
||||
sock = socket.socket()
|
||||
test_sockets.append(sock)
|
||||
# NOTE: We used to interrupt the loop early but this doesn't seem
|
||||
# NOTE: to work well in envs with indeterministic runtimes like
|
||||
# NOTE: PyPy. It looks like sometimes it frees some file
|
||||
# NOTE: descriptors in between running this fixture and the actual
|
||||
# NOTE: test code so the early break has been removed to try
|
||||
# NOTE: address that. The approach may need to be rethought if the
|
||||
# NOTE: issue reoccurs. Another approach may be disabling the GC.
|
||||
# Check we opened enough descriptors to reach a high number
|
||||
the_highest_fileno = max(sock.fileno() for sock in test_sockets)
|
||||
assert the_highest_fileno >= resource_limit
|
||||
yield the_highest_fileno
|
||||
finally:
|
||||
# Close our open resources
|
||||
for test_socket in test_sockets:
|
||||
test_socket.close()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for TLS/SSL support."""
|
||||
"""Tests for TLS support."""
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
|
||||
|
@ -6,11 +6,14 @@ from __future__ import absolute_import, division, print_function
|
|||
__metaclass__ = type
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import OpenSSL.SSL
|
||||
import pytest
|
||||
|
@ -21,7 +24,7 @@ import trustme
|
|||
from .._compat import bton, ntob, ntou
|
||||
from .._compat import IS_ABOVE_OPENSSL10, IS_PYPY
|
||||
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
|
||||
from ..server import Gateway, HTTPServer, get_ssl_adapter_class
|
||||
from ..server import HTTPServer, get_ssl_adapter_class
|
||||
from ..testing import (
|
||||
ANY_INTERFACE_IPV4,
|
||||
ANY_INTERFACE_IPV6,
|
||||
|
@ -30,9 +33,17 @@ from ..testing import (
|
|||
_get_conn_data,
|
||||
_probe_ipv6_sock,
|
||||
)
|
||||
from ..wsgi import Gateway_10
|
||||
|
||||
|
||||
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
|
||||
IS_WIN2016 = (
|
||||
IS_WINDOWS
|
||||
# pylint: disable=unsupported-membership-test
|
||||
and b'Microsoft Windows Server 2016 Datacenter' in subprocess.check_output(
|
||||
('systeminfo',),
|
||||
)
|
||||
)
|
||||
IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL')
|
||||
IS_PYOPENSSL_SSL_VERSION_1_0 = (
|
||||
OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION).
|
||||
|
@ -40,6 +51,7 @@ IS_PYOPENSSL_SSL_VERSION_1_0 = (
|
|||
)
|
||||
PY27 = sys.version_info[:2] == (2, 7)
|
||||
PY34 = sys.version_info[:2] == (3, 4)
|
||||
PY3 = not six.PY2
|
||||
|
||||
|
||||
_stdlib_to_openssl_verify = {
|
||||
|
@ -71,7 +83,7 @@ missing_ipv6 = pytest.mark.skipif(
|
|||
)
|
||||
|
||||
|
||||
class HelloWorldGateway(Gateway):
|
||||
class HelloWorldGateway(Gateway_10):
|
||||
"""Gateway responding with Hello World to root URI."""
|
||||
|
||||
def respond(self):
|
||||
|
@ -83,11 +95,21 @@ class HelloWorldGateway(Gateway):
|
|||
req.ensure_headers_sent()
|
||||
req.write(b'Hello world!')
|
||||
return
|
||||
if req_uri == '/env':
|
||||
req.status = b'200 OK'
|
||||
req.ensure_headers_sent()
|
||||
env = self.get_environ()
|
||||
# drop files so that it can be json dumped
|
||||
env.pop('wsgi.errors')
|
||||
env.pop('wsgi.input')
|
||||
print(env)
|
||||
req.write(json.dumps(env).encode('utf-8'))
|
||||
return
|
||||
return super(HelloWorldGateway, self).respond()
|
||||
|
||||
|
||||
def make_tls_http_server(bind_addr, ssl_adapter, request):
|
||||
"""Create and start an HTTP server bound to bind_addr."""
|
||||
"""Create and start an HTTP server bound to ``bind_addr``."""
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=bind_addr,
|
||||
gateway=HelloWorldGateway,
|
||||
|
@ -128,7 +150,7 @@ def tls_ca_certificate_pem_path(ca):
|
|||
def tls_certificate(ca):
|
||||
"""Provide a leaf certificate via fixture."""
|
||||
interface, host, port = _get_conn_data(ANY_INTERFACE_IPV4)
|
||||
return ca.issue_server_cert(ntou(interface), )
|
||||
return ca.issue_server_cert(ntou(interface))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -145,6 +167,43 @@ def tls_certificate_private_key_pem_path(tls_certificate):
|
|||
yield cert_key_pem
|
||||
|
||||
|
||||
def _thread_except_hook(exceptions, args):
|
||||
"""Append uncaught exception ``args`` in threads to ``exceptions``."""
|
||||
if issubclass(args.exc_type, SystemExit):
|
||||
return
|
||||
# cannot store the exception, it references the thread's stack
|
||||
exceptions.append((
|
||||
args.exc_type,
|
||||
str(args.exc_value),
|
||||
''.join(
|
||||
traceback.format_exception(
|
||||
args.exc_type, args.exc_value, args.exc_traceback,
|
||||
),
|
||||
),
|
||||
))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_exceptions():
|
||||
"""Provide a list of uncaught exceptions from threads via a fixture.
|
||||
|
||||
Only catches exceptions on Python 3.8+.
|
||||
The list contains: ``(type, str(value), str(traceback))``
|
||||
"""
|
||||
exceptions = []
|
||||
# Python 3.8+
|
||||
orig_hook = getattr(threading, 'excepthook', None)
|
||||
if orig_hook is not None:
|
||||
threading.excepthook = functools.partial(
|
||||
_thread_except_hook, exceptions,
|
||||
)
|
||||
try:
|
||||
yield exceptions
|
||||
finally:
|
||||
if orig_hook is not None:
|
||||
threading.excepthook = orig_hook
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'adapter_type',
|
||||
(
|
||||
|
@ -180,7 +239,7 @@ def test_ssl_adapters(
|
|||
)
|
||||
|
||||
resp = requests.get(
|
||||
'https://' + interface + ':' + str(port) + '/',
|
||||
'https://{host!s}:{port!s}/'.format(host=interface, port=port),
|
||||
verify=tls_ca_certificate_pem_path,
|
||||
)
|
||||
|
||||
|
@ -188,7 +247,7 @@ def test_ssl_adapters(
|
|||
assert resp.text == 'Hello world!'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@pytest.mark.parametrize( # noqa: C901 # FIXME
|
||||
'adapter_type',
|
||||
(
|
||||
'builtin',
|
||||
|
@ -196,7 +255,7 @@ def test_ssl_adapters(
|
|||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
'is_trusted_cert,tls_client_identity',
|
||||
('is_trusted_cert', 'tls_client_identity'),
|
||||
(
|
||||
(True, 'localhost'), (True, '127.0.0.1'),
|
||||
(True, '*.localhost'), (True, 'not_localhost'),
|
||||
|
@ -211,7 +270,7 @@ def test_ssl_adapters(
|
|||
ssl.CERT_REQUIRED, # server should validate if client cert CA is OK
|
||||
),
|
||||
)
|
||||
def test_tls_client_auth(
|
||||
def test_tls_client_auth( # noqa: C901 # FIXME
|
||||
# FIXME: remove twisted logic, separate tests
|
||||
mocker,
|
||||
tls_http_server, adapter_type,
|
||||
|
@ -265,7 +324,7 @@ def test_tls_client_auth(
|
|||
|
||||
make_https_request = functools.partial(
|
||||
requests.get,
|
||||
'https://' + interface + ':' + str(port) + '/',
|
||||
'https://{host!s}:{port!s}/'.format(host=interface, port=port),
|
||||
|
||||
# Server TLS certificate verification:
|
||||
verify=tls_ca_certificate_pem_path,
|
||||
|
@ -324,21 +383,23 @@ def test_tls_client_auth(
|
|||
except AttributeError:
|
||||
if PY34:
|
||||
pytest.xfail('OpenSSL behaves wierdly under Python 3.4')
|
||||
elif not six.PY2 and IS_WINDOWS:
|
||||
elif IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
|
||||
err_text = str(ssl_err.value)
|
||||
else:
|
||||
raise
|
||||
|
||||
if isinstance(err_text, int):
|
||||
err_text = str(ssl_err.value)
|
||||
|
||||
expected_substrings = (
|
||||
'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND
|
||||
else 'tlsv1 alert unknown ca',
|
||||
)
|
||||
if not six.PY2:
|
||||
if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl':
|
||||
expected_substrings = ('tlsv1 alert unknown ca', )
|
||||
expected_substrings = ('tlsv1 alert unknown ca',)
|
||||
if (
|
||||
IS_WINDOWS
|
||||
and tls_verify_mode in (
|
||||
tls_verify_mode in (
|
||||
ssl.CERT_REQUIRED,
|
||||
ssl.CERT_OPTIONAL,
|
||||
)
|
||||
|
@ -350,10 +411,172 @@ def test_tls_client_auth(
|
|||
"SysCallError(10054, 'WSAECONNRESET')",
|
||||
"('Connection aborted.', "
|
||||
'OSError("(10054, \'WSAECONNRESET\')"))',
|
||||
"('Connection aborted.', "
|
||||
'OSError("(10054, \'WSAECONNRESET\')",))',
|
||||
"('Connection aborted.', "
|
||||
'error("(10054, \'WSAECONNRESET\')",))',
|
||||
"('Connection aborted.', "
|
||||
'ConnectionResetError(10054, '
|
||||
"'An existing connection was forcibly closed "
|
||||
"by the remote host', None, 10054, None))",
|
||||
) if IS_WINDOWS else (
|
||||
"('Connection aborted.', "
|
||||
'OSError("(104, \'ECONNRESET\')"))',
|
||||
"('Connection aborted.', "
|
||||
'OSError("(104, \'ECONNRESET\')",))',
|
||||
"('Connection aborted.', "
|
||||
'error("(104, \'ECONNRESET\')",))',
|
||||
"('Connection aborted.', "
|
||||
"ConnectionResetError(104, 'Connection reset by peer'))",
|
||||
"('Connection aborted.', "
|
||||
"error(104, 'Connection reset by peer'))",
|
||||
) if (
|
||||
IS_GITHUB_ACTIONS_WORKFLOW
|
||||
and IS_LINUX
|
||||
) else (
|
||||
"('Connection aborted.', "
|
||||
"BrokenPipeError(32, 'Broken pipe'))",
|
||||
)
|
||||
assert any(e in err_text for e in expected_substrings)
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: C901 # FIXME
|
||||
'adapter_type',
|
||||
(
|
||||
'builtin',
|
||||
'pyopenssl',
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
('tls_verify_mode', 'use_client_cert'),
|
||||
(
|
||||
(ssl.CERT_NONE, False),
|
||||
(ssl.CERT_NONE, True),
|
||||
(ssl.CERT_OPTIONAL, False),
|
||||
(ssl.CERT_OPTIONAL, True),
|
||||
(ssl.CERT_REQUIRED, True),
|
||||
),
|
||||
)
|
||||
def test_ssl_env( # noqa: C901 # FIXME
|
||||
thread_exceptions,
|
||||
recwarn,
|
||||
mocker,
|
||||
tls_http_server, adapter_type,
|
||||
ca, tls_verify_mode, tls_certificate,
|
||||
tls_certificate_chain_pem_path,
|
||||
tls_certificate_private_key_pem_path,
|
||||
tls_ca_certificate_pem_path,
|
||||
use_client_cert,
|
||||
):
|
||||
"""Test the SSL environment generated by the SSL adapters."""
|
||||
interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4)
|
||||
|
||||
with mocker.mock_module.patch(
|
||||
'idna.core.ulabel',
|
||||
return_value=ntob('127.0.0.1'),
|
||||
):
|
||||
client_cert = ca.issue_cert(ntou('127.0.0.1'))
|
||||
|
||||
with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem:
|
||||
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
|
||||
tls_adapter = tls_adapter_cls(
|
||||
tls_certificate_chain_pem_path,
|
||||
tls_certificate_private_key_pem_path,
|
||||
)
|
||||
if adapter_type == 'pyopenssl':
|
||||
tls_adapter.context = tls_adapter.get_context()
|
||||
tls_adapter.context.set_verify(
|
||||
_stdlib_to_openssl_verify[tls_verify_mode],
|
||||
lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
|
||||
)
|
||||
else:
|
||||
tls_adapter.context.verify_mode = tls_verify_mode
|
||||
|
||||
ca.configure_trust(tls_adapter.context)
|
||||
tls_certificate.configure_cert(tls_adapter.context)
|
||||
|
||||
tlswsgiserver = tls_http_server((interface, port), tls_adapter)
|
||||
|
||||
interface, _host, port = _get_conn_data(tlswsgiserver.bind_addr)
|
||||
|
||||
resp = requests.get(
|
||||
'https://' + interface + ':' + str(port) + '/env',
|
||||
verify=tls_ca_certificate_pem_path,
|
||||
cert=cl_pem if use_client_cert else None,
|
||||
)
|
||||
if PY34 and resp.status_code != 200:
|
||||
pytest.xfail(
|
||||
'Python 3.4 has problems with verifying client certs',
|
||||
)
|
||||
|
||||
env = json.loads(resp.content.decode('utf-8'))
|
||||
|
||||
# hard coded env
|
||||
assert env['wsgi.url_scheme'] == 'https'
|
||||
assert env['HTTPS'] == 'on'
|
||||
|
||||
# ensure these are present
|
||||
for key in {'SSL_VERSION_INTERFACE', 'SSL_VERSION_LIBRARY'}:
|
||||
assert key in env
|
||||
|
||||
# pyOpenSSL generates the env before the handshake completes
|
||||
if adapter_type == 'pyopenssl':
|
||||
return
|
||||
|
||||
for key in {'SSL_PROTOCOL', 'SSL_CIPHER'}:
|
||||
assert key in env
|
||||
|
||||
# client certificate env
|
||||
if tls_verify_mode == ssl.CERT_NONE or not use_client_cert:
|
||||
assert env['SSL_CLIENT_VERIFY'] == 'NONE'
|
||||
else:
|
||||
assert env['SSL_CLIENT_VERIFY'] == 'SUCCESS'
|
||||
|
||||
with open(cl_pem, 'rt') as f:
|
||||
assert env['SSL_CLIENT_CERT'] in f.read()
|
||||
|
||||
for key in {
|
||||
'SSL_CLIENT_M_VERSION', 'SSL_CLIENT_M_SERIAL',
|
||||
'SSL_CLIENT_I_DN', 'SSL_CLIENT_S_DN',
|
||||
}:
|
||||
assert key in env
|
||||
|
||||
# builtin ssl environment generation may use a loopback socket
|
||||
# ensure no ResourceWarning was raised during the test
|
||||
# NOTE: python 2.7 does not emit ResourceWarning for ssl sockets
|
||||
if IS_PYPY:
|
||||
# NOTE: PyPy doesn't have ResourceWarning
|
||||
# Ref: https://doc.pypy.org/en/latest/cpython_differences.html
|
||||
return
|
||||
for warn in recwarn:
|
||||
if not issubclass(warn.category, ResourceWarning):
|
||||
continue
|
||||
|
||||
# the tests can sporadically generate resource warnings
|
||||
# due to timing issues
|
||||
# all of these sporadic warnings appear to be about socket.socket
|
||||
# and have been observed to come from requests connection pool
|
||||
msg = str(warn.message)
|
||||
if 'socket.socket' in msg:
|
||||
pytest.xfail(
|
||||
'\n'.join((
|
||||
'Sometimes this test fails due to '
|
||||
'a socket.socket ResourceWarning:',
|
||||
msg,
|
||||
)),
|
||||
)
|
||||
pytest.fail(msg)
|
||||
|
||||
# to perform the ssl handshake over that loopback socket,
|
||||
# the builtin ssl environment generation uses a thread
|
||||
for _, _, trace in thread_exceptions:
|
||||
print(trace, file=sys.stderr)
|
||||
assert not thread_exceptions, ': '.join((
|
||||
thread_exceptions[0][0].__name__,
|
||||
thread_exceptions[0][1],
|
||||
))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'ip_addr',
|
||||
(
|
||||
|
@ -382,7 +605,16 @@ def test_https_over_http_error(http_server, ip_addr):
|
|||
@pytest.mark.parametrize(
|
||||
'adapter_type',
|
||||
(
|
||||
pytest.param(
|
||||
'builtin',
|
||||
marks=pytest.mark.xfail(
|
||||
IS_WINDOWS and six.PY2,
|
||||
raises=requests.exceptions.ConnectionError,
|
||||
reason='Stdlib `ssl` module behaves weirdly '
|
||||
'on Windows under Python 2',
|
||||
strict=False,
|
||||
),
|
||||
),
|
||||
'pyopenssl',
|
||||
),
|
||||
)
|
||||
|
@ -428,16 +660,41 @@ def test_http_over_https_error(
|
|||
|
||||
fqdn = interface
|
||||
if ip_addr is ANY_INTERFACE_IPV6:
|
||||
fqdn = '[{}]'.format(fqdn)
|
||||
fqdn = '[{fqdn}]'.format(**locals())
|
||||
|
||||
expect_fallback_response_over_plain_http = (
|
||||
(adapter_type == 'pyopenssl'
|
||||
and (IS_ABOVE_OPENSSL10 or not six.PY2))
|
||||
or PY27
|
||||
(
|
||||
adapter_type == 'pyopenssl'
|
||||
and (IS_ABOVE_OPENSSL10 or not six.PY2)
|
||||
)
|
||||
or PY27
|
||||
) or (
|
||||
IS_GITHUB_ACTIONS_WORKFLOW
|
||||
and IS_WINDOWS
|
||||
and six.PY2
|
||||
and not IS_WIN2016
|
||||
)
|
||||
if (
|
||||
IS_GITHUB_ACTIONS_WORKFLOW
|
||||
and IS_WINDOWS
|
||||
and six.PY2
|
||||
and IS_WIN2016
|
||||
and adapter_type == 'builtin'
|
||||
and ip_addr is ANY_INTERFACE_IPV6
|
||||
):
|
||||
expect_fallback_response_over_plain_http = True
|
||||
if (
|
||||
IS_GITHUB_ACTIONS_WORKFLOW
|
||||
and IS_WINDOWS
|
||||
and six.PY2
|
||||
and not IS_WIN2016
|
||||
and adapter_type == 'builtin'
|
||||
and ip_addr is not ANY_INTERFACE_IPV6
|
||||
):
|
||||
expect_fallback_response_over_plain_http = False
|
||||
if expect_fallback_response_over_plain_http:
|
||||
resp = requests.get(
|
||||
'http://' + fqdn + ':' + str(port) + '/',
|
||||
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.text == (
|
||||
|
@ -448,7 +705,7 @@ def test_http_over_https_error(
|
|||
|
||||
with pytest.raises(requests.exceptions.ConnectionError) as ssl_err:
|
||||
requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL
|
||||
'http://' + fqdn + ':' + str(port) + '/',
|
||||
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
|
||||
)
|
||||
|
||||
if IS_LINUX:
|
||||
|
@ -468,7 +725,7 @@ def test_http_over_https_error(
|
|||
underlying_error = ssl_err.value.args[0].args[-1]
|
||||
err_text = str(underlying_error)
|
||||
assert underlying_error.errno == expected_error_code, (
|
||||
'The underlying error is {!r}'.
|
||||
format(underlying_error)
|
||||
'The underlying error is {underlying_error!r}'.
|
||||
format(**locals())
|
||||
)
|
||||
assert expected_error_text in err_text
|
||||
|
|
58
lib/cheroot/test/test_wsgi.py
Normal file
58
lib/cheroot/test/test_wsgi.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
"""Test wsgi."""
|
||||
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
import portend
|
||||
import requests
|
||||
from requests_toolbelt.sessions import BaseUrlSession as Session
|
||||
from jaraco.context import ExceptionTrap
|
||||
|
||||
from cheroot import wsgi
|
||||
from cheroot._compat import IS_MACOS, IS_WINDOWS
|
||||
|
||||
|
||||
IS_SLOW_ENV = IS_MACOS or IS_WINDOWS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_wsgi_server():
|
||||
"""Fucking simple wsgi server fixture (duh)."""
|
||||
port = portend.find_available_local_port()
|
||||
|
||||
def app(environ, start_response):
|
||||
status = '200 OK'
|
||||
response_headers = [('Content-type', 'text/plain')]
|
||||
start_response(status, response_headers)
|
||||
return [b'Hello world!']
|
||||
|
||||
host = '::'
|
||||
addr = host, port
|
||||
server = wsgi.Server(addr, app, timeout=600 if IS_SLOW_ENV else 20)
|
||||
url = 'http://localhost:{port}/'.format(**locals())
|
||||
with server._run_in_thread() as thread:
|
||||
yield locals()
|
||||
|
||||
|
||||
def test_connection_keepalive(simple_wsgi_server):
|
||||
"""Test the connection keepalive works (duh)."""
|
||||
session = Session(base_url=simple_wsgi_server['url'])
|
||||
pooled = requests.adapters.HTTPAdapter(
|
||||
pool_connections=1, pool_maxsize=1000,
|
||||
)
|
||||
session.mount('http://', pooled)
|
||||
|
||||
def do_request():
|
||||
with ExceptionTrap(requests.exceptions.ConnectionError) as trap:
|
||||
resp = session.get('info')
|
||||
resp.raise_for_status()
|
||||
return bool(trap)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10 if IS_SLOW_ENV else 50) as pool:
|
||||
tasks = [
|
||||
pool.submit(do_request)
|
||||
for n in range(250 if IS_SLOW_ENV else 1000)
|
||||
]
|
||||
failures = sum(task.result() for task in tasks)
|
||||
|
||||
assert not failures
|
|
@ -1,12 +1,14 @@
|
|||
"""Extensions to unittest for web frameworks.
|
||||
|
||||
Use the WebCase.getPage method to request a page from your HTTP server.
|
||||
Use the :py:meth:`WebCase.getPage` method to request a page
|
||||
from your HTTP server.
|
||||
|
||||
Framework Integration
|
||||
=====================
|
||||
If you have control over your server process, you can handle errors
|
||||
in the server-side of the HTTP conversation a bit better. You must run
|
||||
both the client (your WebCase tests) and the server in the same process
|
||||
(but in separate threads, obviously).
|
||||
both the client (your :py:class:`WebCase` tests) and the server in the
|
||||
same process (but in separate threads, obviously).
|
||||
When an error occurs in the framework, call server_error. It will print
|
||||
the traceback to stdout, and keep any assertions you have from running
|
||||
(the assumption is that, if the server errors, the page output will not
|
||||
|
@ -122,7 +124,7 @@ class WebCase(unittest.TestCase):
|
|||
def _Conn(self):
|
||||
"""Return HTTPConnection or HTTPSConnection based on self.scheme.
|
||||
|
||||
* from http.client.
|
||||
* from :py:mod:`python:http.client`.
|
||||
"""
|
||||
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
|
||||
return getattr(http_client, cls_name)
|
||||
|
@ -157,7 +159,7 @@ class WebCase(unittest.TestCase):
|
|||
|
||||
@property
|
||||
def persistent(self):
|
||||
"""Presense of the persistent HTTP connection."""
|
||||
"""Presence of the persistent HTTP connection."""
|
||||
return hasattr(self.HTTP_CONN, '__class__')
|
||||
|
||||
@persistent.setter
|
||||
|
@ -176,7 +178,9 @@ class WebCase(unittest.TestCase):
|
|||
self, url, headers=None, method='GET', body=None,
|
||||
protocol=None, raise_subcls=(),
|
||||
):
|
||||
"""Open the url with debugging support. Return status, headers, body.
|
||||
"""Open the url with debugging support.
|
||||
|
||||
Return status, headers, body.
|
||||
|
||||
url should be the identifier passed to the server, typically a
|
||||
server-absolute path and query string (sent between method and
|
||||
|
@ -184,16 +188,16 @@ class WebCase(unittest.TestCase):
|
|||
enabled in the server.
|
||||
|
||||
If the application under test generates absolute URIs, be sure
|
||||
to wrap them first with strip_netloc::
|
||||
to wrap them first with :py:func:`strip_netloc`::
|
||||
|
||||
class MyAppWebCase(WebCase):
|
||||
def getPage(url, *args, **kwargs):
|
||||
super(MyAppWebCase, self).getPage(
|
||||
cheroot.test.webtest.strip_netloc(url),
|
||||
*args, **kwargs
|
||||
)
|
||||
>>> class MyAppWebCase(WebCase):
|
||||
... def getPage(url, *args, **kwargs):
|
||||
... super(MyAppWebCase, self).getPage(
|
||||
... cheroot.test.webtest.strip_netloc(url),
|
||||
... *args, **kwargs
|
||||
... )
|
||||
|
||||
`raise_subcls` is passed through to openURL.
|
||||
``raise_subcls`` is passed through to :py:func:`openURL`.
|
||||
"""
|
||||
ServerError.on = False
|
||||
|
||||
|
@ -247,7 +251,7 @@ class WebCase(unittest.TestCase):
|
|||
|
||||
console_height = 30
|
||||
|
||||
def _handlewebError(self, msg):
|
||||
def _handlewebError(self, msg): # noqa: C901 # FIXME
|
||||
print('')
|
||||
print(' ERROR: %s' % msg)
|
||||
|
||||
|
@ -487,7 +491,7 @@ def openURL(*args, **kwargs):
|
|||
"""
|
||||
Open a URL, retrying when it fails.
|
||||
|
||||
Specify `raise_subcls` (class or tuple of classes) to exclude
|
||||
Specify ``raise_subcls`` (class or tuple of classes) to exclude
|
||||
those socket.error subclasses from being suppressed and retried.
|
||||
"""
|
||||
raise_subcls = kwargs.pop('raise_subcls', ())
|
||||
|
@ -553,7 +557,7 @@ def strip_netloc(url):
|
|||
server-absolute portion.
|
||||
|
||||
Useful for wrapping an absolute-URI for which only the
|
||||
path is expected (such as in calls to getPage).
|
||||
path is expected (such as in calls to :py:meth:`WebCase.getPage`).
|
||||
|
||||
>>> strip_netloc('https://google.com/foo/bar?bing#baz')
|
||||
'/foo/bar?bing'
|
||||
|
|
|
@ -61,14 +61,14 @@ def cheroot_server(server_factory):
|
|||
httpserver.stop() # destroy it
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture
|
||||
def wsgi_server():
|
||||
"""Set up and tear down a Cheroot WSGI server instance."""
|
||||
for srv in cheroot_server(cheroot.wsgi.Server):
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture
|
||||
def native_server():
|
||||
"""Set up and tear down a Cheroot HTTP server instance."""
|
||||
for srv in cheroot_server(cheroot.server.HTTPServer):
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
"""A thread-based worker pool."""
|
||||
"""A thread-based worker pool.
|
||||
|
||||
.. spelling::
|
||||
|
||||
joinable
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
@ -111,11 +116,6 @@ class WorkerThread(threading.Thread):
|
|||
if conn is _SHUTDOWNREQUEST:
|
||||
return
|
||||
|
||||
# Just close the connection and move on.
|
||||
if conn.closeable:
|
||||
conn.close()
|
||||
continue
|
||||
|
||||
self.conn = conn
|
||||
is_stats_enabled = self.server.stats['Enabled']
|
||||
if is_stats_enabled:
|
||||
|
@ -125,7 +125,7 @@ class WorkerThread(threading.Thread):
|
|||
keep_conn_open = conn.communicate()
|
||||
finally:
|
||||
if keep_conn_open:
|
||||
self.server.connections.put(conn)
|
||||
self.server.put_conn(conn)
|
||||
else:
|
||||
conn.close()
|
||||
if is_stats_enabled:
|
||||
|
@ -176,7 +176,10 @@ class ThreadPool:
|
|||
for i in range(self.min):
|
||||
self._threads.append(WorkerThread(self.server))
|
||||
for worker in self._threads:
|
||||
worker.setName('CP Server ' + worker.getName())
|
||||
worker.setName(
|
||||
'CP Server {worker_name!s}'.
|
||||
format(worker_name=worker.getName()),
|
||||
)
|
||||
worker.start()
|
||||
for worker in self._threads:
|
||||
while not worker.ready:
|
||||
|
@ -192,7 +195,7 @@ class ThreadPool:
|
|||
"""Put request into queue.
|
||||
|
||||
Args:
|
||||
obj (cheroot.server.HTTPConnection): HTTP connection
|
||||
obj (:py:class:`~cheroot.server.HTTPConnection`): HTTP connection
|
||||
waiting to be processed
|
||||
"""
|
||||
self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
|
||||
|
@ -223,7 +226,10 @@ class ThreadPool:
|
|||
|
||||
def _spawn_worker(self):
|
||||
worker = WorkerThread(self.server)
|
||||
worker.setName('CP Server ' + worker.getName())
|
||||
worker.setName(
|
||||
'CP Server {worker_name!s}'.
|
||||
format(worker_name=worker.getName()),
|
||||
)
|
||||
worker.start()
|
||||
return worker
|
||||
|
||||
|
|
|
@ -119,10 +119,7 @@ class Gateway(server.Gateway):
|
|||
corresponding class
|
||||
|
||||
"""
|
||||
return dict(
|
||||
(gw.version, gw)
|
||||
for gw in cls.__subclasses__()
|
||||
)
|
||||
return {gw.version: gw for gw in cls.__subclasses__()}
|
||||
|
||||
def get_environ(self):
|
||||
"""Return a new environ dict targeting the given wsgi.version."""
|
||||
|
@ -195,9 +192,9 @@ class Gateway(server.Gateway):
|
|||
"""Cast status to bytes representation of current Python version.
|
||||
|
||||
According to :pep:`3333`, when using Python 3, the response status
|
||||
and headers must be bytes masquerading as unicode; that is, they
|
||||
and headers must be bytes masquerading as Unicode; that is, they
|
||||
must be of type "str" but are restricted to code points in the
|
||||
"latin-1" set.
|
||||
"Latin-1" set.
|
||||
"""
|
||||
if six.PY2:
|
||||
return status
|
||||
|
@ -300,7 +297,11 @@ class Gateway_10(Gateway):
|
|||
|
||||
# Request headers
|
||||
env.update(
|
||||
('HTTP_' + bton(k).upper().replace('-', '_'), bton(v))
|
||||
(
|
||||
'HTTP_{header_name!s}'.
|
||||
format(header_name=bton(k).upper().replace('-', '_')),
|
||||
bton(v),
|
||||
)
|
||||
for k, v in req.inheaders.items()
|
||||
)
|
||||
|
||||
|
@ -321,7 +322,7 @@ class Gateway_10(Gateway):
|
|||
class Gateway_u0(Gateway_10):
|
||||
"""A Gateway class to interface HTTPServer with WSGI u.0.
|
||||
|
||||
WSGI u.0 is an experimental protocol, which uses unicode for keys
|
||||
WSGI u.0 is an experimental protocol, which uses Unicode for keys
|
||||
and values in both Python 2 and Python 3.
|
||||
"""
|
||||
|
||||
|
@ -409,7 +410,7 @@ class PathInfoDispatcher:
|
|||
path = environ['PATH_INFO'] or '/'
|
||||
for p, app in self.apps:
|
||||
# The apps list should be sorted by length, descending.
|
||||
if path.startswith(p + '/') or path == p:
|
||||
if path.startswith('{path!s}/'.format(path=p)) or path == p:
|
||||
environ = environ.copy()
|
||||
environ['SCRIPT_NAME'] = environ.get('SCRIPT_NAME', '') + p
|
||||
environ['PATH_INFO'] = path[len(p):]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue