Add cheroot-8.2.1

This commit is contained in:
JonnyWong16 2019-11-23 19:03:04 -08:00
commit 8f6639028f
27 changed files with 7925 additions and 0 deletions

View file

@ -0,0 +1 @@
"""Cheroot test suite."""

View file

@ -0,0 +1,69 @@
"""Pytest configuration module.
Contains fixtures, which are tightly bound to the Cheroot framework
itself, useless for end-users' app testing.
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import threading
import time
import pytest
from ..server import Gateway, HTTPServer
from ..testing import ( # noqa: F401
native_server, wsgi_server,
)
from ..testing import get_server_client
@pytest.fixture
def wsgi_server_client(wsgi_server): # noqa: F811
"""Create a test client out of given WSGI server."""
return get_server_client(wsgi_server)
@pytest.fixture
def native_server_client(native_server): # noqa: F811
"""Create a test client out of given HTTP server."""
return get_server_client(native_server)
@pytest.fixture
def http_server():
"""Provision a server creator as a fixture."""
def start_srv():
bind_addr = yield
if bind_addr is None:
return
httpserver = make_http_server(bind_addr)
yield httpserver
yield httpserver
srv_creator = iter(start_srv())
next(srv_creator)
yield srv_creator
try:
while True:
httpserver = next(srv_creator)
if httpserver is not None:
httpserver.stop()
except StopIteration:
pass
def make_http_server(bind_addr):
"""Create and start an HTTP server bound to bind_addr."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=Gateway,
)
threading.Thread(target=httpserver.safe_start).start()
while not httpserver.ready:
time.sleep(0.1)
return httpserver

168
lib/cheroot/test/helper.py Normal file
View file

