Update cheroot-8.5.2

This commit is contained in:
JonnyWong16 2021-10-14 21:14:02 -07:00
commit 182e5f553e
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
25 changed files with 2171 additions and 602 deletions

View file

@ -0,0 +1,38 @@
"""Local pytest plugin.
Contains hooks, which are tightly bound to the Cheroot framework
itself, useless for end-users' app testing.
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
pytest_version = tuple(map(int, pytest.__version__.split('.')))
def pytest_load_initial_conftests(early_config, parser, args):
"""Drop unfilterable warning ignores."""
if pytest_version < (6, 2, 0):
return
# pytest>=6.2.0 under Python 3.8:
# Refs:
# * https://docs.pytest.org/en/stable/usage.html#unraisable
# * https://github.com/pytest-dev/pytest/issues/5299
early_config._inicache['filterwarnings'].extend((
'ignore:Exception in thread CP Server Thread-:'
'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception',
'ignore:Exception in thread Thread-:'
'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception',
'ignore:Exception ignored in. '
'<socket.socket fd=-1, family=AddressFamily.AF_INET, '
'type=SocketKind.SOCK_STREAM, proto=.:'
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
'ignore:Exception ignored in. '
'<socket.socket fd=-1, family=AddressFamily.AF_INET6, '
'type=SocketKind.SOCK_STREAM, proto=.:'
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
))

View file

@ -55,7 +55,7 @@ def http_server():
def make_http_server(bind_addr):
"""Create and start an HTTP server bound to bind_addr."""
"""Create and start an HTTP server bound to ``bind_addr``."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=Gateway,

View file

@ -99,7 +99,7 @@ class CherootWebCase(webtest.WebCase):
date_tolerance = 2
def assertEqualDates(self, dt1, dt2, seconds=None):
"""Assert abs(dt1 - dt2) is within Y seconds."""
"""Assert ``abs(dt1 - dt2)`` is within ``Y`` seconds."""
if seconds is None:
seconds = self.date_tolerance
@ -108,8 +108,10 @@ class CherootWebCase(webtest.WebCase):
else:
diff = dt2 - dt1
if not diff < datetime.timedelta(seconds=seconds):
raise AssertionError('%r and %r are not within %r seconds.' %
(dt1, dt2, seconds))
raise AssertionError(
'%r and %r are not within %r seconds.' %
(dt1, dt2, seconds),
)
class Request:
@ -155,9 +157,13 @@ class Controller:
resp.status = '404 Not Found'
else:
output = handler(req, resp)
if (output is not None
and not any(resp.status.startswith(status_code)
for status_code in ('204', '304'))):
if (
output is not None
and not any(
resp.status.startswith(status_code)
for status_code in ('204', '304')
)
):
resp.body = output
try:
resp.headers.setdefault('Content-Length', str(len(output)))

View file

@ -11,45 +11,45 @@ from cheroot._compat import extract_bytes, memoryview, ntob, ntou, bton
@pytest.mark.parametrize(
'func,inp,out',
[
('func', 'inp', 'out'),
(
(ntob, 'bar', b'bar'),
(ntou, 'bar', u'bar'),
(bton, b'bar', 'bar'),
],
),
)
def test_compat_functions_positive(func, inp, out):
"""Check that compat functions work with correct input."""
"""Check that compatibility functions work with correct input."""
assert func(inp, encoding='utf-8') == out
@pytest.mark.parametrize(
'func',
[
(
ntob,
ntou,
],
),
)
def test_compat_functions_negative_nonnative(func):
"""Check that compat functions fail loudly for incorrect input."""
"""Check that compatibility functions fail loudly for incorrect input."""
non_native_test_str = u'bar' if six.PY2 else b'bar'
with pytest.raises(TypeError):
func(non_native_test_str, encoding='utf-8')
def test_ntou_escape():
"""Check that ntou supports escape-encoding under Python 2."""
"""Check that ``ntou`` supports escape-encoding under Python 2."""
expected = u'hišřії'
actual = ntou('hi\u0161\u0159\u0456\u0457', encoding='escape')
assert actual == expected
@pytest.mark.parametrize(
'input_argument,expected_result',
[
('input_argument', 'expected_result'),
(
(b'qwerty', b'qwerty'),
(memoryview(b'asdfgh'), b'asdfgh'),
],
),
)
def test_extract_bytes(input_argument, expected_result):
"""Check that legitimate inputs produce bytes."""
@ -58,5 +58,9 @@ def test_extract_bytes(input_argument, expected_result):
def test_extract_bytes_invalid():
"""Ensure that invalid input causes exception to be raised."""
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=r'^extract_bytes\(\) only accepts bytes '
'and memoryview/buffer$',
):
extract_bytes(u'some юнікод їїї')

View file

@ -0,0 +1,97 @@
"""Tests to verify the command line interface.
.. spelling::
cli
"""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
import sys
import six
import pytest
from cheroot.cli import (
Application,
parse_wsgi_bind_addr,
)
@pytest.mark.parametrize(
('raw_bind_addr', 'expected_bind_addr'),
(
# tcp/ip
('192.168.1.1:80', ('192.168.1.1', 80)),
# ipv6 ips has to be enclosed in brakets when specified in url form
('[::1]:8000', ('::1', 8000)),
('localhost:5000', ('localhost', 5000)),
# this is a valid input, but foo gets discarted
('foo@bar:5000', ('bar', 5000)),
('foo', ('foo', None)),
('123456789', ('123456789', None)),
# unix sockets
('/tmp/cheroot.sock', '/tmp/cheroot.sock'),
('/tmp/some-random-file-name', '/tmp/some-random-file-name'),
# abstract sockets
('@cheroot', '\x00cheroot'),
),
)
def test_parse_wsgi_bind_addr(raw_bind_addr, expected_bind_addr):
"""Check the parsing of the --bind option.
Verify some of the supported addresses and the expected return value.
"""
assert parse_wsgi_bind_addr(raw_bind_addr) == expected_bind_addr
@pytest.fixture
def wsgi_app(monkeypatch):
"""Return a WSGI app stub."""
class WSGIAppMock:
"""Mock of a wsgi module."""
def application(self):
"""Empty application method.
Default method to be called when no specific callable
is defined in the wsgi application identifier.
It has an empty body because we are expecting to verify that
the same method is return no the actual execution of it.
"""
def main(self):
"""Empty custom method (callable) inside the mocked WSGI app.
It has an empty body because we are expecting to verify that
the same method is return no the actual execution of it.
"""
app = WSGIAppMock()
# patch sys.modules, to include the an instance of WSGIAppMock
# under a specific namespace
if six.PY2:
# python2 requires the previous namespaces to be part of sys.modules
# (e.g. for 'a.b.c' we need to insert 'a', 'a.b' and 'a.b.c')
# otherwise it fails, we're setting the same instance on each level,
# we don't really care about those, just the last one.
monkeypatch.setitem(sys.modules, 'mypkg', app)
monkeypatch.setitem(sys.modules, 'mypkg.wsgi', app)
return app
@pytest.mark.parametrize(
('app_name', 'app_method'),
(
(None, 'application'),
('application', 'application'),
('main', 'main'),
),
)
def test_Aplication_resolve(app_name, app_method, wsgi_app):
"""Check the wsgi application name conversion."""
if app_name is None:
wsgi_app_spec = 'mypkg.wsgi'
else:
wsgi_app_spec = 'mypkg.wsgi:{app_name}'.format(**locals())
expected_app = getattr(wsgi_app, app_method)
assert Application.resolve(wsgi_app_spec).wsgi_app == expected_app