@ -0,0 +1,168 @@
"""A library of helper functions for the Cheroot test suite."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import datetime
import logging
import os
import sys
import time
import threading
import types
from six.moves import http_client
import six
import cheroot.server
import cheroot.wsgi
from cheroot.test import webtest
log = logging.getLogger(__name__)
thisdir = os.path.abspath(os.path.dirname(__file__))
config = {
'bind_addr': ('127.0.0.1', 54583),
'server': 'wsgi',
'wsgi_app': None,
}
class CherootWebCase(webtest.WebCase):
"""Helper class for a web app test suite."""
script_name = ''
scheme = 'http'
available_servers = {
'wsgi': cheroot.wsgi.Server,
'native': cheroot.server.HTTPServer,
}
@classmethod
def setup_class(cls):
"""Create and run one HTTP server per class."""
conf = config.copy()
conf.update(getattr(cls, 'config', {}))
s_class = conf.pop('server', 'wsgi')
server_factory = cls.available_servers.get(s_class)
if server_factory is None:
raise RuntimeError('Unknown server in config: %s' % conf['server'])
cls.httpserver = server_factory(**conf)
cls.HOST, cls.PORT = cls.httpserver.bind_addr
if cls.httpserver.ssl_adapter is None:
ssl = ''
cls.scheme = 'http'
else:
ssl = ' (ssl)'
cls.HTTP_CONN = http_client.HTTPSConnection
cls.scheme = 'https'
v = sys.version.split()[0]
log.info('Python version used to run this test script: %s' % v)
log.info('Cheroot version: %s' % cheroot.__version__)
log.info('HTTP server version: %s%s' % (cls.httpserver.protocol, ssl))
log.info('PID: %s' % os.getpid())
if hasattr(cls, 'setup_server'):
# Clear the wsgi server so that
# it can be updated with the new root
cls.setup_server()
cls.start()
@classmethod
def teardown_class(cls):
"""Cleanup HTTP server."""
if hasattr(cls, 'setup_server'):
cls.stop()
@classmethod
def start(cls):
"""Load and start the HTTP server."""
threading.Thread(target=cls.httpserver.safe_start).start()
while not cls.httpserver.ready:
time.sleep(0.1)
@classmethod
def stop(cls):
"""Terminate HTTP server."""
cls.httpserver.stop()
td = getattr(cls, 'teardown', None)
if td:
td()
date_tolerance = 2
def assertEqualDates(self, dt1, dt2, seconds=None):
"""Assert abs(dt1 - dt2) is within Y seconds."""
if seconds is None:
seconds = self.date_tolerance
if dt1 > dt2:
diff = dt1 - dt2
else:
diff = dt2 - dt1
if not diff < datetime.timedelta(seconds=seconds):
raise AssertionError('%r and %r are not within %r seconds.' %
(dt1, dt2, seconds))
class Request:
"""HTTP request container."""
def __init__(self, environ):
"""Initialize HTTP request."""
self.environ = environ
class Response:
"""HTTP response container."""
def __init__(self):
"""Initialize HTTP response."""
self.status = '200 OK'
self.headers = {'Content-Type': 'text/html'}
self.body = None
def output(self):
"""Generate iterable response body object."""
if self.body is None:
return []
elif isinstance(self.body, six.text_type):
return [self.body.encode('iso-8859-1')]
elif isinstance(self.body, six.binary_type):
return [self.body]
else:
return [x.encode('iso-8859-1') for x in self.body]
class Controller:
"""WSGI app for tests."""
def __call__(self, environ, start_response):
"""WSGI request handler."""
req, resp = Request(environ), Response()
try:
# Python 3 supports unicode attribute names
# Python 2 encodes them
handler = self.handlers[environ['PATH_INFO']]
except KeyError:
resp.status = '404 Not Found'
else:
output = handler(req, resp)
if (output is not None
and not any(resp.status.startswith(status_code)
for status_code in ('204', '304'))):
resp.body = output
try:
resp.headers.setdefault('Content-Length', str(len(output)))
except TypeError:
if not isinstance(output, types.GeneratorType):
raise
start_response(resp.status, resp.headers.items())
return resp.output()

View file

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""Test suite for cross-python compatibility helpers."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
import six
from cheroot._compat import extract_bytes, memoryview, ntob, ntou, bton
@pytest.mark.parametrize(
'func,inp,out',
[
(ntob, 'bar', b'bar'),
(ntou, 'bar', u'bar'),
(bton, b'bar', 'bar'),
],
)
def test_compat_functions_positive(func, inp, out):
"""Check that compat functions work with correct input."""
assert func(inp, encoding='utf-8') == out
@pytest.mark.parametrize(
'func',
[
ntob,
ntou,
],
)
def test_compat_functions_negative_nonnative(func):
"""Check that compat functions fail loudly for incorrect input."""
non_native_test_str = u'bar' if six.PY2 else b'bar'
with pytest.raises(TypeError):
func(non_native_test_str, encoding='utf-8')
def test_ntou_escape():
"""Check that ntou supports escape-encoding under Python 2."""
expected = u'hišřії'
actual = ntou('hi\u0161\u0159\u0456\u0457', encoding='escape')
assert actual == expected
@pytest.mark.parametrize(
'input_argument,expected_result',
[
(b'qwerty', b'qwerty'),
(memoryview(b'asdfgh'), b'asdfgh'),
],
)
def test_extract_bytes(input_argument, expected_result):
"""Check that legitimate inputs produce bytes."""
assert extract_bytes(input_argument) == expected_result
def test_extract_bytes_invalid():
"""Ensure that invalid input causes exception to be raised."""
with pytest.raises(ValueError):
extract_bytes(u'some юнікод їїї')

View file

@ -0,0 +1,980 @@
"""Tests for TCP connection handling, including proper and timely close."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket
import time
from six.moves import range, http_client, urllib
import six
import pytest
from cheroot.test import helper, webtest
timeout = 1
pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN'
class Controller(helper.Controller):
"""Controller for serving WSGI apps."""
def hello(req, resp):
"""Render Hello world."""
return 'Hello, world!'
def pov(req, resp):
"""Render pov value."""
return pov
def stream(req, resp):
"""Render streaming response."""
if 'set_cl' in req.environ['QUERY_STRING']:
resp.headers['Content-Length'] = str(10)
def content():
for x in range(10):
yield str(x)
return content()
def upload(req, resp):
"""Process file upload and render thank."""
if not req.environ['REQUEST_METHOD'] == 'POST':
raise AssertionError("'POST' != request.method %r" %
req.environ['REQUEST_METHOD'])
return "thanks for '%s'" % req.environ['wsgi.input'].read()
def custom_204(req, resp):
"""Render response with status 204."""
resp.status = '204'
return 'Code = 204'
def custom_304(req, resp):
"""Render response with status 304."""
resp.status = '304'
return 'Code = 304'
def err_before_read(req, resp):
"""Render response with status 500."""
resp.status = '500 Internal Server Error'
return 'ok'
def one_megabyte_of_a(req, resp):
"""Render 1MB response."""
return ['a' * 1024] * 1024
def wrong_cl_buffered(req, resp):
"""Render buffered response with invalid length value."""
resp.headers['Content-Length'] = '5'
return 'I have too many bytes'
def wrong_cl_unbuffered(req, resp):
"""Render unbuffered response with invalid length value."""
resp.headers['Content-Length'] = '5'
return ['I too', ' have too many bytes']
def _munge(string):
"""Encode PATH_INFO correctly depending on Python version.
WSGI 1.0 is a mess around unicode. Create endpoints
that match the PATH_INFO that it produces.
"""
if six.PY2:
return string
return string.encode('utf-8').decode('latin-1')
handlers = {
'/hello': hello,
'/pov': pov,
'/page1': pov,
'/page2': pov,
'/page3': pov,
'/stream': stream,
'/upload': upload,
'/custom/204': custom_204,
'/custom/304': custom_304,
'/err_before_read': err_before_read,
'/one_megabyte_of_a': one_megabyte_of_a,
'/wrong_cl_buffered': wrong_cl_buffered,
'/wrong_cl_unbuffered': wrong_cl_unbuffered,
}
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
app = Controller()
def _timeout(req, resp):
return str(wsgi_server.timeout)
app.handlers['/timeout'] = _timeout
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = app
wsgi_server.max_request_body_size = 1001
wsgi_server.timeout = timeout
wsgi_server.server_client = wsgi_server_client
wsgi_server.keep_alive_conn_limit = 2
return wsgi_server
@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
return testing_server.server_client
def header_exists(header_name, headers):
"""Check that a header is present."""
return header_name.lower() in (k.lower() for (k, _) in headers)
def header_has_value(header_name, header_value, headers):
"""Check that a header with a given value is present."""
return header_name.lower() in (
k.lower() for (k, v) in headers
if v == header_value
)
def test_HTTP11_persistent_connections(test_client):
"""Test persistent HTTP/1.1 connections."""
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert there's no "Connection: close".
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Make another request on the same connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page1', http_conn=http_connection,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Test client-side close.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page2', http_conn=http_connection,
headers=[('Connection', 'close')],
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'close', actual_headers)
# Make another request on the same connection, which should error.
with pytest.raises(http_client.NotConnected):
test_client.get('/pov', http_conn=http_connection)
@pytest.mark.parametrize(
'set_cl',
(
False, # Without Content-Length
True, # With Content-Length
),
)
def test_streaming_11(test_client, set_cl):
"""Test serving of streaming responses with HTTP/1.1 protocol."""
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert there's no "Connection: close".
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Make another, streamed request on the same connection.
if set_cl:
# When a Content-Length is provided, the content should stream
# without closing the connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream?set_cl=Yes', http_conn=http_connection,
)
assert header_exists('Content-Length', actual_headers)
assert not header_has_value('Connection', 'close', actual_headers)
assert not header_exists('Transfer-Encoding', actual_headers)
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
else:
# When no Content-Length response header is provided,
# streamed output will either close the connection, or use
# chunked encoding, to determine transfer-length.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream', http_conn=http_connection,
)
assert not header_exists('Content-Length', actual_headers)
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
chunked_response = False
for k, v in actual_headers:
if k.lower() == 'transfer-encoding':
if str(v) == 'chunked':
chunked_response = True
if chunked_response:
assert not header_has_value('Connection', 'close', actual_headers)
else:
assert header_has_value('Connection', 'close', actual_headers)
# Make another request on the same connection, which should
# error.
with pytest.raises(http_client.NotConnected):
test_client.get('/pov', http_conn=http_connection)
# Try HEAD.
# See https://www.bitbucket.org/cherrypy/cherrypy/issue/864.
# TODO: figure out how can this be possible on an closed connection
# (chunked_response case)
status_line, actual_headers, actual_resp_body = test_client.head(
'/stream', http_conn=http_connection,
)
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b''
assert not header_exists('Transfer-Encoding', actual_headers)
@pytest.mark.parametrize(
'set_cl',
(
False, # Without Content-Length
True, # With Content-Length
),
)
def test_streaming_10(test_client, set_cl):
"""Test serving of streaming responses with HTTP/1.0 protocol."""
original_server_protocol = test_client.server_instance.protocol
test_client.server_instance.protocol = 'HTTP/1.0'
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert Keep-Alive.
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection,
headers=[('Connection', 'Keep-Alive')],
protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
# Make another, streamed request on the same connection.
if set_cl:
# When a Content-Length is provided, the content should
# stream without closing the connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream?set_cl=Yes', http_conn=http_connection,
headers=[('Connection', 'Keep-Alive')],
protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
assert header_exists('Content-Length', actual_headers)
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
assert not header_exists('Transfer-Encoding', actual_headers)
else:
# When a Content-Length is not provided,
# the server should close the connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream', http_conn=http_connection,
headers=[('Connection', 'Keep-Alive')],
protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
assert not header_exists('Content-Length', actual_headers)
assert not header_has_value('Connection', 'Keep-Alive', actual_headers)
assert not header_exists('Transfer-Encoding', actual_headers)
# Make another request on the same connection, which should error.
with pytest.raises(http_client.NotConnected):
test_client.get(
'/pov', http_conn=http_connection,
protocol='HTTP/1.0',
)
test_client.server_instance.protocol = original_server_protocol
@pytest.mark.parametrize(
'http_server_protocol',
(
'HTTP/1.0',
'HTTP/1.1',
),
)
def test_keepalive(test_client, http_server_protocol):
"""Test Keep-Alive enabled connections."""
original_server_protocol = test_client.server_instance.protocol
test_client.server_instance.protocol = http_server_protocol
http_client_protocol = 'HTTP/1.0'
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Test a normal HTTP/1.0 request.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page2',
protocol=http_client_protocol,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Test a keep-alive HTTP/1.0 request.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', headers=[('Connection', 'Keep-Alive')],
http_conn=http_connection, protocol=http_client_protocol,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
# Remove the keep-alive header again.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', http_conn=http_connection,
protocol=http_client_protocol,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
test_client.server_instance.protocol = original_server_protocol
def test_keepalive_conn_management(test_client):
"""Test management of Keep-Alive connections."""
test_client.server_instance.timeout = 2
def connection():
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
return http_connection
def request(conn):
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', headers=[('Connection', 'Keep-Alive')],
http_conn=conn, protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
disconnect_errors = (
http_client.BadStatusLine,
http_client.CannotSendRequest,
http_client.NotConnected,
)
# Make a new connection.
c1 = connection()
request(c1)
# Make a second one.
c2 = connection()
request(c2)
# Reusing the first connection should still work.
request(c1)
# Creating a new connection should still work.
c3 = connection()
request(c3)
# Allow a tick.
time.sleep(0.2)
# That's three connections, we should expect the one used less recently
# to be expired.
with pytest.raises(disconnect_errors):
request(c2)
# But the oldest created one should still be valid.
# (As well as the newest one).
request(c1)
request(c3)
# Wait for some of our timeout.
time.sleep(1.0)
# Refresh the third connection.
request(c3)
# Wait for the remainder of our timeout, plus one tick.
time.sleep(1.2)
# First connection should now be expired.
with pytest.raises(disconnect_errors):
request(c1)
# But the third one should still be valid.
request(c3)
test_client.server_instance.timeout = timeout
@pytest.mark.parametrize(
'timeout_before_headers',
(
True,
False,
),
)
def test_HTTP11_Timeout(test_client, timeout_before_headers):
"""Check timeout without sending any data.
The server will close the conn with a 408.
"""
conn = test_client.get_connection()
conn.auto_open = False
conn.connect()
if not timeout_before_headers:
# Connect but send half the headers only.
conn.send(b'GET /hello HTTP/1.1')
conn.send(('Host: %s' % conn.host).encode('ascii'))
# else: Connect but send nothing.
# Wait for our socket timeout
time.sleep(timeout * 2)
# The request should have returned 408 already.
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 408
conn.close()
def test_HTTP11_Timeout_after_request(test_client):
"""Check timeout after at least one request has succeeded.
The server should close the connection without 408.
"""
fail_msg = "Writing to timed out socket didn't fail as it should have: %s"
# Make an initial request
conn = test_client.get_connection()
conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = str(timeout).encode()
assert actual_body == expected_body
# Make a second request on the same socket
conn._output(b'GET /hello HTTP/1.1')
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._send_output()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = b'Hello, world!'
assert actual_body == expected_body
# Wait for our socket timeout
time.sleep(timeout * 2)
# Make another request on the same socket, which should error
conn._output(b'GET /hello HTTP/1.1')
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._send_output()
response = conn.response_class(conn.sock, method='GET')
try:
response.begin()
except (socket.error, http_client.BadStatusLine):
pass
except Exception as ex:
pytest.fail(fail_msg % ex)
else:
if response.status != 408:
pytest.fail(fail_msg % response.read())
conn.close()
# Make another request on a new socket, which should work
conn = test_client.get_connection()
conn.putrequest('GET', '/pov', skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = pov.encode()
assert actual_body == expected_body
# Make another request on the same socket,
# but timeout on the headers
conn.send(b'GET /hello HTTP/1.1')
# Wait for our socket timeout
time.sleep(timeout * 2)
response = conn.response_class(conn.sock, method='GET')
try:
response.begin()
except (socket.error, http_client.BadStatusLine):
pass
except Exception as ex:
pytest.fail(fail_msg % ex)
else:
if response.status != 408:
pytest.fail(fail_msg % response.read())
conn.close()
# Retry the request on a new connection, which should work
conn = test_client.get_connection()
conn.putrequest('GET', '/pov', skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = pov.encode()
assert actual_body == expected_body
conn.close()
def test_HTTP11_pipelining(test_client):
"""Test HTTP/1.1 pipelining.
httplib doesn't support this directly.
"""
conn = test_client.get_connection()
# Put request 1
conn.putrequest('GET', '/hello', skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
for trial in range(5):
# Put next request
conn._output(
('GET /hello?%s HTTP/1.1' % trial).encode('iso-8859-1'),
)
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._send_output()
# Retrieve previous response
response = conn.response_class(conn.sock, method='GET')
# there is a bug in python3 regarding the buffering of
# ``conn.sock``. Until that bug get's fixed we will
# monkey patch the ``response`` instance.
# https://bugs.python.org/issue23377
if not six.PY2:
response.fp = conn.sock.makefile('rb', 0)
response.begin()
body = response.read(13)
assert response.status == 200
assert body == b'Hello, world!'
# Retrieve final response
response = conn.response_class(conn.sock, method='GET')
response.begin()
body = response.read()
assert response.status == 200
assert body == b'Hello, world!'
conn.close()
def test_100_Continue(test_client):
"""Test 100-continue header processing."""
conn = test_client.get_connection()
# Try a page without an Expect request header first.
# Note that httplib's response.begin automatically ignores
# 100 Continue responses, so we must manually check for it.
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '4')
conn.endheaders()
conn.send(b"d'oh")
response = conn.response_class(conn.sock, method='POST')
version, status, reason = response._read_status()
assert status != 100
conn.close()
# Now try a page with an Expect header...
conn.connect()
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '17')
conn.putheader('Expect', '100-continue')
conn.endheaders()
response = conn.response_class(conn.sock, method='POST')
# ...assert and then skip the 100 response
version, status, reason = response._read_status()
assert status == 100
while True:
line = response.fp.readline().strip()
if line:
pytest.fail(
'100 Continue should not output any headers. Got %r' %
line,
)
else:
break
# ...send the body
body = b'I am a small file'
conn.send(body)
# ...get the final response
response.begin()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode()
assert actual_resp_body == expected_resp_body
conn.close()
@pytest.mark.parametrize(
'max_request_body_size',
(
0,
1001,
),
)
def test_readall_or_close(test_client, max_request_body_size):
"""Test a max_request_body_size of 0 (the default) and 1001."""
old_max = test_client.server_instance.max_request_body_size
test_client.server_instance.max_request_body_size = max_request_body_size
conn = test_client.get_connection()
# Get a POST page with an error
conn.putrequest('POST', '/err_before_read', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '1000')
conn.putheader('Expect', '100-continue')
conn.endheaders()
response = conn.response_class(conn.sock, method='POST')
# ...assert and then skip the 100 response
version, status, reason = response._read_status()
assert status == 100
skip = True
while skip:
skip = response.fp.readline().strip()
# ...send the body
conn.send(b'x' * 1000)
# ...get the final response
response.begin()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 500
# Now try a working page with an Expect header...
conn._output(b'POST /upload HTTP/1.1')
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._output(b'Content-Type: text/plain')
conn._output(b'Content-Length: 17')
conn._output(b'Expect: 100-continue')
conn._send_output()
response = conn.response_class(conn.sock, method='POST')
# ...assert and then skip the 100 response
version, status, reason = response._read_status()
assert status == 100
skip = True
while skip:
skip = response.fp.readline().strip()
# ...send the body
body = b'I am a small file'
conn.send(body)
# ...get the final response
response.begin()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode()
assert actual_resp_body == expected_resp_body
conn.close()
test_client.server_instance.max_request_body_size = old_max
def test_No_Message_Body(test_client):
"""Test HTTP queries with an empty response body."""
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert there's no "Connection: close".
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Make a 204 request on the same connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/custom/204', http_conn=http_connection,
)
actual_status = int(status_line[:3])
assert actual_status == 204
assert not header_exists('Content-Length', actual_headers)
assert actual_resp_body == b''
assert not header_exists('Connection', actual_headers)
# Make a 304 request on the same connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/custom/304', http_conn=http_connection,
)
actual_status = int(status_line[:3])
assert actual_status == 304
assert not header_exists('Content-Length', actual_headers)
assert actual_resp_body == b''
assert not header_exists('Connection', actual_headers)
@pytest.mark.xfail(
reason='Server does not correctly read trailers/ending of the previous '
'HTTP request, thus the second request fails as the server tries '
r"to parse b'Content-Type: application/json\r\n' as a "
'Request-Line. This results in HTTP status code 400, instead of 413'
'Ref: https://github.com/cherrypy/cheroot/issues/69',
)
def test_Chunked_Encoding(test_client):
"""Test HTTP uploads with chunked transfer-encoding."""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
# Try a normal chunked request (with extensions)
body = (
b'8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n'
b'Content-Type: application/json\r\n'
b'\r\n'
)
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Transfer-Encoding', 'chunked')
conn.putheader('Trailer', 'Content-Type')
# Note that this is somewhat malformed:
# we shouldn't be sending Content-Length.
# RFC 2616 says the server should ignore it.
conn.putheader('Content-Length', '3')
conn.endheaders()
conn.send(body)
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
expected_resp_body = ("thanks for '%s'" % b'xx\r\nxxxxyyyyy').encode()
assert actual_resp_body == expected_resp_body
# Try a chunked request that exceeds server.max_request_body_size.
# Note that the delimiters and trailer are included.
body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n'
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Transfer-Encoding', 'chunked')
conn.putheader('Content-Type', 'text/plain')
# Chunked requests don't need a content-length
# conn.putheader("Content-Length", len(body))
conn.endheaders()
conn.send(body)
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 413
conn.close()
def test_Content_Length_in(test_client):
"""Try a non-chunked request where Content-Length exceeds limit.
(server.max_request_body_size).
Assert error before body send.
"""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '9999')
conn.endheaders()
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 413
expected_resp_body = (
b'The entity sent with the request exceeds '
b'the maximum allowed bytes.'
)
assert actual_resp_body == expected_resp_body
conn.close()
def test_Content_Length_not_int(test_client):
"""Test that malicious Content-Length header returns 400."""
status_line, actual_headers, actual_resp_body = test_client.post(
'/upload',
headers=[
('Content-Type', 'text/plain'),
('Content-Length', 'not-an-integer'),
],
)
actual_status = int(status_line[:3])
assert actual_status == 400
assert actual_resp_body == b'Malformed Content-Length Header.'
@pytest.mark.parametrize(
'uri,expected_resp_status,expected_resp_body',
(
(
'/wrong_cl_buffered', 500,
(
b'The requested resource returned more bytes than the '
b'declared Content-Length.'
),
),
('/wrong_cl_unbuffered', 200, b'I too'),
),
)
def test_Content_Length_out(
test_client,
uri, expected_resp_status, expected_resp_body,
):
"""Test response with Content-Length less than the response body.
(non-chunked response)
"""
conn = test_client.get_connection()
conn.putrequest('GET', uri, skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == expected_resp_status
assert actual_resp_body == expected_resp_body
conn.close()
@pytest.mark.xfail(
reason='Sometimes this test fails due to low timeout. '
'Ref: https://github.com/cherrypy/cherrypy/issues/598',
)
def test_598(test_client):
"""Test serving large file with a read timeout in place."""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
remote_data_conn = urllib.request.urlopen(
'%s://%s:%s/one_megabyte_of_a'
% ('http', conn.host, conn.port),
)
buf = remote_data_conn.read(512)
time.sleep(timeout * 0.6)
remaining = (1024 * 1024) - 512
while remaining:
data = remote_data_conn.read(remaining)
if not data:
break
buf += data
remaining -= len(data)
assert len(buf) == 1024 * 1024
assert buf == b'a' * 1024 * 1024
assert remaining == 0
remote_data_conn.close()
@pytest.mark.parametrize(
'invalid_terminator',
(
b'\n\n',
b'\r\n\n',
),
)
def test_No_CRLF(test_client, invalid_terminator):
"""Test HTTP queries with no valid CRLF terminators."""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
# (b'%s' % b'') is not supported in Python 3.4, so just use +
conn.send(b'GET /hello HTTP/1.1' + invalid_terminator)
response = conn.response_class(conn.sock, method='GET')
response.begin()
actual_resp_body = response.read()
expected_resp_body = b'HTTP requires CRLF terminators'
assert actual_resp_body == expected_resp_body
conn.close()

View file

@ -0,0 +1,415 @@
"""Tests for managing HTTP issues (malformed requests, etc)."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno
import socket
import pytest
import six
from six.moves import urllib
from cheroot.test import helper
HTTP_BAD_REQUEST = 400
HTTP_LENGTH_REQUIRED = 411
HTTP_NOT_FOUND = 404
HTTP_OK = 200
HTTP_VERSION_NOT_SUPPORTED = 505
class HelloController(helper.Controller):
"""Controller for serving WSGI apps."""
def hello(req, resp):
"""Render Hello world."""
return 'Hello world!'
def body_required(req, resp):
"""Render Hello world or set 411."""
if req.environ.get('Content-Length', None) is None:
resp.status = '411 Length Required'
return
return 'Hello world!'
def query_string(req, resp):
"""Render QUERY_STRING value."""
return req.environ.get('QUERY_STRING', '')
def asterisk(req, resp):
"""Render request method value."""
method = req.environ.get('REQUEST_METHOD', 'NO METHOD FOUND')
tmpl = 'Got asterisk URI path with {method} method'
return tmpl.format(**locals())
def _munge(string):
"""Encode PATH_INFO correctly depending on Python version.
WSGI 1.0 is a mess around unicode. Create endpoints
that match the PATH_INFO that it produces.
"""
if six.PY2:
return string
return string.encode('utf-8').decode('latin-1')
handlers = {
'/hello': hello,
'/no_body': hello,
'/body_required': body_required,
'/query_string': query_string,
_munge('/привіт'): hello,
_munge('/Юххууу'): hello,
'/\xa0Ðblah key 0 900 4 data': hello,
'/*': asterisk,
}
def _get_http_response(connection, method='GET'):
c = connection
kwargs = {'strict': c.strict} if hasattr(c, 'strict') else {}
# Python 3.2 removed the 'strict' feature, saying:
# "http.client now always assumes HTTP/1.x compliant servers."
return c.response_class(c.sock, method=method, **kwargs)
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = HelloController()
wsgi_server.max_request_body_size = 30000000
wsgi_server.server_client = wsgi_server_client
return wsgi_server
@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
return testing_server.server_client
def test_http_connect_request(test_client):
"""Check that CONNECT query results in Method Not Allowed status."""
status_line = test_client.connect('/anything')[0]
actual_status = int(status_line[:3])
assert actual_status == 405
def test_normal_request(test_client):
"""Check that normal GET query succeeds."""
status_line, _, actual_resp_body = test_client.get('/hello')
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
assert actual_resp_body == b'Hello world!'
def test_query_string_request(test_client):
"""Check that GET param is parsed well."""
status_line, _, actual_resp_body = test_client.get(
'/query_string?test=True',
)
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
assert actual_resp_body == b'test=True'
@pytest.mark.parametrize(
'uri',
(
'/hello', # plain
'/query_string?test=True', # query
'/{0}?{1}={2}'.format( # quoted unicode
*map(urllib.parse.quote, ('Юххууу', 'ї', 'йо'))
),
),
)
def test_parse_acceptable_uri(test_client, uri):
"""Check that server responds with OK to valid GET queries."""
status_line = test_client.get(uri)[0]
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
@pytest.mark.xfail(six.PY2, reason='Fails on Python 2')
def test_parse_uri_unsafe_uri(test_client):
"""Test that malicious URI does not allow HTTP injection.
This effectively checks that sending GET request with URL
/%A0%D0blah%20key%200%20900%204%20data
is not converted into
GET /
blah key 0 900 4 data
HTTP/1.1
which would be a security issue otherwise.
"""
c = test_client.get_connection()
resource = '/\xa0Ðblah key 0 900 4 data'.encode('latin-1')
quoted = urllib.parse.quote(resource)
assert quoted == '/%A0%D0blah%20key%200%20900%204%20data'
request = 'GET {quoted} HTTP/1.1'.format(**locals())
c._output(request.encode('utf-8'))
c._send_output()
response = _get_http_response(c, method='GET')
response.begin()
assert response.status == HTTP_OK
assert response.read(12) == b'Hello world!'
c.close()
def test_parse_uri_invalid_uri(test_client):
"""Check that server responds with Bad Request to invalid GET queries.
Invalid request line test case: it should only contain US-ASCII.
"""
c = test_client.get_connection()
c._output(u'GET /йопта! HTTP/1.1'.encode('utf-8'))
c._send_output()
response = _get_http_response(c, method='GET')
response.begin()
assert response.status == HTTP_BAD_REQUEST
assert response.read(21) == b'Malformed Request-URI'
c.close()
@pytest.mark.parametrize(
'uri',
(
'hello', # ascii
'привіт', # non-ascii
),
)
def test_parse_no_leading_slash_invalid(test_client, uri):
"""Check that server responds with Bad Request to invalid GET queries.
Invalid request line test case: it should have leading slash (be absolute).
"""
status_line, _, actual_resp_body = test_client.get(
urllib.parse.quote(uri),
)
actual_status = int(status_line[:3])
assert actual_status == HTTP_BAD_REQUEST
assert b'starting with a slash' in actual_resp_body
def test_parse_uri_absolute_uri(test_client):
"""Check that server responds with Bad Request to Absolute URI.
Only proxy servers should allow this.
"""
status_line, _, actual_resp_body = test_client.get('http://google.com/')
actual_status = int(status_line[:3])
assert actual_status == HTTP_BAD_REQUEST
expected_body = b'Absolute URI not allowed if server is not a proxy.'
assert actual_resp_body == expected_body
def test_parse_uri_asterisk_uri(test_client):
"""Check that server responds with OK to OPTIONS with "*" Absolute URI."""
status_line, _, actual_resp_body = test_client.options('*')
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
expected_body = b'Got asterisk URI path with OPTIONS method'
assert actual_resp_body == expected_body
def test_parse_uri_fragment_uri(test_client):
"""Check that server responds with Bad Request to URI with fragment."""
status_line, _, actual_resp_body = test_client.get(
'/hello?test=something#fake',
)
actual_status = int(status_line[:3])
assert actual_status == HTTP_BAD_REQUEST
expected_body = b'Illegal #fragment in Request-URI.'
assert actual_resp_body == expected_body
def test_no_content_length(test_client):
"""Test POST query with an empty body being successful."""
# "The presence of a message-body in a request is signaled by the
# inclusion of a Content-Length or Transfer-Encoding header field in
# the request's message-headers."
#
# Send a message with neither header and no body.
c = test_client.get_connection()
c.request('POST', '/no_body')
response = c.getresponse()
actual_resp_body = response.read()
actual_status = response.status
assert actual_status == HTTP_OK
assert actual_resp_body == b'Hello world!'
def test_content_length_required(test_client):
"""Test POST query with body failing because of missing Content-Length."""
# Now send a message that has no Content-Length, but does send a body.
# Verify that CP times out the socket and responds
# with 411 Length Required.
c = test_client.get_connection()
c.request('POST', '/body_required')
response = c.getresponse()
response.read()
actual_status = response.status
assert actual_status == HTTP_LENGTH_REQUIRED
@pytest.mark.parametrize(
'request_line,status_code,expected_body',
(
(
b'GET /', # missing proto
HTTP_BAD_REQUEST, b'Malformed Request-Line',
),
(
b'GET / HTTPS/1.1', # invalid proto
HTTP_BAD_REQUEST, b'Malformed Request-Line: bad protocol',
),
(
b'GET / HTTP/1', # invalid version
HTTP_BAD_REQUEST, b'Malformed Request-Line: bad version',
),
(
b'GET / HTTP/2.15', # invalid ver
HTTP_VERSION_NOT_SUPPORTED, b'Cannot fulfill request',
),
),
)
def test_malformed_request_line(
test_client, request_line,
status_code, expected_body,
):
"""Test missing or invalid HTTP version in Request-Line."""
c = test_client.get_connection()
c._output(request_line)
c._send_output()
response = _get_http_response(c, method='GET')
response.begin()
assert response.status == status_code
assert response.read(len(expected_body)) == expected_body
c.close()
def test_malformed_http_method(test_client):
"""Test non-uppercase HTTP method."""
c = test_client.get_connection()
c.putrequest('GeT', '/malformed_method_case')
c.putheader('Content-Type', 'text/plain')
c.endheaders()
response = c.getresponse()
actual_status = response.status
assert actual_status == HTTP_BAD_REQUEST
actual_resp_body = response.read(21)
assert actual_resp_body == b'Malformed method name'
def test_malformed_header(test_client):
"""Check that broken HTTP header results in Bad Request."""
c = test_client.get_connection()
c.putrequest('GET', '/')
c.putheader('Content-Type', 'text/plain')
# See https://www.bitbucket.org/cherrypy/cherrypy/issue/941
c._output(b'Re, 1.2.3.4#015#012')
c.endheaders()
response = c.getresponse()
actual_status = response.status
assert actual_status == HTTP_BAD_REQUEST
actual_resp_body = response.read(20)
assert actual_resp_body == b'Illegal header line.'
def test_request_line_split_issue_1220(test_client):
"""Check that HTTP request line of exactly 256 chars length is OK."""
Request_URI = (
'/hello?'
'intervenant-entreprise-evenement_classaction='
'evenement-mailremerciements'
'&_path=intervenant-entreprise-evenement'
'&intervenant-entreprise-evenement_action-id=19404'
'&intervenant-entreprise-evenement_id=19404'
'&intervenant-entreprise_id=28092'
)
assert len('GET %s HTTP/1.1\r\n' % Request_URI) == 256
actual_resp_body = test_client.get(Request_URI)[2]
assert actual_resp_body == b'Hello world!'
def test_garbage_in(test_client):
"""Test that server sends an error for garbage received over TCP."""
# Connect without SSL regardless of server.scheme
c = test_client.get_connection()
c._output(b'gjkgjklsgjklsgjkljklsg')
c._send_output()
response = c.response_class(c.sock, method='GET')
try:
response.begin()
actual_status = response.status
assert actual_status == HTTP_BAD_REQUEST
actual_resp_body = response.read(22)
assert actual_resp_body == b'Malformed Request-Line'
c.close()
except socket.error as ex:
# "Connection reset by peer" is also acceptable.
if ex.errno != errno.ECONNRESET:
raise
class CloseController:
"""Controller for testing the close callback."""
def __call__(self, environ, start_response):
"""Get the req to know header sent status."""
self.req = start_response.__self__.req
resp = CloseResponse(self.close)
start_response(resp.status, resp.headers.items())
return resp
def close(self):
"""Close, writing hello."""
self.req.write(b'hello')
class CloseResponse:
"""Dummy empty response to trigger the no body status."""
def __init__(self, close):
"""Use some defaults to ensure we have a header."""
self.status = '200 OK'
self.headers = {'Content-Type': 'text/html'}
self.close = close
def __getitem__(self, index):
"""Ensure we don't have a body."""
raise IndexError()
def output(self):
"""Return self to hook the close method."""
return self
@pytest.fixture
def testing_server_close(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = CloseController()
wsgi_server.max_request_body_size = 30000000
wsgi_server.server_client = wsgi_server_client
return wsgi_server
def test_send_header_before_closing(testing_server_close):
"""Test we are actually sending the headers before calling 'close'."""
_, _, resp_body = testing_server_close.server_client.get('/')
assert resp_body == b'hello'

View file

@ -0,0 +1,55 @@
"""Tests for the HTTP server."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
from cheroot.wsgi import PathInfoDispatcher
def wsgi_invoke(app, environ):
"""Serve 1 requeset from a WSGI application."""
response = {}
def start_response(status, headers):
response.update({
'status': status,
'headers': headers,
})
response['body'] = b''.join(
app(environ, start_response),
)
return response
def test_dispatch_no_script_name():
"""Despatch despite lack of SCRIPT_NAME in environ."""
# Bare bones WSGI hello world app (from PEP 333).
def app(environ, start_response):
start_response(
'200 OK', [
('Content-Type', 'text/plain; charset=utf-8'),
],
)
return [u'Hello, world!'.encode('utf-8')]
# Build a dispatch table.
d = PathInfoDispatcher([
('/', app),
])
# Dispatch a request without `SCRIPT_NAME`.
response = wsgi_invoke(
d, {
'PATH_INFO': '/foo',
},
)
assert response == {
'status': '200 OK',
'headers': [
('Content-Type', 'text/plain; charset=utf-8'),
],
'body': b'Hello, world!',
}

View file

@ -0,0 +1,30 @@
"""Test suite for ``cheroot.errors``."""
import pytest
from cheroot import errors
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
@pytest.mark.parametrize(
'err_names,err_nums',
(
(('', 'some-nonsense-name'), []),
(
(
'EPROTOTYPE', 'EAGAIN', 'EWOULDBLOCK',
'WSAEWOULDBLOCK', 'EPIPE',
),
(91, 11, 32) if IS_LINUX else
(32, 35, 41) if IS_MACOS else
(32, 10041, 11, 10035) if IS_WINDOWS else
(),
),
),
)
def test_plat_specific_errors(err_names, err_nums):
"""Test that plat_specific_errors retrieves correct err num list."""
actual_err_nums = errors.plat_specific_errors(*err_names)
assert len(actual_err_nums) == len(err_nums)
assert sorted(actual_err_nums) == sorted(err_nums)

View file

@ -0,0 +1,52 @@
"""self-explanatory."""
from cheroot import makefile
__metaclass__ = type
class MockSocket:
"""Mocks a socket."""
def __init__(self):
"""Initialize."""
self.messages = []
def recv_into(self, buf):
"""Simulate recv_into for Python 3."""
if not self.messages:
return 0
msg = self.messages.pop(0)
for index, byte in enumerate(msg):
buf[index] = byte
return len(msg)
def recv(self, size):
"""Simulate recv for Python 2."""
try:
return self.messages.pop(0)
except IndexError:
return ''
def send(self, val):
"""Simulate a send."""
return len(val)
def test_bytes_read():
"""Reader should capture bytes read."""
sock = MockSocket()
sock.messages.append(b'foo')
rfile = makefile.MakeFile(sock, 'r')
rfile.read()
assert rfile.bytes_read == 3
def test_bytes_written():
"""Writer should capture bytes writtten."""
sock = MockSocket()
sock.messages.append(b'foo')
wfile = makefile.MakeFile(sock, 'w')
wfile.write(b'bar')
assert wfile.bytes_written == 3

View file

@ -0,0 +1,235 @@
"""Tests for the HTTP server."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os
import socket
import tempfile
import threading
import uuid
import pytest
import requests
import requests_unixsocket
import six
from .._compat import bton, ntob
from .._compat import IS_LINUX, IS_MACOS, SYS_PLATFORM
from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
EPHEMERAL_PORT,
get_server_client,
)
unix_only_sock_test = pytest.mark.skipif(
not hasattr(socket, 'AF_UNIX'),
reason='UNIX domain sockets are only available under UNIX-based OS',
)
non_macos_sock_test = pytest.mark.skipif(
IS_MACOS,
reason='Peercreds lookup does not work under macOS/BSD currently.',
)
@pytest.fixture(params=('abstract', 'file'))
def unix_sock_file(request):
"""Check that bound UNIX socket address is stored in server."""
if request.param == 'abstract':
yield request.getfixturevalue('unix_abstract_sock')
return
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
yield tmp_sock_fname
os.close(tmp_sock_fh)
os.unlink(tmp_sock_fname)
@pytest.fixture
def unix_abstract_sock():
"""Return an abstract UNIX socket address."""
if not IS_LINUX:
pytest.skip(
'{os} does not support an abstract '
'socket namespace'.format(os=SYS_PLATFORM),
)
return b''.join((
b'\x00cheroot-test-socket',
ntob(str(uuid.uuid4())),
)).decode()
def test_prepare_makes_server_ready():
"""Check that prepare() makes the server ready, and stop() clears it."""
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
assert not httpserver.ready
assert not httpserver.requests._threads
httpserver.prepare()
assert httpserver.ready
assert httpserver.requests._threads
for thr in httpserver.requests._threads:
assert thr.ready
httpserver.stop()
assert not httpserver.requests._threads
assert not httpserver.ready
def test_stop_interrupts_serve():
"""Check that stop() interrupts running of serve()."""
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
httpserver.prepare()
serve_thread = threading.Thread(target=httpserver.serve)
serve_thread.start()
serve_thread.join(0.5)
assert serve_thread.is_alive()
httpserver.stop()
serve_thread.join(0.5)
assert not serve_thread.is_alive()
@pytest.mark.parametrize(
'ip_addr',
(
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
),
)
def test_bind_addr_inet(http_server, ip_addr):
"""Check that bound IP address is stored in server."""
httpserver = http_server.send((ip_addr, EPHEMERAL_PORT))
assert httpserver.bind_addr[0] == ip_addr
assert httpserver.bind_addr[1] != EPHEMERAL_PORT
@unix_only_sock_test
def test_bind_addr_unix(http_server, unix_sock_file):
"""Check that bound UNIX socket address is stored in server."""
httpserver = http_server.send(unix_sock_file)
assert httpserver.bind_addr == unix_sock_file
@unix_only_sock_test
def test_bind_addr_unix_abstract(http_server, unix_abstract_sock):
"""Check that bound UNIX abstract sockaddr is stored in server."""
httpserver = http_server.send(unix_abstract_sock)
assert httpserver.bind_addr == unix_abstract_sock
PEERCRED_IDS_URI = '/peer_creds/ids'
PEERCRED_TEXTS_URI = '/peer_creds/texts'
class _TestGateway(Gateway):
def respond(self):
req = self.req
conn = req.conn
req_uri = bton(req.uri)
if req_uri == PEERCRED_IDS_URI:
peer_creds = conn.peer_pid, conn.peer_uid, conn.peer_gid
self.send_payload('|'.join(map(str, peer_creds)))
return
elif req_uri == PEERCRED_TEXTS_URI:
self.send_payload('!'.join((conn.peer_user, conn.peer_group)))
return
return super(_TestGateway, self).respond()
def send_payload(self, payload):
req = self.req
req.status = b'200 OK'
req.ensure_headers_sent()
req.write(ntob(payload))
@pytest.fixture
def peercreds_enabled_server_and_client(http_server, unix_sock_file):
"""Construct a test server with `peercreds_enabled`."""
httpserver = http_server.send(unix_sock_file)
httpserver.gateway = _TestGateway
httpserver.peercreds_enabled = True
return httpserver, get_server_client(httpserver)
@unix_only_sock_test
@non_macos_sock_test
def test_peercreds_unix_sock(peercreds_enabled_server_and_client):
"""Check that peercred lookup works when enabled."""
httpserver, testclient = peercreds_enabled_server_and_client
bind_addr = httpserver.bind_addr
if isinstance(bind_addr, six.binary_type):
bind_addr = bind_addr.decode()
unix_base_uri = 'http+unix://{}'.format(
bind_addr.replace('\0', '%00').replace('/', '%2F'),
)
expected_peercreds = os.getpid(), os.getuid(), os.getgid()
expected_peercreds = '|'.join(map(str, expected_peercreds))
with requests_unixsocket.monkeypatch():
peercreds_resp = requests.get(unix_base_uri + PEERCRED_IDS_URI)
peercreds_resp.raise_for_status()
assert peercreds_resp.text == expected_peercreds
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI)
assert peercreds_text_resp.status_code == 500
@pytest.mark.skipif(
not IS_UID_GID_RESOLVABLE,
reason='Modules `grp` and `pwd` are not available '
'under the current platform',
)
@unix_only_sock_test
@non_macos_sock_test
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client):
"""Check that peercred resolution works when enabled."""
httpserver, testclient = peercreds_enabled_server_and_client
httpserver.peercreds_resolve_enabled = True
bind_addr = httpserver.bind_addr
if isinstance(bind_addr, six.binary_type):
bind_addr = bind_addr.decode()
unix_base_uri = 'http+unix://{}'.format(
bind_addr.replace('\0', '%00').replace('/', '%2F'),
)
import grp
import pwd
expected_textcreds = (
pwd.getpwuid(os.getuid()).pw_name,
grp.getgrgid(os.getgid()).gr_name,
)
expected_textcreds = '!'.join(map(str, expected_textcreds))
with requests_unixsocket.monkeypatch():
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI)
peercreds_text_resp.raise_for_status()
assert peercreds_text_resp.text == expected_textcreds

View file

@ -0,0 +1,474 @@
"""Tests for TLS/SSL support."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import functools
import os
import ssl
import sys
import threading
import time
import OpenSSL.SSL
import pytest
import requests
import six
import trustme
from .._compat import bton, ntob, ntou
from .._compat import IS_ABOVE_OPENSSL10, IS_PYPY
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS
from ..server import Gateway, HTTPServer, get_ssl_adapter_class
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
EPHEMERAL_PORT,
# get_server_client,
_get_conn_data,
_probe_ipv6_sock,
)
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL')
IS_PYOPENSSL_SSL_VERSION_1_0 = (
OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION).
startswith(b'OpenSSL 1.0.')
)
PY27 = sys.version_info[:2] == (2, 7)
PY34 = sys.version_info[:2] == (3, 4)
_stdlib_to_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED:
OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
fails_under_py3 = pytest.mark.xfail(
not six.PY2,
reason='Fails under Python 3+',
)
fails_under_py3_in_pypy = pytest.mark.xfail(
not six.PY2 and IS_PYPY,
reason='Fails under PyPy3',
)
missing_ipv6 = pytest.mark.skipif(
not _probe_ipv6_sock('::1'),
reason=''
'IPv6 is disabled '
'(for example, under Travis CI '
'which runs under GCE supporting only IPv4)',
)
class HelloWorldGateway(Gateway):
"""Gateway responding with Hello World to root URI."""
def respond(self):
"""Respond with dummy content via HTTP."""
req = self.req
req_uri = bton(req.uri)
if req_uri == '/':
req.status = b'200 OK'
req.ensure_headers_sent()
req.write(b'Hello world!')
return
return super(HelloWorldGateway, self).respond()
def make_tls_http_server(bind_addr, ssl_adapter, request):
"""Create and start an HTTP server bound to bind_addr."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=HelloWorldGateway,
)
# httpserver.gateway = HelloWorldGateway
httpserver.ssl_adapter = ssl_adapter
threading.Thread(target=httpserver.safe_start).start()
while not httpserver.ready:
time.sleep(0.1)
request.addfinalizer(httpserver.stop)
return httpserver
@pytest.fixture
def tls_http_server(request):
"""Provision a server creator as a fixture."""
return functools.partial(make_tls_http_server, request=request)
@pytest.fixture
def ca():
"""Provide a certificate authority via fixture."""
return trustme.CA()
@pytest.fixture
def tls_ca_certificate_pem_path(ca):
"""Provide a certificate authority certificate file via fixture."""
with ca.cert_pem.tempfile() as ca_cert_pem:
yield ca_cert_pem
@pytest.fixture
def tls_certificate(ca):
"""Provide a leaf certificate via fixture."""
interface, host, port = _get_conn_data(ANY_INTERFACE_IPV4)
return ca.issue_server_cert(ntou(interface), )
@pytest.fixture
def tls_certificate_chain_pem_path(tls_certificate):
"""Provide a certificate chain PEM file path via fixture."""
with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem:
yield cert_pem
@pytest.fixture
def tls_certificate_private_key_pem_path(tls_certificate):
"""Provide a certificate private key PEM file path via fixture."""
with tls_certificate.private_key_pem.tempfile() as cert_key_pem:
yield cert_key_pem
@pytest.mark.parametrize(
'adapter_type',
(
'builtin',
'pyopenssl',
),
)
def test_ssl_adapters(
tls_http_server, adapter_type,
tls_certificate,
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
tls_ca_certificate_pem_path,
):
"""Test ability to connect to server via HTTPS using adapters."""
interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4)
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
tls_adapter = tls_adapter_cls(
tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path,
)
if adapter_type == 'pyopenssl':
tls_adapter.context = tls_adapter.get_context()
tls_certificate.configure_cert(tls_adapter.context)
tlshttpserver = tls_http_server((interface, port), tls_adapter)
# testclient = get_server_client(tlshttpserver)
# testclient.get('/')
interface, _host, port = _get_conn_data(
tlshttpserver.bind_addr,
)
resp = requests.get(
'https://' + interface + ':' + str(port) + '/',
verify=tls_ca_certificate_pem_path,
)
assert resp.status_code == 200
assert resp.text == 'Hello world!'
@pytest.mark.parametrize(
'adapter_type',
(
'builtin',
'pyopenssl',
),
)
@pytest.mark.parametrize(
'is_trusted_cert,tls_client_identity',
(
(True, 'localhost'), (True, '127.0.0.1'),
(True, '*.localhost'), (True, 'not_localhost'),
(False, 'localhost'),
),
)
@pytest.mark.parametrize(
'tls_verify_mode',
(
ssl.CERT_NONE, # server shouldn't validate client cert
ssl.CERT_OPTIONAL, # same as CERT_REQUIRED in client mode, don't use
ssl.CERT_REQUIRED, # server should validate if client cert CA is OK
),
)
def test_tls_client_auth(
# FIXME: remove twisted logic, separate tests
mocker,
tls_http_server, adapter_type,
ca,
tls_certificate,
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
tls_ca_certificate_pem_path,
is_trusted_cert, tls_client_identity,
tls_verify_mode,
):
"""Verify that client TLS certificate auth works correctly."""
test_cert_rejection = (
tls_verify_mode != ssl.CERT_NONE
and not is_trusted_cert
)
interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4)
client_cert_root_ca = ca if is_trusted_cert else trustme.CA()
with mocker.mock_module.patch(
'idna.core.ulabel',
return_value=ntob(tls_client_identity),
):
client_cert = client_cert_root_ca.issue_server_cert(
# FIXME: change to issue_cert once new trustme is out
ntou(tls_client_identity),
)
del client_cert_root_ca
with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem:
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
tls_adapter = tls_adapter_cls(
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
)
if adapter_type == 'pyopenssl':
tls_adapter.context = tls_adapter.get_context()
tls_adapter.context.set_verify(
_stdlib_to_openssl_verify[tls_verify_mode],
lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
)
else:
tls_adapter.context.verify_mode = tls_verify_mode
ca.configure_trust(tls_adapter.context)
tls_certificate.configure_cert(tls_adapter.context)
tlshttpserver = tls_http_server((interface, port), tls_adapter)
interface, _host, port = _get_conn_data(tlshttpserver.bind_addr)
make_https_request = functools.partial(
requests.get,
'https://' + interface + ':' + str(port) + '/',
# Server TLS certificate verification:
verify=tls_ca_certificate_pem_path,
# Client TLS certificate verification:
cert=cl_pem,
)
if not test_cert_rejection:
resp = make_https_request()
is_req_successful = resp.status_code == 200
if (
not is_req_successful
and IS_PYOPENSSL_SSL_VERSION_1_0
and adapter_type == 'builtin'
and tls_verify_mode == ssl.CERT_REQUIRED
and tls_client_identity == 'localhost'
and is_trusted_cert
) or PY34:
pytest.xfail(
'OpenSSL 1.0 has problems with verifying client certs',
)
assert is_req_successful
assert resp.text == 'Hello world!'
return
# xfail some flaky tests
# https://github.com/cherrypy/cheroot/issues/237
issue_237 = (
IS_MACOS
and adapter_type == 'builtin'
and tls_verify_mode != ssl.CERT_NONE
)
if issue_237:
pytest.xfail('Test sometimes fails')
expected_ssl_errors = (
requests.exceptions.SSLError,
OpenSSL.SSL.Error,
) if PY34 else (
requests.exceptions.SSLError,
)
if IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
expected_ssl_errors += requests.exceptions.ConnectionError,
with pytest.raises(expected_ssl_errors) as ssl_err:
make_https_request()
if PY34 and isinstance(ssl_err, OpenSSL.SSL.Error):
pytest.xfail(
'OpenSSL behaves wierdly under Python 3.4 '
'because of an outdated urllib3',
)
try:
err_text = ssl_err.value.args[0].reason.args[0].args[0]
except AttributeError:
if PY34:
pytest.xfail('OpenSSL behaves wierdly under Python 3.4')
elif not six.PY2 and IS_WINDOWS:
err_text = str(ssl_err.value)
else:
raise
expected_substrings = (
'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND
else 'tlsv1 alert unknown ca',
)
if not six.PY2:
if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl':
expected_substrings = ('tlsv1 alert unknown ca', )
if (
IS_WINDOWS
and tls_verify_mode in (
ssl.CERT_REQUIRED,
ssl.CERT_OPTIONAL,
)
and not is_trusted_cert
and tls_client_identity == 'localhost'
):
expected_substrings += (
'bad handshake: '
"SysCallError(10054, 'WSAECONNRESET')",
"('Connection aborted.', "
'OSError("(10054, \'WSAECONNRESET\')"))',
)
assert any(e in err_text for e in expected_substrings)
@pytest.mark.parametrize(
'ip_addr',
(
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
),
)
def test_https_over_http_error(http_server, ip_addr):
"""Ensure that connecting over HTTPS to HTTP port is handled."""
httpserver = http_server.send((ip_addr, EPHEMERAL_PORT))
interface, _host, port = _get_conn_data(httpserver.bind_addr)
with pytest.raises(ssl.SSLError) as ssl_err:
six.moves.http_client.HTTPSConnection(
'{interface}:{port}'.format(
interface=interface,
port=port,
),
).request('GET', '/')
expected_substring = (
'wrong version number' if IS_ABOVE_OPENSSL10
else 'unknown protocol'
)
assert expected_substring in ssl_err.value.args[-1]
@pytest.mark.parametrize(
'adapter_type',
(
'builtin',
'pyopenssl',
),
)
@pytest.mark.parametrize(
'ip_addr',
(
ANY_INTERFACE_IPV4,
pytest.param(ANY_INTERFACE_IPV6, marks=missing_ipv6),
),
)
def test_http_over_https_error(
tls_http_server, adapter_type,
ca, ip_addr,
tls_certificate,
tls_certificate_chain_pem_path,
tls_certificate_private_key_pem_path,
):
"""Ensure that connecting over HTTP to HTTPS port is handled."""
# disable some flaky tests
# https://github.com/cherrypy/cheroot/issues/225
issue_225 = (
IS_MACOS
and adapter_type == 'builtin'
)
if issue_225:
pytest.xfail('Test fails in Travis-CI')
tls_adapter_cls = get_ssl_adapter_class(name=adapter_type)
tls_adapter = tls_adapter_cls(
tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path,
)
if adapter_type == 'pyopenssl':
tls_adapter.context = tls_adapter.get_context()
tls_certificate.configure_cert(tls_adapter.context)
interface, _host, port = _get_conn_data(ip_addr)
tlshttpserver = tls_http_server((interface, port), tls_adapter)
interface, host, port = _get_conn_data(
tlshttpserver.bind_addr,
)
fqdn = interface
if ip_addr is ANY_INTERFACE_IPV6:
fqdn = '[{}]'.format(fqdn)
expect_fallback_response_over_plain_http = (
(adapter_type == 'pyopenssl'
and (IS_ABOVE_OPENSSL10 or not six.PY2))
or PY27
)
if expect_fallback_response_over_plain_http:
resp = requests.get(
'http://' + fqdn + ':' + str(port) + '/',
)
assert resp.status_code == 400
assert resp.text == (
'The client sent a plain HTTP request, '
'but this server only speaks HTTPS on this port.'
)
return
with pytest.raises(requests.exceptions.ConnectionError) as ssl_err:
requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL
'http://' + fqdn + ':' + str(port) + '/',
)
if IS_LINUX:
expected_error_code, expected_error_text = (
104, 'Connection reset by peer',
)
if IS_MACOS:
expected_error_code, expected_error_text = (
54, 'Connection reset by peer',
)
if IS_WINDOWS:
expected_error_code, expected_error_text = (
10054,
'An existing connection was forcibly closed by the remote host',
)
underlying_error = ssl_err.value.args[0].args[-1]
err_text = str(underlying_error)
assert underlying_error.errno == expected_error_code, (
'The underlying error is {!r}'.
format(underlying_error)
)
assert expected_error_text in err_text