View file

@ -3,15 +3,22 @@
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno
import socket
import time
import logging
import traceback as traceback_
from collections import namedtuple
from six.moves import range, http_client, urllib
import six
import pytest
from jaraco.text import trim, unwrap
from cheroot.test import helper, webtest
from cheroot._compat import IS_CI, IS_PYPY, IS_WINDOWS
import cheroot.server
timeout = 1
@ -26,7 +33,7 @@ class Controller(helper.Controller):
return 'Hello, world!'
def pov(req, resp):
"""Render pov value."""
"""Render ``pov`` value."""
return pov
def stream(req, resp):
@ -43,8 +50,10 @@ class Controller(helper.Controller):
def upload(req, resp):
"""Process file upload and render thank."""
if not req.environ['REQUEST_METHOD'] == 'POST':
raise AssertionError("'POST' != request.method %r" %
req.environ['REQUEST_METHOD'])
raise AssertionError(
"'POST' != request.method %r" %
req.environ['REQUEST_METHOD'],
)
return "thanks for '%s'" % req.environ['wsgi.input'].read()
def custom_204(req, resp):
@ -103,9 +112,33 @@ class Controller(helper.Controller):
}
class ErrorLogMonitor:
"""Mock class to access the server error_log calls made by the server."""
ErrorLogCall = namedtuple('ErrorLogCall', ['msg', 'level', 'traceback'])
def __init__(self):
"""Initialize the server error log monitor/interceptor.
If you need to ignore a particular error message use the property
``ignored_msgs`` by appending to the list the expected error messages.
"""
self.calls = []
# to be used the the teardown validation
self.ignored_msgs = []
def __call__(self, msg='', level=logging.INFO, traceback=False):
"""Intercept the call to the server error_log method."""
if traceback:
tblines = traceback_.format_exc()
else:
tblines = ''
self.calls.append(ErrorLogMonitor.ErrorLogCall(msg, level, tblines))
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
def raw_testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and preconfigure it."""
app = Controller()
def _timeout(req, resp):
@ -117,9 +150,36 @@ def testing_server(wsgi_server_client):
wsgi_server.timeout = timeout
wsgi_server.server_client = wsgi_server_client
wsgi_server.keep_alive_conn_limit = 2
return wsgi_server
@pytest.fixture
def testing_server(raw_testing_server, monkeypatch):
"""Modify the "raw" base server to monitor the error_log messages.
If you need to ignore a particular error message use the property
``testing_server.error_log.ignored_msgs`` by appending to the list
the expected error messages.
"""
# patch the error_log calls of the server instance
monkeypatch.setattr(raw_testing_server, 'error_log', ErrorLogMonitor())
yield raw_testing_server
# Teardown verification, in case that the server logged an
# error that wasn't notified to the client or we just made a mistake.
for c_msg, c_level, c_traceback in raw_testing_server.error_log.calls:
if c_level <= logging.WARNING:
continue
assert c_msg in raw_testing_server.error_log.ignored_msgs, (
'Found error in the error log: '
"message = '{c_msg}', level = '{c_level}'\n"
'{c_traceback}'.format(**locals()),
)
@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
@ -338,7 +398,14 @@ def test_streaming_10(test_client, set_cl):
'http_server_protocol',
(
'HTTP/1.0',
'HTTP/1.1',
pytest.param(
'HTTP/1.1',
marks=pytest.mark.xfail(
IS_PYPY and IS_CI,
reason='Fails under PyPy in CI for unknown reason',
strict=False,
),
),
),
)
def test_keepalive(test_client, http_server_protocol):
@ -375,6 +442,11 @@ def test_keepalive(test_client, http_server_protocol):
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
assert header_has_value(
'Keep-Alive',
'timeout={test_client.server_instance.timeout}'.format(**locals()),
actual_headers,
)
# Remove the keep-alive header again.
status_line, actual_headers, actual_resp_body = test_client.get(
@ -386,6 +458,7 @@ def test_keepalive(test_client, http_server_protocol):
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
assert not header_exists('Keep-Alive', actual_headers)
test_client.server_instance.protocol = original_server_protocol
@ -401,7 +474,7 @@ def test_keepalive_conn_management(test_client):
http_connection.connect()
return http_connection
def request(conn):
def request(conn, keepalive=True):
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', headers=[('Connection', 'Keep-Alive')],
http_conn=conn, protocol='HTTP/1.0',
@ -410,7 +483,28 @@ def test_keepalive_conn_management(test_client):
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
if keepalive:
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
assert header_has_value(
'Keep-Alive',
'timeout={test_client.server_instance.timeout}'.
format(**locals()),
actual_headers,
)
else:
assert not header_exists('Connection', actual_headers)
assert not header_exists('Keep-Alive', actual_headers)
def check_server_idle_conn_count(count, timeout=1.0):
deadline = time.time() + timeout
while True:
n = test_client.server_instance._connections._num_connections
if n == count:
return
assert time.time() <= deadline, (
'idle conn count mismatch, wanted {count}, got {n}'.
format(**locals()),
)
disconnect_errors = (
http_client.BadStatusLine,
@ -421,50 +515,175 @@ def test_keepalive_conn_management(test_client):
# Make a new connection.
c1 = connection()
request(c1)
check_server_idle_conn_count(1)
# Make a second one.
c2 = connection()
request(c2)
check_server_idle_conn_count(2)
# Reusing the first connection should still work.
request(c1)
check_server_idle_conn_count(2)
# Creating a new connection should still work.
# Creating a new connection should still work, but we should
# have run out of available connections to keep alive, so the
# server should tell us to close.
c3 = connection()
request(c3)
request(c3, keepalive=False)
check_server_idle_conn_count(2)
# Allow a tick.
time.sleep(0.2)
# That's three connections, we should expect the one used less recently
# to be expired.
# Show that the third connection was closed.
with pytest.raises(disconnect_errors):
request(c2)
# But the oldest created one should still be valid.
# (As well as the newest one).
request(c1)
request(c3)
request(c3)
check_server_idle_conn_count(2)
# Wait for some of our timeout.
time.sleep(1.0)
time.sleep(1.2)
# Refresh the third connection.
request(c3)
# Refresh the second connection.
request(c2)
check_server_idle_conn_count(2)
# Wait for the remainder of our timeout, plus one tick.
time.sleep(1.2)
check_server_idle_conn_count(1)
# First connection should now be expired.
with pytest.raises(disconnect_errors):
request(c1)
check_server_idle_conn_count(1)
# But the third one should still be valid.
request(c3)
# But the second one should still be valid.
request(c2)
check_server_idle_conn_count(1)
# Restore original timeout.
test_client.server_instance.timeout = timeout
@pytest.mark.parametrize(
('simulated_exception', 'error_number', 'exception_leaks'),
(
pytest.param(
socket.error, errno.ECONNRESET, False,
id='socket.error(ECONNRESET)',
),
pytest.param(
socket.error, errno.EPIPE, False,
id='socket.error(EPIPE)',
),
pytest.param(
socket.error, errno.ENOTCONN, False,
id='simulated socket.error(ENOTCONN)',
),
pytest.param(
None, # <-- don't raise an artificial exception
errno.ENOTCONN, False,
id='real socket.error(ENOTCONN)',
marks=pytest.mark.xfail(
IS_WINDOWS,
reason='Now reproducible this way on Windows',
),
),
pytest.param(
socket.error, errno.ESHUTDOWN, False,
id='socket.error(ESHUTDOWN)',
),
pytest.param(RuntimeError, 666, True, id='RuntimeError(666)'),
pytest.param(socket.error, -1, True, id='socket.error(-1)'),
) + (
() if six.PY2 else (
pytest.param(
ConnectionResetError, errno.ECONNRESET, False,
id='ConnectionResetError(ECONNRESET)',
),
pytest.param(
BrokenPipeError, errno.EPIPE, False,
id='BrokenPipeError(EPIPE)',
),
pytest.param(
BrokenPipeError, errno.ESHUTDOWN, False,
id='BrokenPipeError(ESHUTDOWN)',
),
)
),
)
def test_broken_connection_during_tcp_fin(
error_number, exception_leaks,
mocker, monkeypatch,
simulated_exception, test_client,
):
"""Test there's no traceback on broken connection during close.
It artificially causes :py:data:`~errno.ECONNRESET` /
:py:data:`~errno.EPIPE` / :py:data:`~errno.ESHUTDOWN` /
:py:data:`~errno.ENOTCONN` as well as unrelated :py:exc:`RuntimeError`
and :py:exc:`socket.error(-1) <socket.error>` on the server socket when
:py:meth:`socket.shutdown() <socket.socket.shutdown>` is called. It's
triggered by closing the client socket before the server had a chance
to respond.
The expectation is that only :py:exc:`RuntimeError` and a
:py:exc:`socket.error` with an unusual error code would leak.
With the :py:data:`None`-parameter, a real non-simulated
:py:exc:`OSError(107, 'Transport endpoint is not connected')
<OSError>` happens.
"""
exc_instance = (
None if simulated_exception is None
else simulated_exception(error_number, 'Simulated socket error')
)
old_close_kernel_socket = (
test_client.server_instance.
ConnectionClass._close_kernel_socket
)
def _close_kernel_socket(self):
monkeypatch.setattr( # `socket.shutdown` is read-only otherwise
self, 'socket',
mocker.mock_module.Mock(wraps=self.socket),
)
if exc_instance is not None:
monkeypatch.setattr(
self.socket, 'shutdown',
mocker.mock_module.Mock(side_effect=exc_instance),
)
_close_kernel_socket.fin_spy = mocker.spy(self.socket, 'shutdown')
_close_kernel_socket.exception_leaked = True
old_close_kernel_socket(self)
_close_kernel_socket.exception_leaked = False
monkeypatch.setattr(
test_client.server_instance.ConnectionClass,
'_close_kernel_socket',
_close_kernel_socket,
)
conn = test_client.get_connection()
conn.auto_open = False
conn.connect()
conn.send(b'GET /hello HTTP/1.1')
conn.send(('Host: %s' % conn.host).encode('ascii'))
conn.close()
for _ in range(10): # Let the server attempt TCP shutdown
time.sleep(0.1)
if hasattr(_close_kernel_socket, 'exception_leaked'):
break
if exc_instance is not None: # simulated by us
assert _close_kernel_socket.fin_spy.spy_exception is exc_instance
else: # real
assert isinstance(
_close_kernel_socket.fin_spy.spy_exception, socket.error,
)
assert _close_kernel_socket.fin_spy.spy_exception.errno == error_number
assert _close_kernel_socket.exception_leaked is exception_leaks
@pytest.mark.parametrize(
'timeout_before_headers',
(
@ -475,7 +694,7 @@ def test_keepalive_conn_management(test_client):
def test_HTTP11_Timeout(test_client, timeout_before_headers):
"""Check timeout without sending any data.
The server will close the conn with a 408.
The server will close the connection with a 408.
"""
conn = test_client.get_connection()
conn.auto_open = False
@ -594,7 +813,7 @@ def test_HTTP11_Timeout_after_request(test_client):
def test_HTTP11_pipelining(test_client):
"""Test HTTP/1.1 pipelining.
httplib doesn't support this directly.
:py:mod:`http.client` doesn't support this directly.
"""
conn = test_client.get_connection()
@ -639,7 +858,7 @@ def test_100_Continue(test_client):
conn = test_client.get_connection()
# Try a page without an Expect request header first.
# Note that httplib's response.begin automatically ignores
# Note that http.client's response.begin automatically ignores
# 100 Continue responses, so we must manually check for it.
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
@ -800,11 +1019,13 @@ def test_No_Message_Body(test_client):
@pytest.mark.xfail(
reason='Server does not correctly read trailers/ending of the previous '
'HTTP request, thus the second request fails as the server tries '
r"to parse b'Content-Type: application/json\r\n' as a "
'Request-Line. This results in HTTP status code 400, instead of 413'
'Ref: https://github.com/cherrypy/cheroot/issues/69',
reason=unwrap(
trim("""
Headers from earlier request leak into the request
line for a subsequent request, resulting in 400
instead of 413. See cherrypy/cheroot#69 for details.
"""),
),
)
def test_Chunked_Encoding(test_client):
"""Test HTTP uploads with chunked transfer-encoding."""
@ -837,7 +1058,7 @@ def test_Chunked_Encoding(test_client):
# Try a chunked request that exceeds server.max_request_body_size.
# Note that the delimiters and trailer are included.
body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n'
body = b'\r\n'.join((b'3e3', b'x' * 995, b'0', b'', b''))
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Transfer-Encoding', 'chunked')
@ -895,7 +1116,7 @@ def test_Content_Length_not_int(test_client):
@pytest.mark.parametrize(
'uri,expected_resp_status,expected_resp_body',
('uri', 'expected_resp_status', 'expected_resp_body'),
(
(
'/wrong_cl_buffered', 500,
@ -929,6 +1150,16 @@ def test_Content_Length_out(
conn.close()
# the server logs the exception that we had verified from the
# client perspective. Tell the error_log verification that
# it can ignore that message.
test_client.server_instance.error_log.ignored_msgs.extend((
# Python 3.7+:
"ValueError('Response body exceeds the declared Content-Length.')",
# Python 2.7-3.6 (macOS?):
"ValueError('Response body exceeds the declared Content-Length.',)",
))
@pytest.mark.xfail(
reason='Sometimes this test fails due to low timeout. '
@ -970,11 +1201,94 @@ def test_No_CRLF(test_client, invalid_terminator):
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
# (b'%s' % b'') is not supported in Python 3.4, so just use +
conn.send(b'GET /hello HTTP/1.1' + invalid_terminator)
# (b'%s' % b'') is not supported in Python 3.4, so just use bytes.join()
conn.send(b''.join((b'GET /hello HTTP/1.1', invalid_terminator)))
response = conn.response_class(conn.sock, method='GET')
response.begin()
actual_resp_body = response.read()
expected_resp_body = b'HTTP requires CRLF terminators'
assert actual_resp_body == expected_resp_body
conn.close()
class FaultySelect:
"""Mock class to insert errors in the selector.select method."""
def __init__(self, original_select):
"""Initilize helper class to wrap the selector.select method."""
self.original_select = original_select
self.request_served = False
self.os_error_triggered = False
def __call__(self, timeout):
"""Intercept the calls to selector.select."""
if self.request_served:
self.os_error_triggered = True
raise OSError('Error while selecting the client socket.')
return self.original_select(timeout)
class FaultyGetMap:
"""Mock class to insert errors in the selector.get_map method."""
def __init__(self, original_get_map):
"""Initilize helper class to wrap the selector.get_map method."""
self.original_get_map = original_get_map
self.sabotage_conn = False
self.conn_closed = False
def __call__(self):
"""Intercept the calls to selector.get_map."""
sabotage_targets = (
conn for _, (_, _, _, conn) in self.original_get_map().items()
if isinstance(conn, cheroot.server.HTTPConnection)
) if self.sabotage_conn and not self.conn_closed else ()
for conn in sabotage_targets:
# close the socket to cause OSError
conn.close()
self.conn_closed = True
return self.original_get_map()
def test_invalid_selected_connection(test_client, monkeypatch):
"""Test the error handling segment of HTTP connection selection.
See :py:meth:`cheroot.connections.ConnectionManager.get_conn`.
"""
# patch the select method
faux_select = FaultySelect(
test_client.server_instance._connections._selector.select,
)
monkeypatch.setattr(
test_client.server_instance._connections._selector,
'select',
faux_select,
)
# patch the get_map method
faux_get_map = FaultyGetMap(
test_client.server_instance._connections._selector._selector.get_map,
)
monkeypatch.setattr(
test_client.server_instance._connections._selector._selector,
'get_map',
faux_get_map,
)
# request a page with connection keep-alive to make sure
# we'll have a connection to be modified.
resp_status, resp_headers, resp_body = test_client.request(
'/page1', headers=[('Connection', 'Keep-Alive')],
)
assert resp_status == '200 OK'
# trigger the internal errors
faux_get_map.sabotage_conn = faux_select.request_served = True
# give time to make sure the error gets handled
time.sleep(test_client.server_instance.expiration_interval * 2)
assert faux_select.os_error_triggered
assert faux_get_map.conn_closed

View file

@ -18,6 +18,7 @@ from cheroot.test import helper
HTTP_BAD_REQUEST = 400
HTTP_LENGTH_REQUIRED = 411
HTTP_NOT_FOUND = 404
HTTP_REQUEST_ENTITY_TOO_LARGE = 413
HTTP_OK = 200
HTTP_VERSION_NOT_SUPPORTED = 505
@ -78,7 +79,7 @@ def _get_http_response(connection, method='GET'):
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
"""Attach a WSGI app to the given server and preconfigure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = HelloController()
wsgi_server.max_request_body_size = 30000000
@ -92,6 +93,21 @@ def test_client(testing_server):
return testing_server.server_client
@pytest.fixture
def testing_server_with_defaults(wsgi_server_client):
"""Attach a WSGI app to the given server and preconfigure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = HelloController()
wsgi_server.server_client = wsgi_server_client
return wsgi_server
@pytest.fixture
def test_client_with_defaults(testing_server_with_defaults):
"""Get and return a test client out of the given server."""
return testing_server_with_defaults.server_client
def test_http_connect_request(test_client):
"""Check that CONNECT query results in Method Not Allowed status."""
status_line = test_client.connect('/anything')[0]
@ -262,8 +278,30 @@ def test_content_length_required(test_client):
assert actual_status == HTTP_LENGTH_REQUIRED
@pytest.mark.xfail(
reason='https://github.com/cherrypy/cheroot/issues/106',
strict=False, # sometimes it passes
)
def test_large_request(test_client_with_defaults):
"""Test GET query with maliciously large Content-Length."""
# If the server's max_request_body_size is not set (i.e. is set to 0)
# then this will result in an `OverflowError: Python int too large to
# convert to C ssize_t` in the server.
# We expect that this should instead return that the request is too
# large.
c = test_client_with_defaults.get_connection()
c.putrequest('GET', '/hello')
c.putheader('Content-Length', str(2**64))
c.endheaders()
response = c.getresponse()
actual_status = response.status
assert actual_status == HTTP_REQUEST_ENTITY_TOO_LARGE
@pytest.mark.parametrize(
'request_line,status_code,expected_body',
('request_line', 'status_code', 'expected_body'),
(
(
b'GET /', # missing proto
@ -401,7 +439,7 @@ class CloseResponse:
@pytest.fixture
def testing_server_close(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
"""Attach a WSGI app to the given server and preconfigure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = CloseController()
wsgi_server.max_request_body_size = 30000000