605
lib/cheroot/test/webtest.py Normal file
View file

@ -0,0 +1,605 @@
"""Extensions to unittest for web frameworks.
Use the WebCase.getPage method to request a page from your HTTP server.
Framework Integration
=====================
If you have control over your server process, you can handle errors
in the server-side of the HTTP conversation a bit better. You must run
both the client (your WebCase tests) and the server in the same process
(but in separate threads, obviously).
When an error occurs in the framework, call server_error. It will print
the traceback to stdout, and keep any assertions you have from running
(the assumption is that, if the server errors, the page output will not
be of further significance to your tests).
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pprint
import re
import socket
import sys
import time
import traceback
import os
import json
import unittest
import warnings
import functools
from six.moves import http_client, map, urllib_parse
import six
from more_itertools.more import always_iterable
import jaraco.functools
def interface(host):
"""Return an IP address for a client connection given the server host.
If the server is listening on '0.0.0.0' (INADDR_ANY)
or '::' (IN6ADDR_ANY), this will return the proper localhost.
"""
if host == '0.0.0.0':
# INADDR_ANY, which should respond on localhost.
return '127.0.0.1'
if host == '::':
# IN6ADDR_ANY, which should respond on localhost.
return '::1'
return host
try:
# Jython support
if sys.platform[:4] == 'java':
def getchar():
"""Get a key press."""
# Hopefully this is enough
return sys.stdin.read(1)
else:
# On Windows, msvcrt.getch reads a single char without output.
import msvcrt
def getchar():
"""Get a key press."""
return msvcrt.getch()
except ImportError:
# Unix getchr
import tty
import termios
def getchar():
"""Get a key press."""
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch
# from jaraco.properties
class NonDataProperty:
"""Non-data property decorator."""
def __init__(self, fget):
"""Initialize a non-data property."""
assert fget is not None, 'fget cannot be none'
assert callable(fget), 'fget must be callable'
self.fget = fget
def __get__(self, obj, objtype=None):
"""Return a class property."""
if obj is None:
return self
return self.fget(obj)
class WebCase(unittest.TestCase):
"""Helper web test suite base."""
HOST = '127.0.0.1'
PORT = 8000
HTTP_CONN = http_client.HTTPConnection
PROTOCOL = 'HTTP/1.1'
scheme = 'http'
url = None
ssl_context = None
status = None
headers = None
body = None
encoding = 'utf-8'
time = None
@property
def _Conn(self):
"""Return HTTPConnection or HTTPSConnection based on self.scheme.
* from http.client.
"""
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
return getattr(http_client, cls_name)
def get_conn(self, auto_open=False):
"""Return a connection to our HTTP server."""
conn = self._Conn(self.interface(), self.PORT)
# Automatically re-connect?
conn.auto_open = auto_open
conn.connect()
return conn
def set_persistent(self, on=True, auto_open=False):
"""Make our HTTP_CONN persistent (or not).
If the 'on' argument is True (the default), then self.HTTP_CONN
will be set to an instance of HTTP(S)?Connection
to persist across requests.
As this class only allows for a single open connection, if
self already has an open connection, it will be closed.
"""
try:
self.HTTP_CONN.close()
except (TypeError, AttributeError):
pass
self.HTTP_CONN = (
self.get_conn(auto_open=auto_open)
if on
else self._Conn
)
@property
def persistent(self):
"""Presense of the persistent HTTP connection."""
return hasattr(self.HTTP_CONN, '__class__')
@persistent.setter
def persistent(self, on):
self.set_persistent(on)
def interface(self):
"""Return an IP address for a client connection.
If the server is listening on '0.0.0.0' (INADDR_ANY)
or '::' (IN6ADDR_ANY), this will return the proper localhost.
"""
return interface(self.HOST)
def getPage(
self, url, headers=None, method='GET', body=None,
protocol=None, raise_subcls=(),
):
"""Open the url with debugging support. Return status, headers, body.
url should be the identifier passed to the server, typically a
server-absolute path and query string (sent between method and
protocol), and should only be an absolute URI if proxy support is
enabled in the server.
If the application under test generates absolute URIs, be sure
to wrap them first with strip_netloc::
class MyAppWebCase(WebCase):
def getPage(url, *args, **kwargs):
super(MyAppWebCase, self).getPage(
cheroot.test.webtest.strip_netloc(url),
*args, **kwargs
)
`raise_subcls` is passed through to openURL.
"""
ServerError.on = False
if isinstance(url, six.text_type):
url = url.encode('utf-8')
if isinstance(body, six.text_type):
body = body.encode('utf-8')
# for compatibility, support raise_subcls is None
raise_subcls = raise_subcls or ()
self.url = url
self.time = None
start = time.time()
result = openURL(
url, headers, method, body, self.HOST, self.PORT,
self.HTTP_CONN, protocol or self.PROTOCOL,
raise_subcls=raise_subcls,
ssl_context=self.ssl_context,
)
self.time = time.time() - start
self.status, self.headers, self.body = result
# Build a list of request cookies from the previous response cookies.
self.cookies = [
('Cookie', v) for k, v in self.headers
if k.lower() == 'set-cookie'
]
if ServerError.on:
raise ServerError()
return result
@NonDataProperty
def interactive(self):
"""Determine whether tests are run in interactive mode.
Load interactivity setting from environment, where
the value can be numeric or a string like true or
False or 1 or 0.
"""
env_str = os.environ.get('WEBTEST_INTERACTIVE', 'True')
is_interactive = bool(json.loads(env_str.lower()))
if is_interactive:
warnings.warn(
'Interactive test failure interceptor support via '
'WEBTEST_INTERACTIVE environment variable is deprecated.',
DeprecationWarning,
)
return is_interactive
console_height = 30
def _handlewebError(self, msg):
print('')
print(' ERROR: %s' % msg)
if not self.interactive:
raise self.failureException(msg)
p = (
' Show: '
'[B]ody [H]eaders [S]tatus [U]RL; '
'[I]gnore, [R]aise, or sys.e[X]it >> '
)
sys.stdout.write(p)
sys.stdout.flush()
while True:
i = getchar().upper()
if not isinstance(i, type('')):
i = i.decode('ascii')
if i not in 'BHSUIRX':
continue
print(i.upper()) # Also prints new line
if i == 'B':
for x, line in enumerate(self.body.splitlines()):
if (x + 1) % self.console_height == 0:
# The \r and comma should make the next line overwrite
sys.stdout.write('<-- More -->\r')
m = getchar().lower()
# Erase our "More" prompt
sys.stdout.write(' \r')
if m == 'q':
break
print(line)
elif i == 'H':
pprint.pprint(self.headers)
elif i == 'S':
print(self.status)
elif i == 'U':
print(self.url)
elif i == 'I':
# return without raising the normal exception
return
elif i == 'R':
raise self.failureException(msg)
elif i == 'X':
sys.exit()
sys.stdout.write(p)
sys.stdout.flush()
@property
def status_code(self): # noqa: D401; irrelevant for properties
"""Integer HTTP status code."""
return int(self.status[:3])
def status_matches(self, expected):
"""Check whether actual status matches expected."""
actual = (
self.status_code
if isinstance(expected, int) else
self.status
)
return expected == actual
def assertStatus(self, status, msg=None):
"""Fail if self.status != status.
status may be integer code, exact string status, or
iterable of allowed possibilities.
"""
if any(map(self.status_matches, always_iterable(status))):
return
tmpl = 'Status {self.status} does not match {status}'
msg = msg or tmpl.format(**locals())
self._handlewebError(msg)
def assertHeader(self, key, value=None, msg=None):
"""Fail if (key, [value]) not in self.headers."""
lowkey = key.lower()
for k, v in self.headers:
if k.lower() == lowkey:
if value is None or str(value) == v:
return v
if msg is None:
if value is None:
msg = '%r not in headers' % key
else:
msg = '%r:%r not in headers' % (key, value)
self._handlewebError(msg)
def assertHeaderIn(self, key, values, msg=None):
"""Fail if header indicated by key doesn't have one of the values."""
lowkey = key.lower()
for k, v in self.headers:
if k.lower() == lowkey:
matches = [value for value in values if str(value) == v]
if matches:
return matches
if msg is None:
msg = '%(key)r not in %(values)r' % vars()
self._handlewebError(msg)
def assertHeaderItemValue(self, key, value, msg=None):
"""Fail if the header does not contain the specified value."""
actual_value = self.assertHeader(key, msg=msg)
header_values = map(str.strip, actual_value.split(','))
if value in header_values:
return value
if msg is None:
msg = '%r not in %r' % (value, header_values)
self._handlewebError(msg)
def assertNoHeader(self, key, msg=None):
"""Fail if key in self.headers."""
lowkey = key.lower()
matches = [k for k, v in self.headers if k.lower() == lowkey]
if matches:
if msg is None:
msg = '%r in headers' % key
self._handlewebError(msg)
def assertNoHeaderItemValue(self, key, value, msg=None):
"""Fail if the header contains the specified value."""
lowkey = key.lower()
hdrs = self.headers
matches = [k for k, v in hdrs if k.lower() == lowkey and v == value]
if matches:
if msg is None:
msg = '%r:%r in %r' % (key, value, hdrs)
self._handlewebError(msg)
def assertBody(self, value, msg=None):
"""Fail if value != self.body."""
if isinstance(value, six.text_type):
value = value.encode(self.encoding)
if value != self.body:
if msg is None:
msg = 'expected body:\n%r\n\nactual body:\n%r' % (
value, self.body,
)
self._handlewebError(msg)
def assertInBody(self, value, msg=None):
"""Fail if value not in self.body."""
if isinstance(value, six.text_type):
value = value.encode(self.encoding)
if value not in self.body:
if msg is None:
msg = '%r not in body: %s' % (value, self.body)
self._handlewebError(msg)
def assertNotInBody(self, value, msg=None):
"""Fail if value in self.body."""
if isinstance(value, six.text_type):
value = value.encode(self.encoding)
if value in self.body:
if msg is None:
msg = '%r found in body' % value
self._handlewebError(msg)
def assertMatchesBody(self, pattern, msg=None, flags=0):
"""Fail if value (a regex pattern) is not in self.body."""
if isinstance(pattern, six.text_type):
pattern = pattern.encode(self.encoding)
if re.search(pattern, self.body, flags) is None:
if msg is None:
msg = 'No match for %r in body' % pattern
self._handlewebError(msg)
methods_with_bodies = ('POST', 'PUT', 'PATCH')
def cleanHeaders(headers, method, body, host, port):
"""Return request headers, with required headers added (if missing)."""
if headers is None:
headers = []
# Add the required Host request header if not present.
# [This specifies the host:port of the server, not the client.]
found = False
for k, v in headers:
if k.lower() == 'host':
found = True
break
if not found:
if port == 80:
headers.append(('Host', host))
else:
headers.append(('Host', '%s:%s' % (host, port)))
if method in methods_with_bodies:
# Stick in default type and length headers if not present
found = False
for k, v in headers:
if k.lower() == 'content-type':
found = True
break
if not found:
headers.append(
('Content-Type', 'application/x-www-form-urlencoded'),
)
headers.append(('Content-Length', str(len(body or ''))))
return headers
def shb(response):
"""Return status, headers, body the way we like from a response."""
resp_status_line = '%s %s' % (response.status, response.reason)
if not six.PY2:
return resp_status_line, response.getheaders(), response.read()
h = []
key, value = None, None
for line in response.msg.headers:
if line:
if line[0] in ' \t':
value += line.strip()
else:
if key and value:
h.append((key, value))
key, value = line.split(':', 1)
key = key.strip()
value = value.strip()
if key and value:
h.append((key, value))
return resp_status_line, h, response.read()
# def openURL(*args, raise_subcls=(), **kwargs):
# py27 compatible signature:
def openURL(*args, **kwargs):
"""
Open a URL, retrying when it fails.
Specify `raise_subcls` (class or tuple of classes) to exclude
those socket.error subclasses from being suppressed and retried.
"""
raise_subcls = kwargs.pop('raise_subcls', ())
opener = functools.partial(_open_url_once, *args, **kwargs)
def on_exception():
type_, exc = sys.exc_info()[:2]
if isinstance(exc, raise_subcls):
raise
time.sleep(0.5)
# Try up to 10 times
return jaraco.functools.retry_call(
opener,
retries=9,
cleanup=on_exception,
trap=socket.error,
)
def _open_url_once(
url, headers=None, method='GET', body=None,
host='127.0.0.1', port=8000, http_conn=http_client.HTTPConnection,
protocol='HTTP/1.1', ssl_context=None,
):
"""Open the given HTTP resource and return status, headers, and body."""
headers = cleanHeaders(headers, method, body, host, port)
# Allow http_conn to be a class or an instance
if hasattr(http_conn, 'host'):
conn = http_conn
else:
kw = {}
if ssl_context:
kw['context'] = ssl_context
conn = http_conn(interface(host), port, **kw)
conn._http_vsn_str = protocol
conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()]))
if not six.PY2 and isinstance(url, bytes):
url = url.decode()
conn.putrequest(
method.upper(), url, skip_host=True,
skip_accept_encoding=True,
)
for key, value in headers:
conn.putheader(key, value.encode('Latin-1'))
conn.endheaders()
if body is not None:
conn.send(body)
# Handle response
response = conn.getresponse()
s, h, b = shb(response)
if not hasattr(http_conn, 'host'):
# We made our own conn instance. Close it.
conn.close()
return s, h, b
def strip_netloc(url):
"""Return absolute-URI path from URL.
Strip the scheme and host from the URL, returning the
server-absolute portion.
Useful for wrapping an absolute-URI for which only the
path is expected (such as in calls to getPage).
>>> strip_netloc('https://google.com/foo/bar?bing#baz')
'/foo/bar?bing'
>>> strip_netloc('//google.com/foo/bar?bing#baz')
'/foo/bar?bing'
>>> strip_netloc('/foo/bar?bing#baz')
'/foo/bar?bing'
"""
parsed = urllib_parse.urlparse(url)
scheme, netloc, path, params, query, fragment = parsed
stripped = '', '', path, params, query, ''
return urllib_parse.urlunparse(stripped)
# Add any exceptions which your web framework handles
# normally (that you don't want server_error to trap).
ignored_exceptions = []
# You'll want set this to True when you can't guarantee
# that each response will immediately follow each request;
# for example, when handling requests via multiple threads.
ignore_all = False
class ServerError(Exception):
"""Exception for signalling server error."""
on = False
def server_error(exc=None):
"""Server debug hook.
Return True if exception handled, False if ignored.
You probably want to wrap this, so you can still handle an error using
your framework when it's ignored.
"""
if exc is None:
exc = sys.exc_info()
if ignore_all or exc[0] in ignored_exceptions:
return False
else:
ServerError.on = True
print('')
print(''.join(traceback.format_exception(*exc)))
return True