View file

@ -8,7 +8,7 @@ from cheroot.wsgi import PathInfoDispatcher
def wsgi_invoke(app, environ):
"""Serve 1 requeset from a WSGI application."""
"""Serve 1 request from a WSGI application."""
response = {}
def start_response(status, headers):
@ -25,7 +25,7 @@ def wsgi_invoke(app, environ):
def test_dispatch_no_script_name():
"""Despatch despite lack of SCRIPT_NAME in environ."""
"""Dispatch despite lack of ``SCRIPT_NAME`` in environ."""
# Bare bones WSGI hello world app (from PEP 333).
def app(environ, start_response):
start_response(

View file

@ -8,7 +8,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
@pytest.mark.parametrize(
'err_names,err_nums',
('err_names', 'err_nums'),
(
(('', 'some-nonsense-name'), []),
(
@ -24,7 +24,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
),
)
def test_plat_specific_errors(err_names, err_nums):
"""Test that plat_specific_errors retrieves correct err num list."""
"""Test that ``plat_specific_errors`` gets correct error numbers list."""
actual_err_nums = errors.plat_specific_errors(*err_names)
assert len(actual_err_nums) == len(err_nums)
assert sorted(actual_err_nums) == sorted(err_nums)

View file

@ -1,4 +1,4 @@
"""self-explanatory."""
"""Tests for :py:mod:`cheroot.makefile`."""
from cheroot import makefile
@ -7,14 +7,14 @@ __metaclass__ = type
class MockSocket:
"""Mocks a socket."""
"""A mock socket."""
def __init__(self):
"""Initialize."""
"""Initialize :py:class:`MockSocket`."""
self.messages = []
def recv_into(self, buf):
"""Simulate recv_into for Python 3."""
"""Simulate ``recv_into`` for Python 3."""
if not self.messages:
return 0
msg = self.messages.pop(0)
@ -23,7 +23,7 @@ class MockSocket:
return len(msg)
def recv(self, size):
"""Simulate recv for Python 2."""
"""Simulate ``recv`` for Python 2."""
try:
return self.messages.pop(0)
except IndexError:
@ -44,7 +44,7 @@ def test_bytes_read():
def test_bytes_written():
"""Writer should capture bytes writtten."""
"""Writer should capture bytes written."""
sock = MockSocket()
sock.messages.append(b'foo')
wfile = makefile.MakeFile(sock, 'w')

View file

@ -5,10 +5,12 @@
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from contextlib import closing
import os
import socket
import tempfile
import threading
import time
import uuid
import pytest
@ -16,14 +18,15 @@ import requests
import requests_unixsocket
import six
from six.moves import queue, urllib
from .._compat import bton, ntob
from .._compat import IS_LINUX, IS_MACOS, SYS_PLATFORM
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM
from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
EPHEMERAL_PORT,
get_server_client,
)
@ -42,15 +45,8 @@ non_macos_sock_test = pytest.mark.skipif(
@pytest.fixture(params=('abstract', 'file'))
def unix_sock_file(request):
"""Check that bound UNIX socket address is stored in server."""
if request.param == 'abstract':
yield request.getfixturevalue('unix_abstract_sock')
return
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
yield tmp_sock_fname
os.close(tmp_sock_fh)
os.unlink(tmp_sock_fname)
name = 'unix_{request.param}_sock'.format(**locals())
return request.getfixturevalue(name)
@pytest.fixture
@ -67,6 +63,17 @@ def unix_abstract_sock():
)).decode()
@pytest.fixture
def unix_file_sock():
"""Yield a unix file socket."""
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
yield tmp_sock_fname
os.close(tmp_sock_fh)
os.unlink(tmp_sock_fname)
def test_prepare_makes_server_ready():
"""Check that prepare() makes the server ready, and stop() clears it."""
httpserver = HTTPServer(
@ -110,6 +117,77 @@ def test_stop_interrupts_serve():
assert not serve_thread.is_alive()
@pytest.mark.parametrize(
'exc_cls',
(
IOError,
KeyboardInterrupt,
OSError,
RuntimeError,
),
)
def test_server_interrupt(exc_cls):
"""Check that assigning interrupt stops the server."""
interrupt_msg = 'should catch {uuid!s}'.format(uuid=uuid.uuid4())
raise_marker_sentinel = object()
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
result_q = queue.Queue()
def serve_thread():
# ensure we catch the exception on the serve() thread
try:
httpserver.serve()
except exc_cls as e:
if str(e) == interrupt_msg:
result_q.put(raise_marker_sentinel)
httpserver.prepare()
serve_thread = threading.Thread(target=serve_thread)
serve_thread.start()
serve_thread.join(0.5)
assert serve_thread.is_alive()
# this exception is raised on the serve() thread,
# not in the calling context.
httpserver.interrupt = exc_cls(interrupt_msg)
serve_thread.join(0.5)
assert not serve_thread.is_alive()
assert result_q.get_nowait() is raise_marker_sentinel
def test_serving_is_false_and_stop_returns_after_ctrlc():
"""Check that stop() interrupts running of serve()."""
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
httpserver.prepare()
# Simulate a Ctrl-C on the first call to `run`.
def raise_keyboard_interrupt(*args, **kwargs):
raise KeyboardInterrupt()
httpserver._connections._selector.select = raise_keyboard_interrupt
serve_thread = threading.Thread(target=httpserver.serve)
serve_thread.start()
# The thread should exit right away due to the interrupt.
serve_thread.join(httpserver.expiration_interval * 2)
assert not serve_thread.is_alive()
assert not httpserver._connections._serving
httpserver.stop()
@pytest.mark.parametrize(
'ip_addr',
(
@ -135,7 +213,7 @@ def test_bind_addr_unix(http_server, unix_sock_file):
@unix_only_sock_test
def test_bind_addr_unix_abstract(http_server, unix_abstract_sock):
"""Check that bound UNIX abstract sockaddr is stored in server."""
"""Check that bound UNIX abstract socket address is stored in server."""
httpserver = http_server.send(unix_abstract_sock)
assert httpserver.bind_addr == unix_abstract_sock
@ -167,27 +245,26 @@ class _TestGateway(Gateway):
@pytest.fixture
def peercreds_enabled_server_and_client(http_server, unix_sock_file):
"""Construct a test server with `peercreds_enabled`."""
def peercreds_enabled_server(http_server, unix_sock_file):
"""Construct a test server with ``peercreds_enabled``."""
httpserver = http_server.send(unix_sock_file)
httpserver.gateway = _TestGateway
httpserver.peercreds_enabled = True
return httpserver, get_server_client(httpserver)
return httpserver
@unix_only_sock_test
@non_macos_sock_test
def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
"""Check that peercred lookup works when enabled."""
httpserver, testclient = peercreds_enabled_server_and_client
def test_peercreds_unix_sock(peercreds_enabled_server):
"""Check that ``PEERCRED`` lookup works when enabled."""
httpserver = peercreds_enabled_server
bind_addr = httpserver.bind_addr
if isinstance(bind_addr, six.binary_type):
bind_addr = bind_addr.decode()
unix_base_uri = 'http+unix://{}'.format(
bind_addr.replace('\0', '%00').replace('/', '%2F'),
)
quoted = urllib.parse.quote(bind_addr, safe='')
unix_base_uri = 'http+unix://{quoted}'.format(**locals())
expected_peercreds = os.getpid(), os.getuid(), os.getgid()
expected_peercreds = '|'.join(map(str, expected_peercreds))
@ -208,9 +285,9 @@ def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
)
@unix_only_sock_test
@non_macos_sock_test
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
"""Check that peercred resolution works when enabled."""
httpserver, testclient = peercreds_enabled_server_and_client
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server):
"""Check that ``PEERCRED`` resolution works when enabled."""
httpserver = peercreds_enabled_server
httpserver.peercreds_resolve_enabled = True
bind_addr = httpserver.bind_addr
@ -218,9 +295,8 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
if isinstance(bind_addr, six.binary_type):
bind_addr = bind_addr.decode()
unix_base_uri = 'http+unix://{}'.format(
bind_addr.replace('\0', '%00').replace('/', '%2F'),
)
quoted = urllib.parse.quote(bind_addr, safe='')
unix_base_uri = 'http+unix://{quoted}'.format(**locals())
import grp
import pwd
@ -233,3 +309,111 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI)
peercreds_text_resp.raise_for_status()
assert peercreds_text_resp.text == expected_textcreds
@pytest.mark.skipif(
IS_WINDOWS,
reason='This regression test is for a Linux bug, '
'and the resource module is not available on Windows',
)
@pytest.mark.parametrize(
'resource_limit',
(
1024,
2048,
),
indirect=('resource_limit',),
)
@pytest.mark.usefixtures('many_open_sockets')
def test_high_number_of_file_descriptors(resource_limit):
"""Test the server does not crash with a high file-descriptor value.
This test shouldn't cause a server crash when trying to access
file-descriptor higher than 1024.
The earlier implementation used to rely on ``select()`` syscall that
doesn't support file descriptors with numbers higher than 1024.
"""
# We want to force the server to use a file-descriptor with
# a number above resource_limit
# Create our server
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), gateway=Gateway,
)
try:
# This will trigger a crash if select() is used in the implementation
with httpserver._run_in_thread():
# allow server to run long enough to invoke select()
time.sleep(1.0)
except: # noqa: E722
raise # only needed for `else` to work
else:
# We use closing here for py2-compat
with closing(socket.socket()) as sock:
# Check new sockets created are still above our target number
assert sock.fileno() >= resource_limit
finally:
# Stop our server
httpserver.stop()
if not IS_WINDOWS:
test_high_number_of_file_descriptors = pytest.mark.forked(
test_high_number_of_file_descriptors,
)
@pytest.fixture
def resource_limit(request):
"""Set the resource limit two times bigger then requested."""
resource = pytest.importorskip(
'resource',
reason='The "resource" module is Unix-specific',
)
# Get current resource limits to restore them later
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
# We have to increase the nofile limit above 1024
# Otherwise we see a 'Too many files open' error, instead of
# an error due to the file descriptor number being too high
resource.setrlimit(
resource.RLIMIT_NOFILE,
(request.param * 2, hard_limit),
)
try: # noqa: WPS501
yield request.param
finally:
# Reset the resource limit back to the original soft limit
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
@pytest.fixture
def many_open_sockets(resource_limit):
"""Allocate a lot of file descriptors by opening dummy sockets."""
# Hoard a lot of file descriptors by opening and storing a lot of sockets
test_sockets = []
# Open a lot of file descriptors, so the next one the server
# opens is a high number
try:
for i in range(resource_limit):
sock = socket.socket()
test_sockets.append(sock)
# NOTE: We used to interrupt the loop early but this doesn't seem
# NOTE: to work well in envs with indeterministic runtimes like
# NOTE: PyPy. It looks like sometimes it frees some file
# NOTE: descriptors in between running this fixture and the actual
# NOTE: test code so the early break has been removed to try
# NOTE: address that. The approach may need to be rethought if the
# NOTE: issue reoccurs. Another approach may be disabling the GC.
# Check we opened enough descriptors to reach a high number
the_highest_fileno = max(sock.fileno() for sock in test_sockets)
assert the_highest_fileno >= resource_limit
yield the_highest_fileno
finally:
# Close our open resources
for test_socket in test_sockets:
test_socket.close()

View file

@ -1,4 +1,4 @@
"""Tests for TLS/SSL support."""
"""Tests for TLS support."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
@ -6,11 +6,14 @@ from __future__ import absolute_import, division, print_function
__metaclass__ = type
import functools
import json
import os
import ssl
import subprocess
import sys
import threading
import time
import traceback
import OpenSSL.SSL
import pytest
@ -21,7 +24,7 @@ import trustme
from .._compat import bton, ntob, ntou
from .._compat import IS_ABOVE_OPENSSL10, IS_PYPY
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
from ..server import Gateway, HTTPServer, get_ssl_adapter_class
from ..server import HTTPServer, get_ssl_adapter_class
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
@ -30,9 +33,17 @@ from ..testing import (
_get_conn_data,
_probe_ipv6_sock,
)
from ..wsgi import Gateway_10
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
IS_WIN2016 = (
IS_WINDOWS
# pylint: disable=unsupported-membership-test
and b'Microsoft Windows Server 2016 Datacenter' in subprocess.check_output(
('systeminfo',),
)
)
IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL')
IS_PYOPENSSL_SSL_VERSION_1_0 = (
OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION).
@ -40,6 +51,7 @@ IS_PYOPENSSL_SSL_VERSION_1_0 = (
)
PY27 = sys.version_info[:2] == (2, 7)
PY34 = sys.version_info[:2] == (3, 4)
PY3 = not six.PY2
_stdlib_to_openssl_verify = {
@ -71,7 +83,7 @@ missing_ipv6 = pytest.mark.skipif(
)
class HelloWorldGateway(Gateway):
class HelloWorldGateway(Gateway_10):
"""Gateway responding with Hello World to root URI."""
def respond(self):
@ -83,11 +95,21 @@ class HelloWorldGateway(Gateway):
req.ensure_headers_sent()
req.write(b'Hello world!')
return
if req_uri == '/env':
req.status = b'200 OK'
req.ensure_headers_sent()
env = self.get_environ()
# drop files so that it can be json dumped
env.pop('wsgi.errors')
env.pop('wsgi.input')
print(env)
req.write(json.dumps(env).encode('utf-8'))
return
return super(HelloWorldGateway, self).respond()
def make_tls_http_server(bind_addr, ssl_adapter, request):
"""Create and start an HTTP server bound to bind_addr."""
"""Create and start an HTTP server bound to ``bind_addr``."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=HelloWorldGateway,
@ -128,7 +150,7 @@ def tls_ca_certificate_pem_path(ca):
def tls_certificate(ca):
"""Provide a leaf certificate via fixture."""
interface, host, port = _get_conn_data(ANY_INTERFACE_IPV4)
return ca.issue_server_cert(ntou(interface), )
return ca.issue_server_cert(ntou(interface))
@pytest.fixture
@ -145,6 +167,43 @@ def tls_certificate_private_key_pem_path(tls_certificate):
yield cert_key_pem
def _thread_except_hook(exceptions, args):
"""Append uncaught exception ``args`` in threads to ``exceptions``."""
if issubclass(args.exc_type, SystemExit):
return
# cannot store the exception, it references the thread's stack
exceptions.append((
args.exc_type,
str(args.exc_value),
''.join(
traceback.format_exception(
args.exc_type, args.exc_value, args.exc_traceback,
),
),
))
@pytest.fixture
def thread_exceptions():
"""Provide a list of uncaught exceptions from threads via a fixture.
Only catches exceptions on Python 3.8+.
The list contains: ``(type, str(value), str(traceback))``
"""
exceptions = []
# Python 3.8+
orig_hook = getattr(threading, 'excepthook', None)
if orig_hook is not None:
threading.excepthook = functools.partial(
_thread_except_hook, exceptions,
)
try:
yield exceptions
finally:
if orig_hook is not None:
threading.excepthook = orig_hook
@pytest.mark.parametrize(
'adapter_type',
(
@ -180,7 +239,7 @@ def test_ssl_adapters(
)
resp = requests.get(
'https://' + interface + ':' + str(port) + '/',
'https://{host!s}:{port!s}/'.format(host=interface, port=port),
verify=tls_ca_certificate_pem_path,
)
@ -188,7 +247,7 @@ def test_ssl_adapters(
assert resp.text == 'Hello world!'
@pytest.mark.parametrize(
@pytest.mark.parametrize( # noqa: C901 # FIXME
'adapter_type',
(
'builtin',
@ -196,7 +255,7 @@ def test_ssl_adapters(
),
)
@pytest.mark.parametrize(
'is_trusted_cert,tls_client_identity',
('is_trusted_cert', 'tls_client_identity'),
(
(True, 'localhost'), (True, '127.0.0.1'),
(True, '*.localhost'), (True, 'not_localhost'),
@ -211,7 +270,7 @@ def test_ssl_adapters(
ssl.CERT_REQUIRED, # server should validate if client cert CA is OK
),
)
def test_tls_client_auth(
def test_tls_client_auth( # noqa: C901 # FIXME
# FIXME: remove twisted logic, separate tests
mocker,
tls_http_server, adapter_type,
@ -265,7 +324,7 @@ def test_tls_client_auth(
make_https_request = functools.partial(
requests.get,
'https://' + interface + ':' + str(port) + '/',
'https://{host!s}:{port!s}/'.format(host=interface, port=port),
# Server TLS certificate verification:
verify=tls_ca_certificate_pem_path,
@ -324,36 +383,200 @@ def test_tls_client_auth(
except AttributeError:
if PY34:
pytest.xfail('OpenSSL behaves wierdly under Python 3.4')
elif not six.PY2 and IS_WINDOWS:
elif IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
err_text = str(ssl_err.value)
else:
raise
if isinstance(err_text, int):
err_text = str(ssl_err.value)
expected_substrings = (
'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND
else 'tlsv1 alert unknown ca',
)
if not six.PY2:
if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl':
expected_substrings = ('tlsv1 alert unknown ca', )
if (
IS_WINDOWS
and tls_verify_mode in (
ssl.CERT_REQUIRED,
ssl.CERT_OPTIONAL,
)
and not is_trusted_cert
and tls_client_identity == 'localhost'
):
expected_substrings += (
'bad handshake: '
"SysCallError(10054, 'WSAECONNRESET')",
"('Connection aborted.', "
'OSError("(10054, \'WSAECONNRESET\')"))',
expected_substrings = ('tlsv1 alert unknown ca',)
if (
tls_verify_mode in (
ssl.CERT_REQUIRED,
ssl.CERT_OPTIONAL,
)
and not is_trusted_cert
and tls_client_identity == 'localhost'
):
expected_substrings += (
'bad handshake: '
"SysCallError(10054, 'WSAECONNRESET')",
"('Connection aborted.', "
'OSError("(10054, \'WSAECONNRESET\')"))',
"('Connection aborted.', "
'OSError("(10054, \'WSAECONNRESET\')",))',
"('Connection aborted.', "
'error("(10054, \'WSAECONNRESET\')",))',
"('Connection aborted.', "
'ConnectionResetError(10054, '
"'An existing connection was forcibly closed "
"by the remote host', None, 10054, None))",
) if IS_WINDOWS else (
"('Connection aborted.', "
'OSError("(104, \'ECONNRESET\')"))',
"('Connection aborted.', "
'OSError("(104, \'ECONNRESET\')",))',
"('Connection aborted.', "
'error("(104, \'ECONNRESET\')",))',
"('Connection aborted.', "
"ConnectionResetError(104, 'Connection reset by peer'))",
"('Connection aborted.', "
"error(104, 'Connection reset by peer'))",
) if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_LINUX
) else (
"('Connection aborted.', "
"BrokenPipeError(32, 'Broken pipe'))",
)
assert any(e in err_text for e in expected_substrings)
@pytest.mark.parametrize( # noqa: C901 # FIXME
'adapter_type',
(
'builtin',
'pyopenssl',
),
)
@pytest.mark.parametrize(
('tls_verify_mode', 'use_client_cert'),
(
(ssl.CERT_NONE, False),
(ssl.CERT_NONE, True),
(ssl.CERT_OPTIONAL, False),
(ssl.CERT_OPTIONAL, True),
(ssl.CERT_REQUIRED, True),
),
)
def test_ssl_env( # noqa: C901 # FIXME
thread_exceptions,
recwarn,
mocker,
tls_http_server, adapter_type,
ca, tls_verify_mode, tls_certificate,
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
tls_ca_certificate_pem_path,
use_client_cert,
):
"""Test the SSL environment generated by the SSL adapters."""
interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4)
with mocker.mock_module.patch(
'idna.core.ulabel',
return_value=ntob('127.0.0.1'),
):
client_cert = ca.issue_cert(ntou('127.0.0.1'))
with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem:
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
tls_adapter = tls_adapter_cls(
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
)
if adapter_type == 'pyopenssl':
tls_adapter.context = tls_adapter.get_context()
tls_adapter.context.set_verify(
_stdlib_to_openssl_verify[tls_verify_mode],
lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
)
else:
tls_adapter.context.verify_mode = tls_verify_mode
ca.configure_trust(tls_adapter.context)
tls_certificate.configure_cert(tls_adapter.context)
tlswsgiserver = tls_http_server((interface, port), tls_adapter)
interface, _host, port = _get_conn_data(tlswsgiserver.bind_addr)
resp = requests.get(
'https://' + interface + ':' + str(port) + '/env',
verify=tls_ca_certificate_pem_path,
cert=cl_pem if use_client_cert else None,
)
if PY34 and resp.status_code != 200:
pytest.xfail(
'Python 3.4 has problems with verifying client certs',
)
env = json.loads(resp.content.decode('utf-8'))
# hard coded env
assert env['wsgi.url_scheme'] == 'https'
assert env['HTTPS'] == 'on'
# ensure these are present
for key in {'SSL_VERSION_INTERFACE', 'SSL_VERSION_LIBRARY'}:
assert key in env
# pyOpenSSL generates the env before the handshake completes
if adapter_type == 'pyopenssl':
return
for key in {'SSL_PROTOCOL', 'SSL_CIPHER'}:
assert key in env
# client certificate env
if tls_verify_mode == ssl.CERT_NONE or not use_client_cert:
assert env['SSL_CLIENT_VERIFY'] == 'NONE'
else:
assert env['SSL_CLIENT_VERIFY'] == 'SUCCESS'
with open(cl_pem, 'rt') as f:
assert env['SSL_CLIENT_CERT'] in f.read()
for key in {
'SSL_CLIENT_M_VERSION', 'SSL_CLIENT_M_SERIAL',
'SSL_CLIENT_I_DN', 'SSL_CLIENT_S_DN',
}:
assert key in env
# builtin ssl environment generation may use a loopback socket
# ensure no ResourceWarning was raised during the test
# NOTE: python 2.7 does not emit ResourceWarning for ssl sockets
if IS_PYPY:
# NOTE: PyPy doesn't have ResourceWarning
# Ref: https://doc.pypy.org/en/latest/cpython_differences.html
return
for warn in recwarn:
if not issubclass(warn.category, ResourceWarning):
continue
# the tests can sporadically generate resource warnings
# due to timing issues
# all of these sporadic warnings appear to be about socket.socket
# and have been observed to come from requests connection pool
msg = str(warn.message)
if 'socket.socket' in msg:
pytest.xfail(
'\n'.join((
'Sometimes this test fails due to '
'a socket.socket ResourceWarning:',
msg,
)),
)
pytest.fail(msg)
# to perform the ssl handshake over that loopback socket,
# the builtin ssl environment generation uses a thread
for _, _, trace in thread_exceptions:
print(trace, file=sys.stderr)
assert not thread_exceptions, ': '.join((
thread_exceptions[0][0].__name__,
thread_exceptions[0][1],
))
@pytest.mark.parametrize(
'ip_addr',
(
@ -382,7 +605,16 @@ def test_https_over_http_error(http_server, ip_addr):
@pytest.mark.parametrize(
'adapter_type',
(
'builtin',
pytest.param(
'builtin',
marks=pytest.mark.xfail(
IS_WINDOWS and six.PY2,
raises=requests.exceptions.ConnectionError,
reason='Stdlib `ssl` module behaves weirdly '
'on Windows under Python 2',
strict=False,
),
),
'pyopenssl',
),
)
@ -428,16 +660,41 @@ def test_http_over_https_error(
fqdn = interface
if ip_addr is ANY_INTERFACE_IPV6:
fqdn = '[{}]'.format(fqdn)
fqdn = '[{fqdn}]'.format(**locals())
expect_fallback_response_over_plain_http = (
(adapter_type == 'pyopenssl'
and (IS_ABOVE_OPENSSL10 or not six.PY2))
(
adapter_type == 'pyopenssl'
and (IS_ABOVE_OPENSSL10 or not six.PY2)
)
or PY27
) or (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and not IS_WIN2016
)
if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and IS_WIN2016
and adapter_type == 'builtin'
and ip_addr is ANY_INTERFACE_IPV6
):
expect_fallback_response_over_plain_http = True
if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and not IS_WIN2016
and adapter_type == 'builtin'
and ip_addr is not ANY_INTERFACE_IPV6
):
expect_fallback_response_over_plain_http = False
if expect_fallback_response_over_plain_http:
resp = requests.get(
'http://' + fqdn + ':' + str(port) + '/',
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
)
assert resp.status_code == 400
assert resp.text == (
@ -448,7 +705,7 @@ def test_http_over_https_error(
with pytest.raises(requests.exceptions.ConnectionError) as ssl_err:
requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL
'http://' + fqdn + ':' + str(port) + '/',
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
)
if IS_LINUX:
@ -468,7 +725,7 @@ def test_http_over_https_error(
underlying_error = ssl_err.value.args[0].args[-1]
err_text = str(underlying_error)
assert underlying_error.errno == expected_error_code, (
'The underlying error is {!r}'.
format(underlying_error)
'The underlying error is {underlying_error!r}'.
format(**locals())
)
assert expected_error_text in err_text

View file

@ -0,0 +1,58 @@
"""Test wsgi."""
from concurrent.futures.thread import ThreadPoolExecutor
import pytest
import portend
import requests
from requests_toolbelt.sessions import BaseUrlSession as Session
from jaraco.context import ExceptionTrap
from cheroot import wsgi
from cheroot._compat import IS_MACOS, IS_WINDOWS
IS_SLOW_ENV = IS_MACOS or IS_WINDOWS
@pytest.fixture
def simple_wsgi_server():
"""Fucking simple wsgi server fixture (duh)."""
port = portend.find_available_local_port()
def app(environ, start_response):
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
return [b'Hello world!']
host = '::'
addr = host, port
server = wsgi.Server(addr, app, timeout=600 if IS_SLOW_ENV else 20)
url = 'http://localhost:{port}/'.format(**locals())
with server._run_in_thread() as thread:
yield locals()
def test_connection_keepalive(simple_wsgi_server):
"""Test the connection keepalive works (duh)."""
session = Session(base_url=simple_wsgi_server['url'])
pooled = requests.adapters.HTTPAdapter(
pool_connections=1, pool_maxsize=1000,
)
session.mount('http://', pooled)
def do_request():
with ExceptionTrap(requests.exceptions.ConnectionError) as trap:
resp = session.get('info')
resp.raise_for_status()
return bool(trap)
with ThreadPoolExecutor(max_workers=10 if IS_SLOW_ENV else 50) as pool:
tasks = [
pool.submit(do_request)
for n in range(250 if IS_SLOW_ENV else 1000)
]
failures = sum(task.result() for task in tasks)
assert not failures

View file

@ -1,12 +1,14 @@
"""Extensions to unittest for web frameworks.
Use the WebCase.getPage method to request a page from your HTTP server.
Use the :py:meth:`WebCase.getPage` method to request a page
from your HTTP server.
Framework Integration
=====================
If you have control over your server process, you can handle errors
in the server-side of the HTTP conversation a bit better. You must run
both the client (your WebCase tests) and the server in the same process
(but in separate threads, obviously).
both the client (your :py:class:`WebCase` tests) and the server in the
same process (but in separate threads, obviously).
When an error occurs in the framework, call server_error. It will print
the traceback to stdout, and keep any assertions you have from running
(the assumption is that, if the server errors, the page output will not
@ -122,7 +124,7 @@ class WebCase(unittest.TestCase):
def _Conn(self):
"""Return HTTPConnection or HTTPSConnection based on self.scheme.
* from http.client.
* from :py:mod:`python:http.client`.
"""
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
return getattr(http_client, cls_name)
@ -157,7 +159,7 @@ class WebCase(unittest.TestCase):
@property
def persistent(self):
"""Presense of the persistent HTTP connection."""
"""Presence of the persistent HTTP connection."""
return hasattr(self.HTTP_CONN, '__class__')
@persistent.setter
@ -176,7 +178,9 @@ class WebCase(unittest.TestCase):
self, url, headers=None, method='GET', body=None,
protocol=None, raise_subcls=(),
):
"""Open the url with debugging support. Return status, headers, body.
"""Open the url with debugging support.
Return status, headers, body.
url should be the identifier passed to the server, typically a
server-absolute path and query string (sent between method and
@ -184,16 +188,16 @@ class WebCase(unittest.TestCase):
enabled in the server.
If the application under test generates absolute URIs, be sure
to wrap them first with strip_netloc::
to wrap them first with :py:func:`strip_netloc`::
class MyAppWebCase(WebCase):
def getPage(url, *args, **kwargs):
super(MyAppWebCase, self).getPage(
cheroot.test.webtest.strip_netloc(url),
*args, **kwargs
)
>>> class MyAppWebCase(WebCase):
... def getPage(url, *args, **kwargs):
... super(MyAppWebCase, self).getPage(
... cheroot.test.webtest.strip_netloc(url),
... *args, **kwargs
... )
`raise_subcls` is passed through to openURL.
``raise_subcls`` is passed through to :py:func:`openURL`.
"""
ServerError.on = False
@ -247,7 +251,7 @@ class WebCase(unittest.TestCase):
console_height = 30
def _handlewebError(self, msg):
def _handlewebError(self, msg): # noqa: C901 # FIXME
print('')
print(' ERROR: %s' % msg)
@ -487,7 +491,7 @@ def openURL(*args, **kwargs):
"""
Open a URL, retrying when it fails.
Specify `raise_subcls` (class or tuple of classes) to exclude
Specify ``raise_subcls`` (class or tuple of classes) to exclude
those socket.error subclasses from being suppressed and retried.
"""
raise_subcls = kwargs.pop('raise_subcls', ())
@ -553,7 +557,7 @@ def strip_netloc(url):
server-absolute portion.
Useful for wrapping an absolute-URI for which only the
path is expected (such as in calls to getPage).
path is expected (such as in calls to :py:meth:`WebCase.getPage`).
>>> strip_netloc('https://google.com/foo/bar?bing#baz')
'/foo/bar?bing'