Update urllib3-1.26.7

This commit is contained in:
JonnyWong16 2021-10-14 21:00:02 -07:00
parent a3bfabb5f6
commit b6595232d2
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
38 changed files with 4375 additions and 2823 deletions

View file

@ -1,54 +1,43 @@
""" """
urllib3 - Thread-safe connection pooling and re-using. Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more
""" """
from __future__ import absolute_import from __future__ import absolute_import
import warnings
from .connectionpool import ( # Set default logging handler to avoid "No handler found" warnings.
HTTPConnectionPool, import logging
HTTPSConnectionPool, import warnings
connection_from_url from logging import NullHandler
)
from . import exceptions from . import exceptions
from ._version import __version__
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from .filepost import encode_multipart_formdata from .filepost import encode_multipart_formdata
from .poolmanager import PoolManager, ProxyManager, proxy_from_url from .poolmanager import PoolManager, ProxyManager, proxy_from_url
from .response import HTTPResponse from .response import HTTPResponse
from .util.request import make_headers from .util.request import make_headers
from .util.url import get_host
from .util.timeout import Timeout
from .util.retry import Retry from .util.retry import Retry
from .util.timeout import Timeout
from .util.url import get_host
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
# Set default logging handler to avoid "No handler found" warnings. __license__ = "MIT"
import logging __version__ = __version__
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)'
__license__ = 'MIT'
__version__ = '1.22'
__all__ = ( __all__ = (
'HTTPConnectionPool', "HTTPConnectionPool",
'HTTPSConnectionPool', "HTTPSConnectionPool",
'PoolManager', "PoolManager",
'ProxyManager', "ProxyManager",
'HTTPResponse', "HTTPResponse",
'Retry', "Retry",
'Timeout', "Timeout",
'add_stderr_logger', "add_stderr_logger",
'connection_from_url', "connection_from_url",
'disable_warnings', "disable_warnings",
'encode_multipart_formdata', "encode_multipart_formdata",
'get_host', "get_host",
'make_headers', "make_headers",
'proxy_from_url', "proxy_from_url",
) )
logging.getLogger(__name__).addHandler(NullHandler()) logging.getLogger(__name__).addHandler(NullHandler())
@ -65,10 +54,10 @@ def add_stderr_logger(level=logging.DEBUG):
# even if urllib3 is vendored within another package. # even if urllib3 is vendored within another package.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(level) logger.setLevel(level)
logger.debug('Added a stderr logging handler to logger: %s', __name__) logger.debug("Added a stderr logging handler to logger: %s", __name__)
return handler return handler
@ -80,18 +69,17 @@ del NullHandler
# shouldn't be: otherwise, it's very hard for users to use most Python # shouldn't be: otherwise, it's very hard for users to use most Python
# mechanisms to silence them. # mechanisms to silence them.
# SecurityWarning's always go off by default. # SecurityWarning's always go off by default.
warnings.simplefilter('always', exceptions.SecurityWarning, append=True) warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host # SubjectAltNameWarning's should go off once per host
warnings.simplefilter('default', exceptions.SubjectAltNameWarning, append=True) warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default. # InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter('default', exceptions.InsecurePlatformWarning, warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
append=True)
# SNIMissingWarnings should go off only once. # SNIMissingWarnings should go off only once.
warnings.simplefilter('default', exceptions.SNIMissingWarning, append=True) warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
def disable_warnings(category=exceptions.HTTPWarning): def disable_warnings(category=exceptions.HTTPWarning):
""" """
Helper for quickly disabling all urllib3 warnings. Helper for quickly disabling all urllib3 warnings.
""" """
warnings.simplefilter('ignore', category) warnings.simplefilter("ignore", category)

View file

@ -1,8 +1,13 @@
from __future__ import absolute_import from __future__ import absolute_import
try:
from collections.abc import Mapping, MutableMapping
except ImportError:
from collections import Mapping, MutableMapping from collections import Mapping, MutableMapping
try: try:
from threading import RLock from threading import RLock
except ImportError: # Platform-specific: No threads available except ImportError: # Platform-specific: No threads available
class RLock: class RLock:
def __enter__(self): def __enter__(self):
pass pass
@ -11,14 +16,13 @@ except ImportError: # Platform-specific: No threads available
pass pass
try: # Python 2.7+
from collections import OrderedDict from collections import OrderedDict
except ImportError:
from .packages.ordered_dict import OrderedDict
from .packages.six import iterkeys, itervalues, PY3
from .exceptions import InvalidHeader
from .packages import six
from .packages.six import iterkeys, itervalues
__all__ = ['RecentlyUsedContainer', 'HTTPHeaderDict'] __all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
_Null = object() _Null = object()
@ -81,7 +85,9 @@ class RecentlyUsedContainer(MutableMapping):
return len(self._container) return len(self._container)
def __iter__(self): def __iter__(self):
raise NotImplementedError('Iteration over this class is unlikely to be threadsafe.') raise NotImplementedError(
"Iteration over this class is unlikely to be threadsafe."
)
def clear(self): def clear(self):
with self.lock: with self.lock:
@ -149,7 +155,7 @@ class HTTPHeaderDict(MutableMapping):
def __getitem__(self, key): def __getitem__(self, key):
val = self._container[key.lower()] val = self._container[key.lower()]
return ', '.join(val[1:]) return ", ".join(val[1:])
def __delitem__(self, key): def __delitem__(self, key):
del self._container[key.lower()] del self._container[key.lower()]
@ -158,17 +164,18 @@ class HTTPHeaderDict(MutableMapping):
return key.lower() in self._container return key.lower() in self._container
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'): if not isinstance(other, Mapping) and not hasattr(other, "keys"):
return False return False
if not isinstance(other, type(self)): if not isinstance(other, type(self)):
other = type(self)(other) other = type(self)(other)
return (dict((k.lower(), v) for k, v in self.itermerged()) == return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
dict((k.lower(), v) for k, v in other.itermerged())) (k.lower(), v) for k, v in other.itermerged()
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
if not PY3: # Python 2 if six.PY2: # Python 2
iterkeys = MutableMapping.iterkeys iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues itervalues = MutableMapping.itervalues
@ -183,9 +190,9 @@ class HTTPHeaderDict(MutableMapping):
yield vals[0] yield vals[0]
def pop(self, key, default=__marker): def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. """D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised. If key is not found, d is returned if given, otherwise KeyError is raised.
''' """
# Using the MutableMapping function directly fails due to the private marker. # Using the MutableMapping function directly fails due to the private marker.
# Using ordinary dict.pop would expose the internal structures. # Using ordinary dict.pop would expose the internal structures.
# So let's reinvent the wheel. # So let's reinvent the wheel.
@ -227,8 +234,10 @@ class HTTPHeaderDict(MutableMapping):
with self.add instead of self.__setitem__ with self.add instead of self.__setitem__
""" """
if len(args) > 1: if len(args) > 1:
raise TypeError("extend() takes at most 1 positional " raise TypeError(
"arguments ({0} given)".format(len(args))) "extend() takes at most 1 positional "
"arguments ({0} given)".format(len(args))
)
other = args[0] if len(args) >= 1 else () other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict): if isinstance(other, HTTPHeaderDict):
@ -294,7 +303,7 @@ class HTTPHeaderDict(MutableMapping):
"""Iterate over all headers, merging duplicate ones together.""" """Iterate over all headers, merging duplicate ones together."""
for key in self: for key in self:
val = self._container[key.lower()] val = self._container[key.lower()]
yield val[0], ', '.join(val[1:]) yield val[0], ", ".join(val[1:])
def items(self): def items(self):
return list(self.iteritems()) return list(self.iteritems())
@ -305,15 +314,24 @@ class HTTPHeaderDict(MutableMapping):
# python2.7 does not expose a proper API for exporting multiheaders # python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message # efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly. # object and extracts the multiheaders properly.
obs_fold_continued_leaders = (" ", "\t")
headers = [] headers = []
for line in message.headers: for line in message.headers:
if line.startswith((' ', '\t')): if line.startswith(obs_fold_continued_leaders):
if not headers:
# We received a header line that starts with OWS as described
# in RFC-7230 S3.2.4. This indicates a multiline header, but
# there exists no previous header to which we can attach it.
raise InvalidHeader(
"Header continuation with no previous header: %s" % line
)
else:
key, value = headers[-1] key, value = headers[-1]
headers[-1] = (key, value + '\r\n' + line.rstrip()) headers[-1] = (key, value + " " + line.strip())
continue continue
key, value = line.split(':', 1) key, value = line.split(":", 1)
headers.append((key, value.strip())) headers.append((key, value.strip()))
return cls(headers) return cls(headers)

2
lib/urllib3/_version.py Normal file
View file

@ -0,0 +1,2 @@
# This file is protected via CODEOWNERS
__version__ = "1.26.7"

View file

@ -1,17 +1,22 @@
from __future__ import absolute_import from __future__ import absolute_import
import datetime import datetime
import logging import logging
import os import os
import sys import re
import socket import socket
from socket import error as SocketError, timeout as SocketTimeout
import warnings import warnings
from socket import error as SocketError
from socket import timeout as SocketTimeout
from .packages import six from .packages import six
from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection
from .packages.six.moves.http_client import HTTPException # noqa: F401 from .packages.six.moves.http_client import HTTPException # noqa: F401
from .util.proxy import create_proxy_ssl_context
try: # Compiled with SSL? try: # Compiled with SSL?
import ssl import ssl
BaseSSLError = ssl.SSLError BaseSSLError = ssl.SSLError
except (ImportError, AttributeError): # Platform-specific: No SSL. except (ImportError, AttributeError): # Platform-specific: No SSL.
ssl = None ssl = None
@ -20,56 +25,57 @@ except (ImportError, AttributeError): # Platform-specific: No SSL.
pass pass
try: # Python 3: try:
# Not a no-op, we're adding this to the namespace so it can be imported. # Python 3: not a no-op, we're adding this to the namespace so it can be imported.
ConnectionError = ConnectionError ConnectionError = ConnectionError
except NameError: # Python 2: except NameError:
# Python 2
class ConnectionError(Exception): class ConnectionError(Exception):
pass pass
try: # Python 3:
# Not a no-op, we're adding this to the namespace so it can be imported.
BrokenPipeError = BrokenPipeError
except NameError: # Python 2:
class BrokenPipeError(Exception):
pass
from ._collections import HTTPHeaderDict # noqa (historical, removed in v2)
from ._version import __version__
from .exceptions import ( from .exceptions import (
NewConnectionError,
ConnectTimeoutError, ConnectTimeoutError,
NewConnectionError,
SubjectAltNameWarning, SubjectAltNameWarning,
SystemTimeWarning, SystemTimeWarning,
) )
from .packages.ssl_match_hostname import match_hostname, CertificateError from .packages.ssl_match_hostname import CertificateError, match_hostname
from .util import SKIP_HEADER, SKIPPABLE_HEADERS, connection
from .util.ssl_ import ( from .util.ssl_ import (
resolve_cert_reqs,
resolve_ssl_version,
assert_fingerprint, assert_fingerprint,
create_urllib3_context, create_urllib3_context,
ssl_wrap_socket is_ipaddress,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
) )
from .util import connection
from ._collections import HTTPHeaderDict
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
port_by_scheme = { port_by_scheme = {"http": 80, "https": 443}
'http': 80,
'https': 443,
}
# When updating RECENT_DATE, move it to # When it comes time to update this value as a part of regular maintenance
# within two years of the current date, and no # (ie test_recent_date is failing) update it to ~6 months before the current date.
# earlier than 6 months ago. RECENT_DATE = datetime.date(2020, 7, 1)
RECENT_DATE = datetime.date(2016, 1, 1)
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
class DummyConnection(object):
"""Used to detect a failed ConnectionCls import."""
pass
class HTTPConnection(_HTTPConnection, object): class HTTPConnection(_HTTPConnection, object):
""" """
Based on httplib.HTTPConnection but provides an extra constructor Based on :class:`http.client.HTTPConnection` but provides an extra constructor
backwards-compatibility layer between older and newer Pythons. backwards-compatibility layer between older and newer Pythons.
Additional keyword parameters are used to configure attributes of the connection. Additional keyword parameters are used to configure attributes of the connection.
@ -77,15 +83,14 @@ class HTTPConnection(_HTTPConnection, object):
- ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool` - ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool`
- ``source_address``: Set the source address for the current connection. - ``source_address``: Set the source address for the current connection.
.. note:: This is ignored for Python 2.6. It is only applied for 2.7 and 3.x
- ``socket_options``: Set specific options on the underlying socket. If not specified, then - ``socket_options``: Set specific options on the underlying socket. If not specified, then
defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling
Nagle's algorithm (sets TCP_NODELAY to 1) unless the connection is behind a proxy. Nagle's algorithm (sets TCP_NODELAY to 1) unless the connection is behind a proxy.
For example, if you wish to enable TCP Keep Alive in addition to the defaults, For example, if you wish to enable TCP Keep Alive in addition to the defaults,
you might pass:: you might pass:
.. code-block:: python
HTTPConnection.default_socket_options + [ HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
@ -94,7 +99,7 @@ class HTTPConnection(_HTTPConnection, object):
Or you may want to disable the defaults by passing an empty list (e.g., ``[]``). Or you may want to disable the defaults by passing an empty list (e.g., ``[]``).
""" """
default_port = port_by_scheme['http'] default_port = port_by_scheme["http"]
#: Disable Nagle's algorithm by default. #: Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` #: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
@ -103,27 +108,56 @@ class HTTPConnection(_HTTPConnection, object):
#: Whether this connection verifies the host's certificate. #: Whether this connection verifies the host's certificate.
is_verified = False is_verified = False
#: Whether this proxy connection (if used) verifies the proxy host's
#: certificate.
proxy_is_verified = None
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
if six.PY3: # Python 3 if not six.PY2:
kw.pop('strict', None) kw.pop("strict", None)
# Pre-set source_address in case we have an older Python like 2.6. # Pre-set source_address.
self.source_address = kw.get('source_address') self.source_address = kw.get("source_address")
if sys.version_info < (2, 7): # Python 2.6
# _HTTPConnection on Python 2.6 will balk at this keyword arg, but
# not newer versions. We can still use it when creating a
# connection though, so we pop it *after* we have saved it as
# self.source_address.
kw.pop('source_address', None)
#: The socket options provided by the user. If no options are #: The socket options provided by the user. If no options are
#: provided, we use the default options. #: provided, we use the default options.
self.socket_options = kw.pop('socket_options', self.default_socket_options) self.socket_options = kw.pop("socket_options", self.default_socket_options)
# Proxy options provided by the user.
self.proxy = kw.pop("proxy", None)
self.proxy_config = kw.pop("proxy_config", None)
# Superclass also sets self.source_address in Python 2.7+.
_HTTPConnection.__init__(self, *args, **kw) _HTTPConnection.__init__(self, *args, **kw)
@property
def host(self):
"""
Getter method to remove any trailing dots that indicate the hostname is an FQDN.
In general, SSL certificates don't include the trailing dot indicating a
fully-qualified domain name, and thus, they don't validate properly when
checked against a domain name that includes the dot. In addition, some
servers may not expect to receive the trailing dot when provided.
However, the hostname with trailing dot is critical to DNS resolution; doing a
lookup with the trailing dot will properly only resolve the appropriate FQDN,
whereas a lookup without a trailing dot will search the system's search domain
list. Thus, it's important to keep the original host around for use only in
those cases where it's appropriate (i.e., when doing DNS lookup to establish the
actual TCP connection across which we're going to send HTTP requests).
"""
return self._dns_host.rstrip(".")
@host.setter
def host(self, value):
"""
Setter for the `host` property.
We assume that only urllib3 uses the _dns_host attribute; httplib itself
only uses `host`, and it seems reasonable that other libraries follow suit.
"""
self._dns_host = value
def _new_conn(self): def _new_conn(self):
"""Establish a socket connection and set nodelay settings on it. """Establish a socket connection and set nodelay settings on it.
@ -131,32 +165,37 @@ class HTTPConnection(_HTTPConnection, object):
""" """
extra_kw = {} extra_kw = {}
if self.source_address: if self.source_address:
extra_kw['source_address'] = self.source_address extra_kw["source_address"] = self.source_address
if self.socket_options: if self.socket_options:
extra_kw['socket_options'] = self.socket_options extra_kw["socket_options"] = self.socket_options
try: try:
conn = connection.create_connection( conn = connection.create_connection(
(self.host, self.port), self.timeout, **extra_kw) (self._dns_host, self.port), self.timeout, **extra_kw
)
except SocketTimeout as e: except SocketTimeout:
raise ConnectTimeoutError( raise ConnectTimeoutError(
self, "Connection to %s timed out. (connect timeout=%s)" % self,
(self.host, self.timeout)) "Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except SocketError as e: except SocketError as e:
raise NewConnectionError( raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e) self, "Failed to establish a new connection: %s" % e
)
return conn return conn
def _is_using_tunnel(self):
# Google App Engine's httplib does not define _tunnel_host
return getattr(self, "_tunnel_host", None)
def _prepare_conn(self, conn): def _prepare_conn(self, conn):
self.sock = conn self.sock = conn
# the _tunnel_host attribute was added in python 2.6.3 (via if self._is_using_tunnel():
# http://hg.python.org/cpython/rev/0f57b30a152f) so pythons 2.6(0-2) do
# not have them.
if getattr(self, '_tunnel_host', None):
# TODO: Fix tunnel so it doesn't depend on self.sock state. # TODO: Fix tunnel so it doesn't depend on self.sock state.
self._tunnel() self._tunnel()
# Mark this connection as not reusable # Mark this connection as not reusable
@ -166,129 +205,167 @@ class HTTPConnection(_HTTPConnection, object):
conn = self._new_conn() conn = self._new_conn()
self._prepare_conn(conn) self._prepare_conn(conn)
def putrequest(self, method, url, *args, **kwargs):
""" """
# Empty docstring because the indentation of CPython's implementation
# is broken but we don't want this method in our documentation.
match = _CONTAINS_CONTROL_CHAR_RE.search(method)
if match:
raise ValueError(
"Method cannot contain non-token characters %r (found at least %r)"
% (method, match.group())
)
return _HTTPConnection.putrequest(self, method, url, *args, **kwargs)
def putheader(self, header, *values):
""" """
if not any(isinstance(v, str) and v == SKIP_HEADER for v in values):
_HTTPConnection.putheader(self, header, *values)
elif six.ensure_str(header.lower()) not in SKIPPABLE_HEADERS:
raise ValueError(
"urllib3.util.SKIP_HEADER only supports '%s'"
% ("', '".join(map(str.title, sorted(SKIPPABLE_HEADERS))),)
)
def request(self, method, url, body=None, headers=None):
if headers is None:
headers = {}
else:
# Avoid modifying the headers passed into .request()
headers = headers.copy()
if "user-agent" not in (six.ensure_str(k.lower()) for k in headers):
headers["User-Agent"] = _get_default_user_agent()
super(HTTPConnection, self).request(method, url, body=body, headers=headers)
def request_chunked(self, method, url, body=None, headers=None): def request_chunked(self, method, url, body=None, headers=None):
""" """
Alternative to the common request method, which sends the Alternative to the common request method, which sends the
body with chunked encoding and not as one block body with chunked encoding and not as one block
""" """
headers = HTTPHeaderDict(headers if headers is not None else {}) headers = headers or {}
skip_accept_encoding = 'accept-encoding' in headers header_keys = set([six.ensure_str(k.lower()) for k in headers])
skip_host = 'host' in headers skip_accept_encoding = "accept-encoding" in header_keys
skip_host = "host" in header_keys
self.putrequest( self.putrequest(
method, method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host
url,
skip_accept_encoding=skip_accept_encoding,
skip_host=skip_host
) )
if "user-agent" not in header_keys:
self.putheader("User-Agent", _get_default_user_agent())
for header, value in headers.items(): for header, value in headers.items():
self.putheader(header, value) self.putheader(header, value)
if 'transfer-encoding' not in headers: if "transfer-encoding" not in header_keys:
self.putheader('Transfer-Encoding', 'chunked') self.putheader("Transfer-Encoding", "chunked")
self.endheaders() self.endheaders()
if body is not None: if body is not None:
stringish_types = six.string_types + (six.binary_type,) stringish_types = six.string_types + (bytes,)
if isinstance(body, stringish_types): if isinstance(body, stringish_types):
body = (body,) body = (body,)
for chunk in body: for chunk in body:
if not chunk: if not chunk:
continue continue
if not isinstance(chunk, six.binary_type): if not isinstance(chunk, bytes):
chunk = chunk.encode('utf8') chunk = chunk.encode("utf8")
len_str = hex(len(chunk))[2:] len_str = hex(len(chunk))[2:]
self.send(len_str.encode('utf-8')) to_send = bytearray(len_str.encode())
self.send(b'\r\n') to_send += b"\r\n"
self.send(chunk) to_send += chunk
self.send(b'\r\n') to_send += b"\r\n"
self.send(to_send)
# After the if clause, to always have a closed body # After the if clause, to always have a closed body
self.send(b'0\r\n\r\n') self.send(b"0\r\n\r\n")
class HTTPSConnection(HTTPConnection): class HTTPSConnection(HTTPConnection):
default_port = port_by_scheme['https']
ssl_version = None
def __init__(self, host, port=None, key_file=None, cert_file=None,
strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
ssl_context=None, **kw):
HTTPConnection.__init__(self, host, port, strict=strict,
timeout=timeout, **kw)
self.key_file = key_file
self.cert_file = cert_file
self.ssl_context = ssl_context
# Required property for Google AppEngine 1.9.0 which otherwise causes
# HTTPS requests to go out as HTTP. (See Issue #356)
self._protocol = 'https'
def connect(self):
conn = self._new_conn()
self._prepare_conn(conn)
if self.ssl_context is None:
self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(None),
cert_reqs=resolve_cert_reqs(None),
)
self.sock = ssl_wrap_socket(
sock=conn,
keyfile=self.key_file,
certfile=self.cert_file,
ssl_context=self.ssl_context,
)
class VerifiedHTTPSConnection(HTTPSConnection):
""" """
Based on httplib.HTTPSConnection but wraps the socket with Many of the parameters to this constructor are passed to the underlying SSL
SSL certification. socket by means of :py:func:`urllib3.util.ssl_wrap_socket`.
""" """
default_port = port_by_scheme["https"]
cert_reqs = None cert_reqs = None
ca_certs = None ca_certs = None
ca_cert_dir = None ca_cert_dir = None
ca_cert_data = None
ssl_version = None ssl_version = None
assert_fingerprint = None assert_fingerprint = None
tls_in_tls_required = False
def set_cert(self, key_file=None, cert_file=None, def __init__(
cert_reqs=None, ca_certs=None, self,
assert_hostname=None, assert_fingerprint=None, host,
ca_cert_dir=None): port=None,
key_file=None,
cert_file=None,
key_password=None,
strict=None,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
ssl_context=None,
server_hostname=None,
**kw
):
HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password
self.ssl_context = ssl_context
self.server_hostname = server_hostname
# Required property for Google AppEngine 1.9.0 which otherwise causes
# HTTPS requests to go out as HTTP. (See Issue #356)
self._protocol = "https"
def set_cert(
self,
key_file=None,
cert_file=None,
cert_reqs=None,
key_password=None,
ca_certs=None,
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
ca_cert_data=None,
):
""" """
This method should only be called once, before the connection is used. This method should only be called once, before the connection is used.
""" """
# If cert_reqs is not provided, we can try to guess. If the user gave # If cert_reqs is not provided we'll assume CERT_REQUIRED unless we also
# us a cert database, we assume they want to use it: otherwise, if # have an SSLContext object in which case we'll use its verify_mode.
# they gave us an SSL Context object we should use whatever is set for
# it.
if cert_reqs is None: if cert_reqs is None:
if ca_certs or ca_cert_dir: if self.ssl_context is not None:
cert_reqs = 'CERT_REQUIRED'
elif self.ssl_context is not None:
cert_reqs = self.ssl_context.verify_mode cert_reqs = self.ssl_context.verify_mode
else:
cert_reqs = resolve_cert_reqs(None)
self.key_file = key_file self.key_file = key_file
self.cert_file = cert_file self.cert_file = cert_file
self.cert_reqs = cert_reqs self.cert_reqs = cert_reqs
self.key_password = key_password
self.assert_hostname = assert_hostname self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint self.assert_fingerprint = assert_fingerprint
self.ca_certs = ca_certs and os.path.expanduser(ca_certs) self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir) self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
self.ca_cert_data = ca_cert_data
def connect(self): def connect(self):
# Add certificate verification # Add certificate verification
conn = self._new_conn() conn = self._new_conn()
hostname = self.host hostname = self.host
if getattr(self, '_tunnel_host', None): tls_in_tls = False
# _tunnel_host was added in Python 2.6.3
# (See: http://hg.python.org/cpython/rev/0f57b30a152f) if self._is_using_tunnel():
if self.tls_in_tls_required:
conn = self._connect_tls_proxy(hostname, conn)
tls_in_tls = True
self.sock = conn self.sock = conn
# Calls self._set_hostport(), so self.host is # Calls self._set_hostport(), so self.host is
# self._tunnel_host below. # self._tunnel_host below.
self._tunnel() self._tunnel()
@ -298,17 +375,25 @@ class VerifiedHTTPSConnection(HTTPSConnection):
# Override the host with the one we're requesting data from. # Override the host with the one we're requesting data from.
hostname = self._tunnel_host hostname = self._tunnel_host
server_hostname = hostname
if self.server_hostname is not None:
server_hostname = self.server_hostname
is_time_off = datetime.date.today() < RECENT_DATE is_time_off = datetime.date.today() < RECENT_DATE
if is_time_off: if is_time_off:
warnings.warn(( warnings.warn(
'System time is way off (before {0}). This will probably ' (
'lead to SSL verification errors').format(RECENT_DATE), "System time is way off (before {0}). This will probably "
SystemTimeWarning "lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
) )
# Wrap socket using verification with the root certs in # Wrap socket using verification with the root certs in
# trusted_root_certs # trusted_root_certs
default_ssl_context = False
if self.ssl_context is None: if self.ssl_context is None:
default_ssl_context = True
self.ssl_context = create_urllib3_context( self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(self.ssl_version), ssl_version=resolve_ssl_version(self.ssl_version),
cert_reqs=resolve_cert_reqs(self.cert_reqs), cert_reqs=resolve_cert_reqs(self.cert_reqs),
@ -316,48 +401,150 @@ class VerifiedHTTPSConnection(HTTPSConnection):
context = self.ssl_context context = self.ssl_context
context.verify_mode = resolve_cert_reqs(self.cert_reqs) context.verify_mode = resolve_cert_reqs(self.cert_reqs)
# Try to load OS default certs if none are given.
# Works well on Windows (requires Python3.4+)
if (
not self.ca_certs
and not self.ca_cert_dir
and not self.ca_cert_data
and default_ssl_context
and hasattr(context, "load_default_certs")
):
context.load_default_certs()
self.sock = ssl_wrap_socket( self.sock = ssl_wrap_socket(
sock=conn, sock=conn,
keyfile=self.key_file, keyfile=self.key_file,
certfile=self.cert_file, certfile=self.cert_file,
key_password=self.key_password,
ca_certs=self.ca_certs, ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir, ca_cert_dir=self.ca_cert_dir,
server_hostname=hostname, ca_cert_data=self.ca_cert_data,
ssl_context=context) server_hostname=server_hostname,
ssl_context=context,
tls_in_tls=tls_in_tls,
)
# If we're using all defaults and the connection
# is TLSv1 or TLSv1.1 we throw a DeprecationWarning
# for the host.
if (
default_ssl_context
and self.ssl_version is None
and hasattr(self.sock, "version")
and self.sock.version() in {"TLSv1", "TLSv1.1"}
):
warnings.warn(
"Negotiating TLSv1/TLSv1.1 by default is deprecated "
"and will be disabled in urllib3 v2.0.0. Connecting to "
"'%s' with '%s' can be enabled by explicitly opting-in "
"with 'ssl_version'" % (self.host, self.sock.version()),
DeprecationWarning,
)
if self.assert_fingerprint: if self.assert_fingerprint:
assert_fingerprint(self.sock.getpeercert(binary_form=True), assert_fingerprint(
self.assert_fingerprint) self.sock.getpeercert(binary_form=True), self.assert_fingerprint
elif context.verify_mode != ssl.CERT_NONE \ )
and not getattr(context, 'check_hostname', False) \ elif (
and self.assert_hostname is not False: context.verify_mode != ssl.CERT_NONE
and not getattr(context, "check_hostname", False)
and self.assert_hostname is not False
):
# While urllib3 attempts to always turn off hostname matching from # While urllib3 attempts to always turn off hostname matching from
# the TLS library, this cannot always be done. So we check whether # the TLS library, this cannot always be done. So we check whether
# the TLS Library still thinks it's matching hostnames. # the TLS Library still thinks it's matching hostnames.
cert = self.sock.getpeercert() cert = self.sock.getpeercert()
if not cert.get('subjectAltName', ()): if not cert.get("subjectAltName", ()):
warnings.warn(( warnings.warn(
'Certificate for {0} has no `subjectAltName`, falling back to check for a ' (
'`commonName` for now. This feature is being removed by major browsers and ' "Certificate for {0} has no `subjectAltName`, falling back to check for a "
'deprecated by RFC 2818. (See https://github.com/shazow/urllib3/issues/497 ' "`commonName` for now. This feature is being removed by major browsers and "
'for details.)'.format(hostname)), "deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 "
SubjectAltNameWarning "for details.)".format(hostname)
),
SubjectAltNameWarning,
) )
_match_hostname(cert, self.assert_hostname or hostname) _match_hostname(cert, self.assert_hostname or server_hostname)
self.is_verified = ( self.is_verified = (
context.verify_mode == ssl.CERT_REQUIRED or context.verify_mode == ssl.CERT_REQUIRED
self.assert_fingerprint is not None or self.assert_fingerprint is not None
) )
def _connect_tls_proxy(self, hostname, conn):
"""
Establish a TLS connection to the proxy using the provided SSL context.
"""
proxy_config = self.proxy_config
ssl_context = proxy_config.ssl_context
if ssl_context:
# If the user provided a proxy context, we assume CA and client
# certificates have already been set
return ssl_wrap_socket(
sock=conn,
server_hostname=hostname,
ssl_context=ssl_context,
)
ssl_context = create_proxy_ssl_context(
self.ssl_version,
self.cert_reqs,
self.ca_certs,
self.ca_cert_dir,
self.ca_cert_data,
)
# If no cert was provided, use only the default options for server
# certificate validation
socket = ssl_wrap_socket(
sock=conn,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
ca_cert_data=self.ca_cert_data,
server_hostname=hostname,
ssl_context=ssl_context,
)
if ssl_context.verify_mode != ssl.CERT_NONE and not getattr(
ssl_context, "check_hostname", False
):
# While urllib3 attempts to always turn off hostname matching from
# the TLS library, this cannot always be done. So we check whether
# the TLS Library still thinks it's matching hostnames.
cert = socket.getpeercert()
if not cert.get("subjectAltName", ()):
warnings.warn(
(
"Certificate for {0} has no `subjectAltName`, falling back to check for a "
"`commonName` for now. This feature is being removed by major browsers and "
"deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 "
"for details.)".format(hostname)
),
SubjectAltNameWarning,
)
_match_hostname(cert, hostname)
self.proxy_is_verified = ssl_context.verify_mode == ssl.CERT_REQUIRED
return socket
def _match_hostname(cert, asserted_hostname): def _match_hostname(cert, asserted_hostname):
# Our upstream implementation of ssl.match_hostname()
# only applies this normalization to IP addresses so it doesn't
# match DNS SANs so we do the same thing!
stripped_hostname = asserted_hostname.strip("u[]")
if is_ipaddress(stripped_hostname):
asserted_hostname = stripped_hostname
try: try:
match_hostname(cert, asserted_hostname) match_hostname(cert, asserted_hostname)
except CertificateError as e: except CertificateError as e:
log.error( log.warning(
'Certificate did not match expected hostname: %s. ' "Certificate did not match expected hostname: %s. Certificate: %s",
'Certificate: %s', asserted_hostname, cert asserted_hostname,
cert,
) )
# Add cert to exception and reraise so client code can inspect # Add cert to exception and reraise so client code can inspect
# the cert when catching the exception, if they want to # the cert when catching the exception, if they want to
@ -365,9 +552,18 @@ def _match_hostname(cert, asserted_hostname):
raise raise
if ssl: def _get_default_user_agent():
# Make a copy for testing. return "python-urllib3/%s" % __version__
UnverifiedHTTPSConnection = HTTPSConnection
HTTPSConnection = VerifiedHTTPSConnection
else: class DummyConnection(object):
HTTPSConnection = DummyConnection """Used to detect a failed ConnectionCls import."""
pass
if not ssl:
HTTPSConnection = DummyConnection # noqa: F811
VerifiedHTTPSConnection = HTTPSConnection

View file

@ -1,51 +1,53 @@
from __future__ import absolute_import from __future__ import absolute_import
import errno import errno
import logging import logging
import socket
import sys import sys
import warnings import warnings
from socket import error as SocketError
from socket import timeout as SocketTimeout
from socket import error as SocketError, timeout as SocketTimeout from .connection import (
import socket BaseSSLError,
BrokenPipeError,
DummyConnection,
HTTPConnection,
HTTPException,
HTTPSConnection,
VerifiedHTTPSConnection,
port_by_scheme,
)
from .exceptions import ( from .exceptions import (
ClosedPoolError, ClosedPoolError,
ProtocolError,
EmptyPoolError, EmptyPoolError,
HeaderParsingError, HeaderParsingError,
HostChangedError, HostChangedError,
InsecureRequestWarning,
LocationValueError, LocationValueError,
MaxRetryError, MaxRetryError,
NewConnectionError,
ProtocolError,
ProxyError, ProxyError,
ReadTimeoutError, ReadTimeoutError,
SSLError, SSLError,
TimeoutError, TimeoutError,
InsecureRequestWarning,
NewConnectionError,
) )
from .packages.ssl_match_hostname import CertificateError
from .packages import six from .packages import six
from .packages.six.moves import queue from .packages.six.moves import queue
from .connection import ( from .packages.ssl_match_hostname import CertificateError
port_by_scheme,
DummyConnection,
HTTPConnection, HTTPSConnection, VerifiedHTTPSConnection,
HTTPException, BaseSSLError,
)
from .request import RequestMethods from .request import RequestMethods
from .response import HTTPResponse from .response import HTTPResponse
from .util.connection import is_connection_dropped from .util.connection import is_connection_dropped
from .util.proxy import connection_requires_http_tunnel
from .util.queue import LifoQueue
from .util.request import set_file_position from .util.request import set_file_position
from .util.response import assert_header_parsing from .util.response import assert_header_parsing
from .util.retry import Retry from .util.retry import Retry
from .util.timeout import Timeout from .util.timeout import Timeout
from .util.url import get_host, Url from .util.url import Url, _encode_target
from .util.url import _normalize_host as normalize_host
from .util.url import get_host, parse_url
if six.PY2:
# Queue is imported for side effects on MS Windows
import Queue as _unused_module_Queue # noqa: F401
xrange = six.moves.xrange xrange = six.moves.xrange
@ -59,22 +61,26 @@ class ConnectionPool(object):
""" """
Base class for all connection pools, such as Base class for all connection pools, such as
:class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`. :class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`.
.. note::
ConnectionPool.urlopen() does not normalize or percent-encode target URIs
which is useful if your target server doesn't support percent-encoded
target URIs.
""" """
scheme = None scheme = None
QueueCls = queue.LifoQueue QueueCls = LifoQueue
def __init__(self, host, port=None): def __init__(self, host, port=None):
if not host: if not host:
raise LocationValueError("No host specified.") raise LocationValueError("No host specified.")
self.host = _ipv6_host(host).lower() self.host = _normalize_host(host, scheme=self.scheme)
self._proxy_host = host.lower() self._proxy_host = host.lower()
self.port = port self.port = port
def __str__(self): def __str__(self):
return '%s(host=%r, port=%r)' % (type(self).__name__, return "%s(host=%r, port=%r)" % (type(self).__name__, self.host, self.port)
self.host, self.port)
def __enter__(self): def __enter__(self):
return self return self
@ -92,7 +98,7 @@ class ConnectionPool(object):
# This is taken from http://hg.python.org/cpython/file/7aaba721ebc0/Lib/socket.py#l252 # This is taken from http://hg.python.org/cpython/file/7aaba721ebc0/Lib/socket.py#l252
_blocking_errnos = set([errno.EAGAIN, errno.EWOULDBLOCK]) _blocking_errnos = {errno.EAGAIN, errno.EWOULDBLOCK}
class HTTPConnectionPool(ConnectionPool, RequestMethods): class HTTPConnectionPool(ConnectionPool, RequestMethods):
@ -101,16 +107,16 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
:param host: :param host:
Host used for this HTTP Connection (e.g. "localhost"), passed into Host used for this HTTP Connection (e.g. "localhost"), passed into
:class:`httplib.HTTPConnection`. :class:`http.client.HTTPConnection`.
:param port: :param port:
Port used for this HTTP Connection (None is equivalent to 80), passed Port used for this HTTP Connection (None is equivalent to 80), passed
into :class:`httplib.HTTPConnection`. into :class:`http.client.HTTPConnection`.
:param strict: :param strict:
Causes BadStatusLine to be raised if the status line can't be parsed Causes BadStatusLine to be raised if the status line can't be parsed
as a valid HTTP/1.0 or 1.1 status line, passed into as a valid HTTP/1.0 or 1.1 status line, passed into
:class:`httplib.HTTPConnection`. :class:`http.client.HTTPConnection`.
.. note:: .. note::
Only works in Python 2. This parameter is ignored in Python 3. Only works in Python 2. This parameter is ignored in Python 3.
@ -144,26 +150,36 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
:param _proxy: :param _proxy:
Parsed proxy URL, should not be used directly, instead, see Parsed proxy URL, should not be used directly, instead, see
:class:`urllib3.connectionpool.ProxyManager`" :class:`urllib3.ProxyManager`
:param _proxy_headers: :param _proxy_headers:
A dictionary with proxy headers, should not be used directly, A dictionary with proxy headers, should not be used directly,
instead, see :class:`urllib3.connectionpool.ProxyManager`" instead, see :class:`urllib3.ProxyManager`
:param \\**conn_kw: :param \\**conn_kw:
Additional parameters are used to create fresh :class:`urllib3.connection.HTTPConnection`, Additional parameters are used to create fresh :class:`urllib3.connection.HTTPConnection`,
:class:`urllib3.connection.HTTPSConnection` instances. :class:`urllib3.connection.HTTPSConnection` instances.
""" """
scheme = 'http' scheme = "http"
ConnectionCls = HTTPConnection ConnectionCls = HTTPConnection
ResponseCls = HTTPResponse ResponseCls = HTTPResponse
def __init__(self, host, port=None, strict=False, def __init__(
timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1, block=False, self,
headers=None, retries=None, host,
_proxy=None, _proxy_headers=None, port=None,
**conn_kw): strict=False,
timeout=Timeout.DEFAULT_TIMEOUT,
maxsize=1,
block=False,
headers=None,
retries=None,
_proxy=None,
_proxy_headers=None,
_proxy_config=None,
**conn_kw
):
ConnectionPool.__init__(self, host, port) ConnectionPool.__init__(self, host, port)
RequestMethods.__init__(self, headers) RequestMethods.__init__(self, headers)
@ -183,6 +199,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self.proxy = _proxy self.proxy = _proxy
self.proxy_headers = _proxy_headers or {} self.proxy_headers = _proxy_headers or {}
self.proxy_config = _proxy_config
# Fill the queue up so that doing get() on it will block properly # Fill the queue up so that doing get() on it will block properly
for _ in xrange(maxsize): for _ in xrange(maxsize):
@ -197,19 +214,30 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Enable Nagle's algorithm for proxies, to avoid packet fragmentation. # Enable Nagle's algorithm for proxies, to avoid packet fragmentation.
# We cannot know if the user has added default socket options, so we cannot replace the # We cannot know if the user has added default socket options, so we cannot replace the
# list. # list.
self.conn_kw.setdefault('socket_options', []) self.conn_kw.setdefault("socket_options", [])
self.conn_kw["proxy"] = self.proxy
self.conn_kw["proxy_config"] = self.proxy_config
def _new_conn(self): def _new_conn(self):
""" """
Return a fresh :class:`HTTPConnection`. Return a fresh :class:`HTTPConnection`.
""" """
self.num_connections += 1 self.num_connections += 1
log.debug("Starting new HTTP connection (%d): %s", log.debug(
self.num_connections, self.host) "Starting new HTTP connection (%d): %s:%s",
self.num_connections,
self.host,
self.port or "80",
)
conn = self.ConnectionCls(host=self.host, port=self.port, conn = self.ConnectionCls(
host=self.host,
port=self.port,
timeout=self.timeout.connect_timeout, timeout=self.timeout.connect_timeout,
strict=self.strict, **self.conn_kw) strict=self.strict,
**self.conn_kw
)
return conn return conn
def _get_conn(self, timeout=None): def _get_conn(self, timeout=None):
@ -233,18 +261,19 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
except queue.Empty: except queue.Empty:
if self.block: if self.block:
raise EmptyPoolError(self, raise EmptyPoolError(
"Pool reached maximum size and no more " self,
"connections are allowed.") "Pool reached maximum size and no more connections are allowed.",
)
pass # Oh well, we'll create a new connection then pass # Oh well, we'll create a new connection then
# If this is a persistent connection, check if it got disconnected # If this is a persistent connection, check if it got disconnected
if conn and is_connection_dropped(conn): if conn and is_connection_dropped(conn):
log.debug("Resetting dropped connection: %s", self.host) log.debug("Resetting dropped connection: %s", self.host)
conn.close() conn.close()
if getattr(conn, 'auto_open', 1) == 0: if getattr(conn, "auto_open", 1) == 0:
# This is a proxied connection that has been mutated by # This is a proxied connection that has been mutated by
# httplib._tunnel() and cannot be reused (since it would # http.client._tunnel() and cannot be reused (since it would
# attempt to bypass the proxy) # attempt to bypass the proxy)
conn = None conn = None
@ -272,9 +301,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass pass
except queue.Full: except queue.Full:
# This should never happen if self.block == True # This should never happen if self.block == True
log.warning( log.warning("Connection pool is full, discarding connection: %s", self.host)
"Connection pool is full, discarding connection: %s",
self.host)
# Connection never got put back into the pool, close it. # Connection never got put back into the pool, close it.
if conn: if conn:
@ -306,21 +333,30 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
"""Is the error actually a timeout? Will raise a ReadTimeout or pass""" """Is the error actually a timeout? Will raise a ReadTimeout or pass"""
if isinstance(err, SocketTimeout): if isinstance(err, SocketTimeout):
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value) raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
# See the above comment about EAGAIN in Python 3. In Python 2 we have # See the above comment about EAGAIN in Python 3. In Python 2 we have
# to specifically catch it and throw the timeout error # to specifically catch it and throw the timeout error
if hasattr(err, 'errno') and err.errno in _blocking_errnos: if hasattr(err, "errno") and err.errno in _blocking_errnos:
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value) raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
# Catch possible read timeouts thrown as SSL errors. If not the # Catch possible read timeouts thrown as SSL errors. If not the
# case, rethrow the original. We need to do this because of: # case, rethrow the original. We need to do this because of:
# http://bugs.python.org/issue10272 # http://bugs.python.org/issue10272
if 'timed out' in str(err) or 'did not complete (read)' in str(err): # Python 2.6 if "timed out" in str(err) or "did not complete (read)" in str(
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value) err
): # Python < 2.7.4
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
def _make_request(self, conn, method, url, timeout=_Default, chunked=False, def _make_request(
**httplib_request_kw): self, conn, method, url, timeout=_Default, chunked=False, **httplib_request_kw
):
""" """
Perform a request on a given urllib connection object taken from our Perform a request on a given urllib connection object taken from our
pool. pool.
@ -349,18 +385,36 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
self._raise_timeout(err=e, url=url, timeout_value=conn.timeout) self._raise_timeout(err=e, url=url, timeout_value=conn.timeout)
raise raise
# conn.request() calls httplib.*.request, not the method in # conn.request() calls http.client.*.request, not the method in
# urllib3.request. It also calls makefile (recv) on the socket. # urllib3.request. It also calls makefile (recv) on the socket.
try:
if chunked: if chunked:
conn.request_chunked(method, url, **httplib_request_kw) conn.request_chunked(method, url, **httplib_request_kw)
else: else:
conn.request(method, url, **httplib_request_kw) conn.request(method, url, **httplib_request_kw)
# We are swallowing BrokenPipeError (errno.EPIPE) since the server is
# legitimately able to close the connection after sending a valid response.
# With this behaviour, the received response is still readable.
except BrokenPipeError:
# Python 3
pass
except IOError as e:
# Python 2 and macOS/Linux
# EPIPE and ESHUTDOWN are BrokenPipeError on Python 2, and EPROTOTYPE is needed on macOS
# https://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
if e.errno not in {
errno.EPIPE,
errno.ESHUTDOWN,
errno.EPROTOTYPE,
}:
raise
# Reset the timeout for the recv() on the socket # Reset the timeout for the recv() on the socket
read_timeout = timeout_obj.read_timeout read_timeout = timeout_obj.read_timeout
# App Engine doesn't have a sock attr # App Engine doesn't have a sock attr
if getattr(conn, 'sock', None): if getattr(conn, "sock", None):
# In Python 3 socket.py will catch EAGAIN and return None when you # In Python 3 socket.py will catch EAGAIN and return None when you
# try and read into the file pointer created by http.client, which # try and read into the file pointer created by http.client, which
# instead raises a BadStatusLine exception. Instead of catching # instead raises a BadStatusLine exception. Instead of catching
@ -368,7 +422,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# timeouts, check for a zero timeout before making the request. # timeouts, check for a zero timeout before making the request.
if read_timeout == 0: if read_timeout == 0:
raise ReadTimeoutError( raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % read_timeout) self, url, "Read timed out. (read timeout=%s)" % read_timeout
)
if read_timeout is Timeout.DEFAULT_TIMEOUT: if read_timeout is Timeout.DEFAULT_TIMEOUT:
conn.sock.settimeout(socket.getdefaulttimeout()) conn.sock.settimeout(socket.getdefaulttimeout())
else: # None or a value else: # None or a value
@ -376,31 +431,45 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Receive the response from the server # Receive the response from the server
try: try:
try: # Python 2.7, use buffering of HTTP responses try:
# Python 2.7, use buffering of HTTP responses
httplib_response = conn.getresponse(buffering=True) httplib_response = conn.getresponse(buffering=True)
except TypeError: # Python 2.6 and older, Python 3 except TypeError:
# Python 3
try: try:
httplib_response = conn.getresponse() httplib_response = conn.getresponse()
except Exception as e: except BaseException as e:
# Remove the TypeError from the exception chain in Python 3; # Remove the TypeError from the exception chain in
# otherwise it looks like a programming error was the cause. # Python 3 (including for exceptions like SystemExit).
# Otherwise it looks like a bug in the code.
six.raise_from(e, None) six.raise_from(e, None)
except (SocketTimeout, BaseSSLError, SocketError) as e: except (SocketTimeout, BaseSSLError, SocketError) as e:
self._raise_timeout(err=e, url=url, timeout_value=read_timeout) self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
raise raise
# AppEngine doesn't have a version attr. # AppEngine doesn't have a version attr.
http_version = getattr(conn, '_http_vsn_str', 'HTTP/?') http_version = getattr(conn, "_http_vsn_str", "HTTP/?")
log.debug("%s://%s:%s \"%s %s %s\" %s %s", self.scheme, self.host, self.port, log.debug(
method, url, http_version, httplib_response.status, '%s://%s:%s "%s %s %s" %s %s',
httplib_response.length) self.scheme,
self.host,
self.port,
method,
url,
http_version,
httplib_response.status,
httplib_response.length,
)
try: try:
assert_header_parsing(httplib_response.msg) assert_header_parsing(httplib_response.msg)
except (HeaderParsingError, TypeError) as hpe: # Platform-specific: Python 3 except (HeaderParsingError, TypeError) as hpe: # Platform-specific: Python 3
log.warning( log.warning(
'Failed to parse headers (url=%s): %s', "Failed to parse headers (url=%s): %s",
self._absolute_url(url), hpe, exc_info=True) self._absolute_url(url),
hpe,
exc_info=True,
)
return httplib_response return httplib_response
@ -411,6 +480,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
""" """
Close all pooled connections and disable the pool. Close all pooled connections and disable the pool.
""" """
if self.pool is None:
return
# Disable access to the pool # Disable access to the pool
old_pool, self.pool = self.pool, None old_pool, self.pool = self.pool, None
@ -428,13 +499,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Check if the given ``url`` is a member of the same host as this Check if the given ``url`` is a member of the same host as this
connection pool. connection pool.
""" """
if url.startswith('/'): if url.startswith("/"):
return True return True
# TODO: Add optional support for socket.gethostbyname checking. # TODO: Add optional support for socket.gethostbyname checking.
scheme, host, port = get_host(url) scheme, host, port = get_host(url)
if host is not None:
host = _ipv6_host(host).lower() host = _normalize_host(host, scheme=scheme)
# Use explicit default port for comparison when none is given # Use explicit default port for comparison when none is given
if self.port and not port: if self.port and not port:
@ -444,10 +515,22 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
return (scheme, host, port) == (self.scheme, self.host, self.port) return (scheme, host, port) == (self.scheme, self.host, self.port)
def urlopen(self, method, url, body=None, headers=None, retries=None, def urlopen(
redirect=True, assert_same_host=True, timeout=_Default, self,
pool_timeout=None, release_conn=None, chunked=False, method,
body_pos=None, **response_kw): url,
body=None,
headers=None,
retries=None,
redirect=True,
assert_same_host=True,
timeout=_Default,
pool_timeout=None,
release_conn=None,
chunked=False,
body_pos=None,
**response_kw
):
""" """
Get a connection from the pool and perform an HTTP request. This is the Get a connection from the pool and perform an HTTP request. This is the
lowest level call for making a request, so you'll need to specify all lowest level call for making a request, so you'll need to specify all
@ -468,10 +551,12 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
:param method: :param method:
HTTP request method (such as GET, POST, PUT, etc.) HTTP request method (such as GET, POST, PUT, etc.)
:param url:
The URL to perform the request on.
:param body: :param body:
Data to send in the request body (useful for creating Data to send in the request body, either :class:`str`, :class:`bytes`,
POST requests, see HTTPConnectionPool.post_url for an iterable of :class:`str`/:class:`bytes`, or a file-like object.
more convenience).
:param headers: :param headers:
Dictionary of custom headers to send, such as User-Agent, Dictionary of custom headers to send, such as User-Agent,
@ -501,7 +586,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
:param assert_same_host: :param assert_same_host:
If ``True``, will make sure that the host of the pool requests is If ``True``, will make sure that the host of the pool requests is
consistent else will raise HostChangedError. When False, you can consistent else will raise HostChangedError. When ``False``, you can
use the pool on an HTTP proxy and request foreign hosts. use the pool on an HTTP proxy and request foreign hosts.
:param timeout: :param timeout:
@ -538,6 +623,10 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Additional parameters are passed to Additional parameters are passed to
:meth:`urllib3.response.HTTPResponse.from_httplib` :meth:`urllib3.response.HTTPResponse.from_httplib`
""" """
parsed_url = parse_url(url)
destination_scheme = parsed_url.scheme
if headers is None: if headers is None:
headers = self.headers headers = self.headers
@ -545,12 +634,18 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
retries = Retry.from_int(retries, redirect=redirect, default=self.retries) retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if release_conn is None: if release_conn is None:
release_conn = response_kw.get('preload_content', True) release_conn = response_kw.get("preload_content", True)
# Check host # Check host
if assert_same_host and not self.is_same_host(url): if assert_same_host and not self.is_same_host(url):
raise HostChangedError(self, url, retries) raise HostChangedError(self, url, retries)
# Ensure that the URL we're connecting to is properly encoded
if url.startswith("/"):
url = six.ensure_str(_encode_target(url))
else:
url = six.ensure_str(parsed_url.url)
conn = None conn = None
# Track whether `conn` needs to be released before # Track whether `conn` needs to be released before
@ -561,13 +656,17 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# #
# See issue #651 [1] for details. # See issue #651 [1] for details.
# #
# [1] <https://github.com/shazow/urllib3/issues/651> # [1] <https://github.com/urllib3/urllib3/issues/651>
release_this_conn = release_conn release_this_conn = release_conn
# Merge the proxy headers. Only do this in HTTP. We have to copy the http_tunnel_required = connection_requires_http_tunnel(
# headers dict so we can safely change it without those changes being self.proxy, self.proxy_config, destination_scheme
# reflected in anyone else's copy. )
if self.scheme == 'http':
# Merge the proxy headers. Only done when not using HTTP CONNECT. We
# have to copy the headers dict so we can safely change it without those
# changes being reflected in anyone else's copy.
if not http_tunnel_required:
headers = headers.copy() headers = headers.copy()
headers.update(self.proxy_headers) headers.update(self.proxy_headers)
@ -590,15 +689,22 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
conn.timeout = timeout_obj.connect_timeout conn.timeout = timeout_obj.connect_timeout
is_new_proxy_conn = self.proxy is not None and not getattr(conn, 'sock', None) is_new_proxy_conn = self.proxy is not None and not getattr(
if is_new_proxy_conn: conn, "sock", None
)
if is_new_proxy_conn and http_tunnel_required:
self._prepare_proxy(conn) self._prepare_proxy(conn)
# Make the request on the httplib connection object. # Make the request on the httplib connection object.
httplib_response = self._make_request(conn, method, url, httplib_response = self._make_request(
conn,
method,
url,
timeout=timeout_obj, timeout=timeout_obj,
body=body, headers=headers, body=body,
chunked=chunked) headers=headers,
chunked=chunked,
)
# If we're going to release the connection in ``finally:``, then # If we're going to release the connection in ``finally:``, then
# the response doesn't need to know about the connection. Otherwise # the response doesn't need to know about the connection. Otherwise
@ -607,36 +713,48 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
response_conn = conn if not release_conn else None response_conn = conn if not release_conn else None
# Pass method to Response for length checking # Pass method to Response for length checking
response_kw['request_method'] = method response_kw["request_method"] = method
# Import httplib's response into our own wrapper object # Import httplib's response into our own wrapper object
response = self.ResponseCls.from_httplib(httplib_response, response = self.ResponseCls.from_httplib(
httplib_response,
pool=self, pool=self,
connection=response_conn, connection=response_conn,
retries=retries, retries=retries,
**response_kw) **response_kw
)
# Everything went great! # Everything went great!
clean_exit = True clean_exit = True
except queue.Empty: except EmptyPoolError:
# Timed out by queue. # Didn't get a connection from the pool, no need to clean up
raise EmptyPoolError(self, "No pool connections are available.") clean_exit = True
release_this_conn = False
raise
except (TimeoutError, HTTPException, SocketError, ProtocolError, except (
BaseSSLError, SSLError, CertificateError) as e: TimeoutError,
HTTPException,
SocketError,
ProtocolError,
BaseSSLError,
SSLError,
CertificateError,
) as e:
# Discard the connection for these exceptions. It will be # Discard the connection for these exceptions. It will be
# replaced during the next _get_conn() call. # replaced during the next _get_conn() call.
clean_exit = False clean_exit = False
if isinstance(e, (BaseSSLError, CertificateError)): if isinstance(e, (BaseSSLError, CertificateError)):
e = SSLError(e) e = SSLError(e)
elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy: elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy:
e = ProxyError('Cannot connect to proxy.', e) e = ProxyError("Cannot connect to proxy.", e)
elif isinstance(e, (SocketError, HTTPException)): elif isinstance(e, (SocketError, HTTPException)):
e = ProtocolError('Connection aborted.', e) e = ProtocolError("Connection aborted.", e)
retries = retries.increment(method, url, error=e, _pool=self, retries = retries.increment(
_stacktrace=sys.exc_info()[2]) method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2]
)
retries.sleep() retries.sleep()
# Keep track of the error for the retry warning. # Keep track of the error for the retry warning.
@ -659,77 +777,87 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
if not conn: if not conn:
# Try again # Try again
log.warning("Retrying (%r) after connection " log.warning(
"broken by '%r': %s", retries, err, url) "Retrying (%r) after connection broken by '%r': %s", retries, err, url
return self.urlopen(method, url, body, headers, retries, )
redirect, assert_same_host, return self.urlopen(
timeout=timeout, pool_timeout=pool_timeout, method,
release_conn=release_conn, body_pos=body_pos, url,
**response_kw) body,
headers,
def drain_and_release_conn(response): retries,
try: redirect,
# discard any remaining response body, the connection will be assert_same_host,
# released back to the pool once the entire response is read timeout=timeout,
response.read() pool_timeout=pool_timeout,
except (TimeoutError, HTTPException, SocketError, ProtocolError, release_conn=release_conn,
BaseSSLError, SSLError) as e: chunked=chunked,
pass body_pos=body_pos,
**response_kw
)
# Handle redirect? # Handle redirect?
redirect_location = redirect and response.get_redirect_location() redirect_location = redirect and response.get_redirect_location()
if redirect_location: if redirect_location:
if response.status == 303: if response.status == 303:
method = 'GET' method = "GET"
try: try:
retries = retries.increment(method, url, response=response, _pool=self) retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError: except MaxRetryError:
if retries.raise_on_redirect: if retries.raise_on_redirect:
# Drain and release the connection for this response, since response.drain_conn()
# we're not returning it to be released manually.
drain_and_release_conn(response)
raise raise
return response return response
# drain and return the connection to the pool before recursing response.drain_conn()
drain_and_release_conn(response)
retries.sleep_for_retry(response) retries.sleep_for_retry(response)
log.debug("Redirecting %s -> %s", url, redirect_location) log.debug("Redirecting %s -> %s", url, redirect_location)
return self.urlopen( return self.urlopen(
method, redirect_location, body, headers, method,
retries=retries, redirect=redirect, redirect_location,
body,
headers,
retries=retries,
redirect=redirect,
assert_same_host=assert_same_host, assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout, timeout=timeout,
release_conn=release_conn, body_pos=body_pos, pool_timeout=pool_timeout,
**response_kw) release_conn=release_conn,
chunked=chunked,
body_pos=body_pos,
**response_kw
)
# Check if we should retry the HTTP response. # Check if we should retry the HTTP response.
has_retry_after = bool(response.getheader('Retry-After')) has_retry_after = bool(response.getheader("Retry-After"))
if retries.is_retry(method, response.status, has_retry_after): if retries.is_retry(method, response.status, has_retry_after):
try: try:
retries = retries.increment(method, url, response=response, _pool=self) retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError: except MaxRetryError:
if retries.raise_on_status: if retries.raise_on_status:
# Drain and release the connection for this response, since response.drain_conn()
# we're not returning it to be released manually.
drain_and_release_conn(response)
raise raise
return response return response
# drain and return the connection to the pool before recursing response.drain_conn()
drain_and_release_conn(response)
retries.sleep(response) retries.sleep(response)
log.debug("Retry: %s", url) log.debug("Retry: %s", url)
return self.urlopen( return self.urlopen(
method, url, body, headers, method,
retries=retries, redirect=redirect, url,
body,
headers,
retries=retries,
redirect=redirect,
assert_same_host=assert_same_host, assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout, timeout=timeout,
pool_timeout=pool_timeout,
release_conn=release_conn, release_conn=release_conn,
body_pos=body_pos, **response_kw) chunked=chunked,
body_pos=body_pos,
**response_kw
)
return response return response
@ -738,42 +866,62 @@ class HTTPSConnectionPool(HTTPConnectionPool):
""" """
Same as :class:`.HTTPConnectionPool`, but HTTPS. Same as :class:`.HTTPConnectionPool`, but HTTPS.
When Python is compiled with the :mod:`ssl` module, then :class:`.HTTPSConnection` uses one of ``assert_fingerprint``,
:class:`.VerifiedHTTPSConnection` is used, which *can* verify certificates,
instead of :class:`.HTTPSConnection`.
:class:`.VerifiedHTTPSConnection` uses one of ``assert_fingerprint``,
``assert_hostname`` and ``host`` in this order to verify connections. ``assert_hostname`` and ``host`` in this order to verify connections.
If ``assert_hostname`` is False, no verification is done. If ``assert_hostname`` is False, no verification is done.
The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``,
``ca_cert_dir``, and ``ssl_version`` are only used if :mod:`ssl` is ``ca_cert_dir``, ``ssl_version``, ``key_password`` are only used if :mod:`ssl`
available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade is available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade
the connection socket into an SSL socket. the connection socket into an SSL socket.
""" """
scheme = 'https' scheme = "https"
ConnectionCls = HTTPSConnection ConnectionCls = HTTPSConnection
def __init__(self, host, port=None, def __init__(
strict=False, timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1, self,
block=False, headers=None, retries=None, host,
_proxy=None, _proxy_headers=None, port=None,
key_file=None, cert_file=None, cert_reqs=None, strict=False,
ca_certs=None, ssl_version=None, timeout=Timeout.DEFAULT_TIMEOUT,
assert_hostname=None, assert_fingerprint=None, maxsize=1,
ca_cert_dir=None, **conn_kw): block=False,
headers=None,
retries=None,
_proxy=None,
_proxy_headers=None,
key_file=None,
cert_file=None,
cert_reqs=None,
key_password=None,
ca_certs=None,
ssl_version=None,
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
**conn_kw
):
HTTPConnectionPool.__init__(self, host, port, strict, timeout, maxsize, HTTPConnectionPool.__init__(
block, headers, retries, _proxy, _proxy_headers, self,
**conn_kw) host,
port,
if ca_certs and cert_reqs is None: strict,
cert_reqs = 'CERT_REQUIRED' timeout,
maxsize,
block,
headers,
retries,
_proxy,
_proxy_headers,
**conn_kw
)
self.key_file = key_file self.key_file = key_file
self.cert_file = cert_file self.cert_file = cert_file
self.cert_reqs = cert_reqs self.cert_reqs = cert_reqs
self.key_password = key_password
self.ca_certs = ca_certs self.ca_certs = ca_certs
self.ca_cert_dir = ca_cert_dir self.ca_cert_dir = ca_cert_dir
self.ssl_version = ssl_version self.ssl_version = ssl_version
@ -787,45 +935,50 @@ class HTTPSConnectionPool(HTTPConnectionPool):
""" """
if isinstance(conn, VerifiedHTTPSConnection): if isinstance(conn, VerifiedHTTPSConnection):
conn.set_cert(key_file=self.key_file, conn.set_cert(
key_file=self.key_file,
key_password=self.key_password,
cert_file=self.cert_file, cert_file=self.cert_file,
cert_reqs=self.cert_reqs, cert_reqs=self.cert_reqs,
ca_certs=self.ca_certs, ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir, ca_cert_dir=self.ca_cert_dir,
assert_hostname=self.assert_hostname, assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint) assert_fingerprint=self.assert_fingerprint,
)
conn.ssl_version = self.ssl_version conn.ssl_version = self.ssl_version
return conn return conn
def _prepare_proxy(self, conn): def _prepare_proxy(self, conn):
""" """
Establish tunnel connection early, because otherwise httplib Establishes a tunnel connection through HTTP CONNECT.
would improperly set Host: header to proxy's IP:port.
"""
# Python 2.7+
try:
set_tunnel = conn.set_tunnel
except AttributeError: # Platform-specific: Python 2.6
set_tunnel = conn._set_tunnel
if sys.version_info <= (2, 6, 4) and not self.proxy_headers: # Python 2.6.4 and older Tunnel connection is established early because otherwise httplib would
set_tunnel(self._proxy_host, self.port) improperly set Host: header to proxy's IP:port.
else: """
set_tunnel(self._proxy_host, self.port, self.proxy_headers)
conn.set_tunnel(self._proxy_host, self.port, self.proxy_headers)
if self.proxy.scheme == "https":
conn.tls_in_tls_required = True
conn.connect() conn.connect()
def _new_conn(self): def _new_conn(self):
""" """
Return a fresh :class:`httplib.HTTPSConnection`. Return a fresh :class:`http.client.HTTPSConnection`.
""" """
self.num_connections += 1 self.num_connections += 1
log.debug("Starting new HTTPS connection (%d): %s", log.debug(
self.num_connections, self.host) "Starting new HTTPS connection (%d): %s:%s",
self.num_connections,
self.host,
self.port or "443",
)
if not self.ConnectionCls or self.ConnectionCls is DummyConnection: if not self.ConnectionCls or self.ConnectionCls is DummyConnection:
raise SSLError("Can't connect to HTTPS URL because the SSL " raise SSLError(
"module is not available.") "Can't connect to HTTPS URL because the SSL module is not available."
)
actual_host = self.host actual_host = self.host
actual_port = self.port actual_port = self.port
@ -833,9 +986,16 @@ class HTTPSConnectionPool(HTTPConnectionPool):
actual_host = self.proxy.host actual_host = self.proxy.host
actual_port = self.proxy.port actual_port = self.proxy.port
conn = self.ConnectionCls(host=actual_host, port=actual_port, conn = self.ConnectionCls(
host=actual_host,
port=actual_port,
timeout=self.timeout.connect_timeout, timeout=self.timeout.connect_timeout,
strict=self.strict, **self.conn_kw) strict=self.strict,
cert_file=self.cert_file,
key_file=self.key_file,
key_password=self.key_password,
**self.conn_kw
)
return self._prepare_conn(conn) return self._prepare_conn(conn)
@ -846,16 +1006,30 @@ class HTTPSConnectionPool(HTTPConnectionPool):
super(HTTPSConnectionPool, self)._validate_conn(conn) super(HTTPSConnectionPool, self)._validate_conn(conn)
# Force connect early to allow us to validate the connection. # Force connect early to allow us to validate the connection.
if not getattr(conn, 'sock', None): # AppEngine might not have `.sock` if not getattr(conn, "sock", None): # AppEngine might not have `.sock`
conn.connect() conn.connect()
if not conn.is_verified: if not conn.is_verified:
warnings.warn(( warnings.warn(
'Unverified HTTPS request is being made. ' (
'Adding certificate verification is strongly advised. See: ' "Unverified HTTPS request is being made to host '%s'. "
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html' "Adding certificate verification is strongly advised. See: "
'#ssl-warnings'), "https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
InsecureRequestWarning) "#ssl-warnings" % conn.host
),
InsecureRequestWarning,
)
if getattr(conn, "proxy_is_verified", None) is False:
warnings.warn(
(
"Unverified HTTPS connection done to an HTTPS proxy. "
"Adding certificate verification is strongly advised. See: "
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
"#ssl-warnings"
),
InsecureRequestWarning,
)
def connection_from_url(url, **kw): def connection_from_url(url, **kw):
@ -880,26 +1054,25 @@ def connection_from_url(url, **kw):
""" """
scheme, host, port = get_host(url) scheme, host, port = get_host(url)
port = port or port_by_scheme.get(scheme, 80) port = port or port_by_scheme.get(scheme, 80)
if scheme == 'https': if scheme == "https":
return HTTPSConnectionPool(host, port=port, **kw) return HTTPSConnectionPool(host, port=port, **kw)
else: else:
return HTTPConnectionPool(host, port=port, **kw) return HTTPConnectionPool(host, port=port, **kw)
def _ipv6_host(host): def _normalize_host(host, scheme):
""" """
Process IPv6 address literals Normalize hosts for comparisons and use with sockets.
""" """
host = normalize_host(host, scheme)
# httplib doesn't like it when we include brackets in IPv6 addresses # httplib doesn't like it when we include brackets in IPv6 addresses
# Specifically, if we include brackets but also pass the port then # Specifically, if we include brackets but also pass the port then
# httplib crazily doubles up the square brackets on the Host header. # httplib crazily doubles up the square brackets on the Host header.
# Instead, we need to make sure we never pass ``None`` as the port. # Instead, we need to make sure we never pass ``None`` as the port.
# However, for backward compatibility reasons we can't actually # However, for backward compatibility reasons we can't actually
# *assert* that. See http://bugs.python.org/issue28539 # *assert* that. See http://bugs.python.org/issue28539
# if host.startswith("[") and host.endswith("]"):
# Also if an IPv6 address literal has a zone identifier, the host = host[1:-1]
# percent sign might be URIencoded, convert it back into ASCII
if host.startswith('[') and host.endswith(']'):
host = host.replace('%25', '%').strip('[]')
return host return host

View file

@ -0,0 +1,36 @@
"""
This module provides means to detect the App Engine environment.
"""
import os
def is_appengine():
return is_local_appengine() or is_prod_appengine()
def is_appengine_sandbox():
"""Reports if the app is running in the first generation sandbox.
The second generation runtimes are technically still in a sandbox, but it
is much less restrictive, so generally you shouldn't need to check for it.
see https://cloud.google.com/appengine/docs/standard/runtimes
"""
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
def is_local_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Development/")
def is_prod_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Google App Engine/")
def is_prod_appengine_mvms():
"""Deprecated."""
return False

View file

@ -32,35 +32,60 @@ license and by oscrypto's:
from __future__ import absolute_import from __future__ import absolute_import
import platform import platform
from ctypes.util import find_library
from ctypes import ( from ctypes import (
c_void_p, c_int32, c_char_p, c_size_t, c_byte, c_uint32, c_ulong, c_long, CDLL,
c_bool CFUNCTYPE,
POINTER,
c_bool,
c_byte,
c_char_p,
c_int32,
c_long,
c_size_t,
c_uint32,
c_ulong,
c_void_p,
) )
from ctypes import CDLL, POINTER, CFUNCTYPE from ctypes.util import find_library
from urllib3.packages.six import raise_from
security_path = find_library('Security') if platform.system() != "Darwin":
if not security_path: raise ImportError("Only macOS is supported")
raise ImportError('The library Security could not be found')
core_foundation_path = find_library('CoreFoundation')
if not core_foundation_path:
raise ImportError('The library CoreFoundation could not be found')
version = platform.mac_ver()[0] version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split('.'))) version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8): if version_info < (10, 8):
raise OSError( raise OSError(
'Only OS X 10.8 and newer are supported, not %s.%s' % ( "Only OS X 10.8 and newer are supported, not %s.%s"
version_info[0], version_info[1] % (version_info[0], version_info[1])
) )
def load_cdll(name, macos10_16_path):
"""Loads a CDLL by name, falling back to known path on 10.16+"""
try:
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
# beta being labeled as 10.16.
if version_info >= (10, 16):
path = macos10_16_path
else:
path = find_library(name)
if not path:
raise OSError # Caught and reraised as 'ImportError'
return CDLL(path, use_errno=True)
except OSError:
raise_from(ImportError("The library %s failed to load" % name), None)
Security = load_cdll(
"Security", "/System/Library/Frameworks/Security.framework/Security"
)
CoreFoundation = load_cdll(
"CoreFoundation",
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
) )
Security = CDLL(security_path, use_errno=True)
CoreFoundation = CDLL(core_foundation_path, use_errno=True)
Boolean = c_bool Boolean = c_bool
CFIndex = c_long CFIndex = c_long
@ -129,27 +154,19 @@ try:
Security.SecKeyGetTypeID.argtypes = [] Security.SecKeyGetTypeID.argtypes = []
Security.SecKeyGetTypeID.restype = CFTypeID Security.SecKeyGetTypeID.restype = CFTypeID
Security.SecCertificateCreateWithData.argtypes = [ Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
CFAllocatorRef,
CFDataRef
]
Security.SecCertificateCreateWithData.restype = SecCertificateRef Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [ Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
SecCertificateRef
]
Security.SecCertificateCopyData.restype = CFDataRef Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [ Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
OSStatus,
c_void_p
]
Security.SecCopyErrorMessageString.restype = CFStringRef Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecIdentityCreateWithCertificate.argtypes = [ Security.SecIdentityCreateWithCertificate.argtypes = [
CFTypeRef, CFTypeRef,
SecCertificateRef, SecCertificateRef,
POINTER(SecIdentityRef) POINTER(SecIdentityRef),
] ]
Security.SecIdentityCreateWithCertificate.restype = OSStatus Security.SecIdentityCreateWithCertificate.restype = OSStatus
@ -159,201 +176,133 @@ try:
c_void_p, c_void_p,
Boolean, Boolean,
c_void_p, c_void_p,
POINTER(SecKeychainRef) POINTER(SecKeychainRef),
] ]
Security.SecKeychainCreate.restype = OSStatus Security.SecKeychainCreate.restype = OSStatus
Security.SecKeychainDelete.argtypes = [ Security.SecKeychainDelete.argtypes = [SecKeychainRef]
SecKeychainRef
]
Security.SecKeychainDelete.restype = OSStatus Security.SecKeychainDelete.restype = OSStatus
Security.SecPKCS12Import.argtypes = [ Security.SecPKCS12Import.argtypes = [
CFDataRef, CFDataRef,
CFDictionaryRef, CFDictionaryRef,
POINTER(CFArrayRef) POINTER(CFArrayRef),
] ]
Security.SecPKCS12Import.restype = OSStatus Security.SecPKCS12Import.restype = OSStatus
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t)) SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)) SSLWriteFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
)
Security.SSLSetIOFuncs.argtypes = [ Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
SSLContextRef,
SSLReadFunc,
SSLWriteFunc
]
Security.SSLSetIOFuncs.restype = OSStatus Security.SSLSetIOFuncs.restype = OSStatus
Security.SSLSetPeerID.argtypes = [ Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
SSLContextRef,
c_char_p,
c_size_t
]
Security.SSLSetPeerID.restype = OSStatus Security.SSLSetPeerID.restype = OSStatus
Security.SSLSetCertificate.argtypes = [ Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
SSLContextRef,
CFArrayRef
]
Security.SSLSetCertificate.restype = OSStatus Security.SSLSetCertificate.restype = OSStatus
Security.SSLSetCertificateAuthorities.argtypes = [ Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
SSLContextRef,
CFTypeRef,
Boolean
]
Security.SSLSetCertificateAuthorities.restype = OSStatus Security.SSLSetCertificateAuthorities.restype = OSStatus
Security.SSLSetConnection.argtypes = [ Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
SSLContextRef,
SSLConnectionRef
]
Security.SSLSetConnection.restype = OSStatus Security.SSLSetConnection.restype = OSStatus
Security.SSLSetPeerDomainName.argtypes = [ Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
SSLContextRef,
c_char_p,
c_size_t
]
Security.SSLSetPeerDomainName.restype = OSStatus Security.SSLSetPeerDomainName.restype = OSStatus
Security.SSLHandshake.argtypes = [ Security.SSLHandshake.argtypes = [SSLContextRef]
SSLContextRef
]
Security.SSLHandshake.restype = OSStatus Security.SSLHandshake.restype = OSStatus
Security.SSLRead.argtypes = [ Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
SSLContextRef,
c_char_p,
c_size_t,
POINTER(c_size_t)
]
Security.SSLRead.restype = OSStatus Security.SSLRead.restype = OSStatus
Security.SSLWrite.argtypes = [ Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
SSLContextRef,
c_char_p,
c_size_t,
POINTER(c_size_t)
]
Security.SSLWrite.restype = OSStatus Security.SSLWrite.restype = OSStatus
Security.SSLClose.argtypes = [ Security.SSLClose.argtypes = [SSLContextRef]
SSLContextRef
]
Security.SSLClose.restype = OSStatus Security.SSLClose.restype = OSStatus
Security.SSLGetNumberSupportedCiphers.argtypes = [ Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
SSLContextRef,
POINTER(c_size_t)
]
Security.SSLGetNumberSupportedCiphers.restype = OSStatus Security.SSLGetNumberSupportedCiphers.restype = OSStatus
Security.SSLGetSupportedCiphers.argtypes = [ Security.SSLGetSupportedCiphers.argtypes = [
SSLContextRef, SSLContextRef,
POINTER(SSLCipherSuite), POINTER(SSLCipherSuite),
POINTER(c_size_t) POINTER(c_size_t),
] ]
Security.SSLGetSupportedCiphers.restype = OSStatus Security.SSLGetSupportedCiphers.restype = OSStatus
Security.SSLSetEnabledCiphers.argtypes = [ Security.SSLSetEnabledCiphers.argtypes = [
SSLContextRef, SSLContextRef,
POINTER(SSLCipherSuite), POINTER(SSLCipherSuite),
c_size_t c_size_t,
] ]
Security.SSLSetEnabledCiphers.restype = OSStatus Security.SSLSetEnabledCiphers.restype = OSStatus
Security.SSLGetNumberEnabledCiphers.argtype = [ Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
SSLContextRef,
POINTER(c_size_t)
]
Security.SSLGetNumberEnabledCiphers.restype = OSStatus Security.SSLGetNumberEnabledCiphers.restype = OSStatus
Security.SSLGetEnabledCiphers.argtypes = [ Security.SSLGetEnabledCiphers.argtypes = [
SSLContextRef, SSLContextRef,
POINTER(SSLCipherSuite), POINTER(SSLCipherSuite),
POINTER(c_size_t) POINTER(c_size_t),
] ]
Security.SSLGetEnabledCiphers.restype = OSStatus Security.SSLGetEnabledCiphers.restype = OSStatus
Security.SSLGetNegotiatedCipher.argtypes = [ Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
SSLContextRef,
POINTER(SSLCipherSuite)
]
Security.SSLGetNegotiatedCipher.restype = OSStatus Security.SSLGetNegotiatedCipher.restype = OSStatus
Security.SSLGetNegotiatedProtocolVersion.argtypes = [ Security.SSLGetNegotiatedProtocolVersion.argtypes = [
SSLContextRef, SSLContextRef,
POINTER(SSLProtocol) POINTER(SSLProtocol),
] ]
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
Security.SSLCopyPeerTrust.argtypes = [ Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
SSLContextRef,
POINTER(SecTrustRef)
]
Security.SSLCopyPeerTrust.restype = OSStatus Security.SSLCopyPeerTrust.restype = OSStatus
Security.SecTrustSetAnchorCertificates.argtypes = [ Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
SecTrustRef,
CFArrayRef
]
Security.SecTrustSetAnchorCertificates.restype = OSStatus Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [ Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
SecTrustRef,
Boolean
]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [ Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
SecTrustRef,
POINTER(SecTrustResultType)
]
Security.SecTrustEvaluate.restype = OSStatus Security.SecTrustEvaluate.restype = OSStatus
Security.SecTrustGetCertificateCount.argtypes = [ Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
SecTrustRef
]
Security.SecTrustGetCertificateCount.restype = CFIndex Security.SecTrustGetCertificateCount.restype = CFIndex
Security.SecTrustGetCertificateAtIndex.argtypes = [ Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
SecTrustRef,
CFIndex
]
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
Security.SSLCreateContext.argtypes = [ Security.SSLCreateContext.argtypes = [
CFAllocatorRef, CFAllocatorRef,
SSLProtocolSide, SSLProtocolSide,
SSLConnectionType SSLConnectionType,
] ]
Security.SSLCreateContext.restype = SSLContextRef Security.SSLCreateContext.restype = SSLContextRef
Security.SSLSetSessionOption.argtypes = [ Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
SSLContextRef,
SSLSessionOption,
Boolean
]
Security.SSLSetSessionOption.restype = OSStatus Security.SSLSetSessionOption.restype = OSStatus
Security.SSLSetProtocolVersionMin.argtypes = [ Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
SSLContextRef,
SSLProtocol
]
Security.SSLSetProtocolVersionMin.restype = OSStatus Security.SSLSetProtocolVersionMin.restype = OSStatus
Security.SSLSetProtocolVersionMax.argtypes = [ Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol]
SSLContextRef,
SSLProtocol
]
Security.SSLSetProtocolVersionMax.restype = OSStatus Security.SSLSetProtocolVersionMax.restype = OSStatus
Security.SecCopyErrorMessageString.argtypes = [ try:
OSStatus, Security.SSLSetALPNProtocols.argtypes = [SSLContextRef, CFArrayRef]
c_void_p Security.SSLSetALPNProtocols.restype = OSStatus
] except AttributeError:
# Supported only in 10.12+
pass
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SSLReadFunc = SSLReadFunc Security.SSLReadFunc = SSLReadFunc
@ -369,64 +318,47 @@ try:
Security.OSStatus = OSStatus Security.OSStatus = OSStatus
Security.kSecImportExportPassphrase = CFStringRef.in_dll( Security.kSecImportExportPassphrase = CFStringRef.in_dll(
Security, 'kSecImportExportPassphrase' Security, "kSecImportExportPassphrase"
) )
Security.kSecImportItemIdentity = CFStringRef.in_dll( Security.kSecImportItemIdentity = CFStringRef.in_dll(
Security, 'kSecImportItemIdentity' Security, "kSecImportItemIdentity"
) )
# CoreFoundation time! # CoreFoundation time!
CoreFoundation.CFRetain.argtypes = [ CoreFoundation.CFRetain.argtypes = [CFTypeRef]
CFTypeRef
]
CoreFoundation.CFRetain.restype = CFTypeRef CoreFoundation.CFRetain.restype = CFTypeRef
CoreFoundation.CFRelease.argtypes = [ CoreFoundation.CFRelease.argtypes = [CFTypeRef]
CFTypeRef
]
CoreFoundation.CFRelease.restype = None CoreFoundation.CFRelease.restype = None
CoreFoundation.CFGetTypeID.argtypes = [ CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CFTypeRef
]
CoreFoundation.CFGetTypeID.restype = CFTypeID CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [ CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef, CFAllocatorRef,
c_char_p, c_char_p,
CFStringEncoding CFStringEncoding,
] ]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [ CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CFStringRef,
CFStringEncoding
]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [ CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef, CFStringRef,
c_char_p, c_char_p,
CFIndex, CFIndex,
CFStringEncoding CFStringEncoding,
] ]
CoreFoundation.CFStringGetCString.restype = c_bool CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [ CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
CFAllocatorRef,
c_char_p,
CFIndex
]
CoreFoundation.CFDataCreate.restype = CFDataRef CoreFoundation.CFDataCreate.restype = CFDataRef
CoreFoundation.CFDataGetLength.argtypes = [ CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
CFDataRef
]
CoreFoundation.CFDataGetLength.restype = CFIndex CoreFoundation.CFDataGetLength.restype = CFIndex
CoreFoundation.CFDataGetBytePtr.argtypes = [ CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
CFDataRef
]
CoreFoundation.CFDataGetBytePtr.restype = c_void_p CoreFoundation.CFDataGetBytePtr.restype = c_void_p
CoreFoundation.CFDictionaryCreate.argtypes = [ CoreFoundation.CFDictionaryCreate.argtypes = [
@ -435,14 +367,11 @@ try:
POINTER(CFTypeRef), POINTER(CFTypeRef),
CFIndex, CFIndex,
CFDictionaryKeyCallBacks, CFDictionaryKeyCallBacks,
CFDictionaryValueCallBacks CFDictionaryValueCallBacks,
] ]
CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef
CoreFoundation.CFDictionaryGetValue.argtypes = [ CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
CFDictionaryRef,
CFTypeRef
]
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
CoreFoundation.CFArrayCreate.argtypes = [ CoreFoundation.CFArrayCreate.argtypes = [
@ -456,36 +385,30 @@ try:
CoreFoundation.CFArrayCreateMutable.argtypes = [ CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef, CFAllocatorRef,
CFIndex, CFIndex,
CFArrayCallBacks CFArrayCallBacks,
] ]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [ CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CFMutableArrayRef,
c_void_p
]
CoreFoundation.CFArrayAppendValue.restype = None CoreFoundation.CFArrayAppendValue.restype = None
CoreFoundation.CFArrayGetCount.argtypes = [ CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CFArrayRef
]
CoreFoundation.CFArrayGetCount.restype = CFIndex CoreFoundation.CFArrayGetCount.restype = CFIndex
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [ CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CFArrayRef,
CFIndex
]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll( CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
CoreFoundation, 'kCFAllocatorDefault' CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeArrayCallBacks"
) )
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(CoreFoundation, 'kCFTypeArrayCallBacks')
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll( CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeDictionaryKeyCallBacks' CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
) )
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll( CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeDictionaryValueCallBacks' CoreFoundation, "kCFTypeDictionaryValueCallBacks"
) )
CoreFoundation.CFTypeRef = CFTypeRef CoreFoundation.CFTypeRef = CFTypeRef
@ -494,7 +417,7 @@ try:
CoreFoundation.CFDictionaryRef = CFDictionaryRef CoreFoundation.CFDictionaryRef = CFDictionaryRef
except (AttributeError): except (AttributeError):
raise ImportError('Error initializing ctypes') raise ImportError("Error initializing ctypes")
class CFConst(object): class CFConst(object):
@ -502,6 +425,7 @@ class CFConst(object):
A class object that acts as essentially a namespace for CoreFoundation A class object that acts as essentially a namespace for CoreFoundation
constants. constants.
""" """
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100) kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
@ -509,6 +433,7 @@ class SecurityConst(object):
""" """
A class object that acts as essentially a namespace for Security constants. A class object that acts as essentially a namespace for Security constants.
""" """
kSSLSessionOptionBreakOnServerAuth = 0 kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1 kSSLProtocol2 = 1
@ -516,6 +441,9 @@ class SecurityConst(object):
kTLSProtocol1 = 4 kTLSProtocol1 = 4
kTLSProtocol11 = 7 kTLSProtocol11 = 7
kTLSProtocol12 = 8 kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999
kSSLClientSide = 1 kSSLClientSide = 1
kSSLStreamType = 0 kSSLStreamType = 0
@ -558,30 +486,27 @@ class SecurityConst(object):
errSecInvalidTrustSettings = -25262 errSecInvalidTrustSettings = -25262
# Cipher suites. We only pick the ones our default cipher string allows. # Cipher suites. We only pick the ones our default cipher string allows.
# Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030 TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F
TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 = 0x00A3 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8
TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F
TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 = 0x00A2
TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024 TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014 TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B
TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 = 0x006A
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039 TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0x0038
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009 TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013 TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067 TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067
TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 = 0x0040
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033 TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0x0032
TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D
TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D
@ -590,4 +515,5 @@ class SecurityConst(object):
TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F
TLS_AES_128_GCM_SHA256 = 0x1301 TLS_AES_128_GCM_SHA256 = 0x1301
TLS_AES_256_GCM_SHA384 = 0x1302 TLS_AES_256_GCM_SHA384 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 = 0x1303 TLS_AES_128_CCM_8_SHA256 = 0x1305
TLS_AES_128_CCM_SHA256 = 0x1304

View file

@ -10,13 +10,13 @@ appropriate and useful assistance to the higher-level code.
import base64 import base64
import ctypes import ctypes
import itertools import itertools
import re
import os import os
import re
import ssl import ssl
import struct
import tempfile import tempfile
from .bindings import Security, CoreFoundation, CFConst from .bindings import CFConst, CoreFoundation, Security
# This regular expression is used to grab PEM data out of a PEM bundle. # This regular expression is used to grab PEM data out of a PEM bundle.
_PEM_CERTS_RE = re.compile( _PEM_CERTS_RE = re.compile(
@ -56,6 +56,51 @@ def _cf_dictionary_from_tuples(tuples):
) )
def _cfstr(py_bstr):
"""
Given a Python binary data, create a CFString.
The string must be CFReleased by the caller.
"""
c_str = ctypes.c_char_p(py_bstr)
cf_str = CoreFoundation.CFStringCreateWithCString(
CoreFoundation.kCFAllocatorDefault,
c_str,
CFConst.kCFStringEncodingUTF8,
)
return cf_str
def _create_cfstring_array(lst):
"""
Given a list of Python binary data, create an associated CFMutableArray.
The array must be CFReleased by the caller.
Raises an ssl.SSLError on failure.
"""
cf_arr = None
try:
cf_arr = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cf_arr:
raise MemoryError("Unable to allocate memory!")
for item in lst:
cf_str = _cfstr(item)
if not cf_str:
raise MemoryError("Unable to allocate memory!")
try:
CoreFoundation.CFArrayAppendValue(cf_arr, cf_str)
finally:
CoreFoundation.CFRelease(cf_str)
except BaseException as e:
if cf_arr:
CoreFoundation.CFRelease(cf_arr)
raise ssl.SSLError("Unable to allocate array: %s" % (e,))
return cf_arr
def _cf_string_to_unicode(value): def _cf_string_to_unicode(value):
""" """
Creates a Unicode string from a CFString object. Used entirely for error Creates a Unicode string from a CFString object. Used entirely for error
@ -66,22 +111,18 @@ def _cf_string_to_unicode(value):
value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p)) value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
string = CoreFoundation.CFStringGetCStringPtr( string = CoreFoundation.CFStringGetCStringPtr(
value_as_void_p, value_as_void_p, CFConst.kCFStringEncodingUTF8
CFConst.kCFStringEncodingUTF8
) )
if string is None: if string is None:
buffer = ctypes.create_string_buffer(1024) buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString( result = CoreFoundation.CFStringGetCString(
value_as_void_p, value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
buffer,
1024,
CFConst.kCFStringEncodingUTF8
) )
if not result: if not result:
raise OSError('Error copying C string from CFStringRef') raise OSError("Error copying C string from CFStringRef")
string = buffer.value string = buffer.value
if string is not None: if string is not None:
string = string.decode('utf-8') string = string.decode("utf-8")
return string return string
@ -97,8 +138,8 @@ def _assert_no_error(error, exception_class=None):
output = _cf_string_to_unicode(cf_error_string) output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string) CoreFoundation.CFRelease(cf_error_string)
if output is None or output == u'': if output is None or output == u"":
output = u'OSStatus %s' % error output = u"OSStatus %s" % error
if exception_class is None: if exception_class is None:
exception_class = ssl.SSLError exception_class = ssl.SSLError
@ -111,9 +152,11 @@ def _cert_array_from_pem(pem_bundle):
Given a bundle of certs in PEM format, turns them into a CFArray of certs Given a bundle of certs in PEM format, turns them into a CFArray of certs
that can be used to validate a cert chain. that can be used to validate a cert chain.
""" """
# Normalize the PEM bundle's line endings.
pem_bundle = pem_bundle.replace(b"\r\n", b"\n")
der_certs = [ der_certs = [
base64.b64decode(match.group(1)) base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
for match in _PEM_CERTS_RE.finditer(pem_bundle)
] ]
if not der_certs: if not der_certs:
raise ssl.SSLError("No root certificates specified") raise ssl.SSLError("No root certificates specified")
@ -121,7 +164,7 @@ def _cert_array_from_pem(pem_bundle):
cert_array = CoreFoundation.CFArrayCreateMutable( cert_array = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault, CoreFoundation.kCFAllocatorDefault,
0, 0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks) ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
) )
if not cert_array: if not cert_array:
raise ssl.SSLError("Unable to allocate memory!") raise ssl.SSLError("Unable to allocate memory!")
@ -145,6 +188,7 @@ def _cert_array_from_pem(pem_bundle):
# We only want to do that if an error occurs: otherwise, the caller # We only want to do that if an error occurs: otherwise, the caller
# should free. # should free.
CoreFoundation.CFRelease(cert_array) CoreFoundation.CFRelease(cert_array)
raise
return cert_array return cert_array
@ -183,21 +227,16 @@ def _temporary_keychain():
# some random bytes to password-protect the keychain we're creating, so we # some random bytes to password-protect the keychain we're creating, so we
# ask for 40 random bytes. # ask for 40 random bytes.
random_bytes = os.urandom(40) random_bytes = os.urandom(40)
filename = base64.b64encode(random_bytes[:8]).decode('utf-8') filename = base64.b16encode(random_bytes[:8]).decode("utf-8")
password = base64.b64encode(random_bytes[8:]) # Must be valid UTF-8 password = base64.b16encode(random_bytes[8:]) # Must be valid UTF-8
tempdirectory = tempfile.mkdtemp() tempdirectory = tempfile.mkdtemp()
keychain_path = os.path.join(tempdirectory, filename).encode('utf-8') keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
# We now want to create the keychain itself. # We now want to create the keychain itself.
keychain = Security.SecKeychainRef() keychain = Security.SecKeychainRef()
status = Security.SecKeychainCreate( status = Security.SecKeychainCreate(
keychain_path, keychain_path, len(password), password, False, None, ctypes.byref(keychain)
len(password),
password,
False,
None,
ctypes.byref(keychain)
) )
_assert_no_error(status) _assert_no_error(status)
@ -216,14 +255,12 @@ def _load_items_from_file(keychain, path):
identities = [] identities = []
result_array = None result_array = None
with open(path, 'rb') as f: with open(path, "rb") as f:
raw_filedata = f.read() raw_filedata = f.read()
try: try:
filedata = CoreFoundation.CFDataCreate( filedata = CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, CoreFoundation.kCFAllocatorDefault, raw_filedata, len(raw_filedata)
raw_filedata,
len(raw_filedata)
) )
result_array = CoreFoundation.CFArrayRef() result_array = CoreFoundation.CFArrayRef()
result = Security.SecItemImport( result = Security.SecItemImport(
@ -234,7 +271,7 @@ def _load_items_from_file(keychain, path):
0, # import flags 0, # import flags
None, # key params, can include passphrase in the future None, # key params, can include passphrase in the future
keychain, # The keychain to insert into keychain, # The keychain to insert into
ctypes.byref(result_array) # Results ctypes.byref(result_array), # Results
) )
_assert_no_error(result) _assert_no_error(result)
@ -244,9 +281,7 @@ def _load_items_from_file(keychain, path):
# keychain already has them! # keychain already has them!
result_count = CoreFoundation.CFArrayGetCount(result_array) result_count = CoreFoundation.CFArrayGetCount(result_array)
for index in range(result_count): for index in range(result_count):
item = CoreFoundation.CFArrayGetValueAtIndex( item = CoreFoundation.CFArrayGetValueAtIndex(result_array, index)
result_array, index
)
item = ctypes.cast(item, CoreFoundation.CFTypeRef) item = ctypes.cast(item, CoreFoundation.CFTypeRef)
if _is_cert(item): if _is_cert(item):
@ -304,9 +339,7 @@ def _load_client_cert_chain(keychain, *paths):
try: try:
for file_path in paths: for file_path in paths:
new_identities, new_certs = _load_items_from_file( new_identities, new_certs = _load_items_from_file(keychain, file_path)
keychain, file_path
)
identities.extend(new_identities) identities.extend(new_identities)
certificates.extend(new_certs) certificates.extend(new_certs)
@ -315,9 +348,7 @@ def _load_client_cert_chain(keychain, *paths):
if not identities: if not identities:
new_identity = Security.SecIdentityRef() new_identity = Security.SecIdentityRef()
status = Security.SecIdentityCreateWithCertificate( status = Security.SecIdentityCreateWithCertificate(
keychain, keychain, certificates[0], ctypes.byref(new_identity)
certificates[0],
ctypes.byref(new_identity)
) )
_assert_no_error(status) _assert_no_error(status)
identities.append(new_identity) identities.append(new_identity)
@ -341,3 +372,26 @@ def _load_client_cert_chain(keychain, *paths):
finally: finally:
for obj in itertools.chain(identities, certificates): for obj in itertools.chain(identities, certificates):
CoreFoundation.CFRelease(obj) CoreFoundation.CFRelease(obj)
TLS_PROTOCOL_VERSIONS = {
"SSLv2": (0, 2),
"SSLv3": (3, 0),
"TLSv1": (3, 1),
"TLSv1.1": (3, 2),
"TLSv1.2": (3, 3),
}
def _build_tls_unknown_ca_alert(version):
"""
Builds a TLS alert record for an unknown CA.
"""
ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version]
severity_fatal = 0x02
description_unknown_ca = 0x30
msg = struct.pack(">BB", severity_fatal, description_unknown_ca)
msg_len = len(msg)
record_type_alert = 0x15
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
return record

View file

@ -39,25 +39,25 @@ urllib3 on Google App Engine:
""" """
from __future__ import absolute_import from __future__ import absolute_import
import io
import logging import logging
import os
import warnings import warnings
from ..packages.six.moves.urllib.parse import urljoin
from ..exceptions import ( from ..exceptions import (
HTTPError, HTTPError,
HTTPWarning, HTTPWarning,
MaxRetryError, MaxRetryError,
ProtocolError, ProtocolError,
SSLError,
TimeoutError, TimeoutError,
SSLError
) )
from ..packages.six.moves.urllib.parse import urljoin
from ..packages.six import BytesIO
from ..request import RequestMethods from ..request import RequestMethods
from ..response import HTTPResponse from ..response import HTTPResponse
from ..util.timeout import Timeout
from ..util.retry import Retry from ..util.retry import Retry
from ..util.timeout import Timeout
from . import _appengine_environ
try: try:
from google.appengine.api import urlfetch from google.appengine.api import urlfetch
@ -90,29 +90,30 @@ class AppEngineManager(RequestMethods):
* If you attempt to use this on App Engine Flexible, as full socket * If you attempt to use this on App Engine Flexible, as full socket
support is available. support is available.
* If a request size is more than 10 megabytes. * If a request size is more than 10 megabytes.
* If a response size is more than 32 megabtyes. * If a response size is more than 32 megabytes.
* If you use an unsupported request method such as OPTIONS. * If you use an unsupported request method such as OPTIONS.
Beyond those cases, it will raise normal urllib3 errors. Beyond those cases, it will raise normal urllib3 errors.
""" """
def __init__(self, headers=None, retries=None, validate_certificate=True, def __init__(
urlfetch_retries=True): self,
headers=None,
retries=None,
validate_certificate=True,
urlfetch_retries=True,
):
if not urlfetch: if not urlfetch:
raise AppEnginePlatformError( raise AppEnginePlatformError(
"URLFetch is not available in this environment.") "URLFetch is not available in this environment."
)
if is_prod_appengine_mvms():
raise AppEnginePlatformError(
"Use normal urllib3.PoolManager instead of AppEngineManager"
"on Managed VMs, as using URLFetch is not necessary in "
"this environment.")
warnings.warn( warnings.warn(
"urllib3 is using URLFetch on Google App Engine sandbox instead " "urllib3 is using URLFetch on Google App Engine sandbox instead "
"of sockets. To use sockets directly instead of URLFetch see " "of sockets. To use sockets directly instead of URLFetch see "
"https://urllib3.readthedocs.io/en/latest/reference/urllib3.contrib.html.", "https://urllib3.readthedocs.io/en/1.26.x/reference/urllib3.contrib.html.",
AppEnginePlatformWarning) AppEnginePlatformWarning,
)
RequestMethods.__init__(self, headers) RequestMethods.__init__(self, headers)
self.validate_certificate = validate_certificate self.validate_certificate = validate_certificate
@ -127,17 +128,22 @@ class AppEngineManager(RequestMethods):
# Return False to re-raise any potential exceptions # Return False to re-raise any potential exceptions
return False return False
def urlopen(self, method, url, body=None, headers=None, def urlopen(
retries=None, redirect=True, timeout=Timeout.DEFAULT_TIMEOUT, self,
**response_kw): method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw
):
retries = self._get_retries(retries, redirect) retries = self._get_retries(retries, redirect)
try: try:
follow_redirects = ( follow_redirects = redirect and retries.redirect != 0 and retries.total
redirect and
retries.redirect != 0 and
retries.total)
response = urlfetch.fetch( response = urlfetch.fetch(
url, url,
payload=body, payload=body,
@ -152,44 +158,52 @@ class AppEngineManager(RequestMethods):
raise TimeoutError(self, e) raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e: except urlfetch.InvalidURLError as e:
if 'too large' in str(e): if "too large" in str(e):
raise AppEnginePlatformError( raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only " "URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.", e) "supports requests up to 10mb in size.",
e,
)
raise ProtocolError(e) raise ProtocolError(e)
except urlfetch.DownloadError as e: except urlfetch.DownloadError as e:
if 'Too many redirects' in str(e): if "Too many redirects" in str(e):
raise MaxRetryError(self, url, reason=e) raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e) raise ProtocolError(e)
except urlfetch.ResponseTooLargeError as e: except urlfetch.ResponseTooLargeError as e:
raise AppEnginePlatformError( raise AppEnginePlatformError(
"URLFetch response too large, URLFetch only supports" "URLFetch response too large, URLFetch only supports"
"responses up to 32mb in size.", e) "responses up to 32mb in size.",
e,
)
except urlfetch.SSLCertificateError as e: except urlfetch.SSLCertificateError as e:
raise SSLError(e) raise SSLError(e)
except urlfetch.InvalidMethodError as e: except urlfetch.InvalidMethodError as e:
raise AppEnginePlatformError( raise AppEnginePlatformError(
"URLFetch does not support method: %s" % method, e) "URLFetch does not support method: %s" % method, e
)
http_response = self._urlfetch_response_to_http_response( http_response = self._urlfetch_response_to_http_response(
response, retries=retries, **response_kw) response, retries=retries, **response_kw
)
# Handle redirect? # Handle redirect?
redirect_location = redirect and http_response.get_redirect_location() redirect_location = redirect and http_response.get_redirect_location()
if redirect_location: if redirect_location:
# Check for redirect response # Check for redirect response
if (self.urlfetch_retries and retries.raise_on_redirect): if self.urlfetch_retries and retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects") raise MaxRetryError(self, url, "too many redirects")
else: else:
if http_response.status == 303: if http_response.status == 303:
method = 'GET' method = "GET"
try: try:
retries = retries.increment(method, url, response=http_response, _pool=self) retries = retries.increment(
method, url, response=http_response, _pool=self
)
except MaxRetryError: except MaxRetryError:
if retries.raise_on_redirect: if retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects") raise MaxRetryError(self, url, "too many redirects")
@ -199,22 +213,32 @@ class AppEngineManager(RequestMethods):
log.debug("Redirecting %s -> %s", url, redirect_location) log.debug("Redirecting %s -> %s", url, redirect_location)
redirect_url = urljoin(url, redirect_location) redirect_url = urljoin(url, redirect_location)
return self.urlopen( return self.urlopen(
method, redirect_url, body, headers, method,
retries=retries, redirect=redirect, redirect_url,
timeout=timeout, **response_kw) body,
headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
# Check if we should retry the HTTP response. # Check if we should retry the HTTP response.
has_retry_after = bool(http_response.getheader('Retry-After')) has_retry_after = bool(http_response.getheader("Retry-After"))
if retries.is_retry(method, http_response.status, has_retry_after): if retries.is_retry(method, http_response.status, has_retry_after):
retries = retries.increment( retries = retries.increment(method, url, response=http_response, _pool=self)
method, url, response=http_response, _pool=self)
log.debug("Retry: %s", url) log.debug("Retry: %s", url)
retries.sleep(http_response) retries.sleep(http_response)
return self.urlopen( return self.urlopen(
method, url, method,
body=body, headers=headers, url,
retries=retries, redirect=redirect, body=body,
timeout=timeout, **response_kw) headers=headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
return http_response return http_response
@ -223,28 +247,37 @@ class AppEngineManager(RequestMethods):
if is_prod_appengine(): if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does # Production GAE handles deflate encoding automatically, but does
# not remove the encoding header. # not remove the encoding header.
content_encoding = urlfetch_resp.headers.get('content-encoding') content_encoding = urlfetch_resp.headers.get("content-encoding")
if content_encoding == 'deflate': if content_encoding == "deflate":
del urlfetch_resp.headers['content-encoding'] del urlfetch_resp.headers["content-encoding"]
transfer_encoding = urlfetch_resp.headers.get('transfer-encoding') transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
# We have a full response's content, # We have a full response's content,
# so let's make sure we don't report ourselves as chunked data. # so let's make sure we don't report ourselves as chunked data.
if transfer_encoding == 'chunked': if transfer_encoding == "chunked":
encodings = transfer_encoding.split(",") encodings = transfer_encoding.split(",")
encodings.remove('chunked') encodings.remove("chunked")
urlfetch_resp.headers['transfer-encoding'] = ','.join(encodings) urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
return HTTPResponse( original_response = HTTPResponse(
# In order for decoding to work, we must present the content as # In order for decoding to work, we must present the content as
# a file-like object. # a file-like object.
body=BytesIO(urlfetch_resp.content), body=io.BytesIO(urlfetch_resp.content),
msg=urlfetch_resp.header_msg,
headers=urlfetch_resp.headers, headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code, status=urlfetch_resp.status_code,
**response_kw **response_kw
) )
return HTTPResponse(
body=io.BytesIO(urlfetch_resp.content),
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
original_response=original_response,
**response_kw
)
def _get_absolute_timeout(self, timeout): def _get_absolute_timeout(self, timeout):
if timeout is Timeout.DEFAULT_TIMEOUT: if timeout is Timeout.DEFAULT_TIMEOUT:
return None # Defer to URLFetch's default. return None # Defer to URLFetch's default.
@ -253,44 +286,29 @@ class AppEngineManager(RequestMethods):
warnings.warn( warnings.warn(
"URLFetch does not support granular timeout settings, " "URLFetch does not support granular timeout settings, "
"reverting to total or default URLFetch timeout.", "reverting to total or default URLFetch timeout.",
AppEnginePlatformWarning) AppEnginePlatformWarning,
)
return timeout.total return timeout.total
return timeout return timeout
def _get_retries(self, retries, redirect): def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry): if not isinstance(retries, Retry):
retries = Retry.from_int( retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect: if retries.connect or retries.read or retries.redirect:
warnings.warn( warnings.warn(
"URLFetch only supports total retries and does not " "URLFetch only supports total retries and does not "
"recognize connect, read, or redirect retry parameters.", "recognize connect, read, or redirect retry parameters.",
AppEnginePlatformWarning) AppEnginePlatformWarning,
)
return retries return retries
def is_appengine(): # Alias methods from _appengine_environ to maintain public API interface.
return (is_local_appengine() or
is_prod_appengine() or
is_prod_appengine_mvms())
is_appengine = _appengine_environ.is_appengine
def is_appengine_sandbox(): is_appengine_sandbox = _appengine_environ.is_appengine_sandbox
return is_appengine() and not is_prod_appengine_mvms() is_local_appengine = _appengine_environ.is_local_appengine
is_prod_appengine = _appengine_environ.is_prod_appengine
is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms
def is_local_appengine():
return ('APPENGINE_RUNTIME' in os.environ and
'Development/' in os.environ['SERVER_SOFTWARE'])
def is_prod_appengine():
return ('APPENGINE_RUNTIME' in os.environ and
'Google App Engine/' in os.environ['SERVER_SOFTWARE'] and
not is_prod_appengine_mvms())
def is_prod_appengine_mvms():
return os.environ.get('GAE_VM', False) == 'true'

View file

@ -5,12 +5,21 @@ Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
""" """
from __future__ import absolute_import from __future__ import absolute_import
import warnings
from logging import getLogger from logging import getLogger
from ntlm import ntlm from ntlm import ntlm
from .. import HTTPSConnectionPool from .. import HTTPSConnectionPool
from ..packages.six.moves.http_client import HTTPSConnection from ..packages.six.moves.http_client import HTTPSConnection
warnings.warn(
"The 'urllib3.contrib.ntlmpool' module is deprecated and will be removed "
"in urllib3 v2.0 release, urllib3 is not able to support it properly due "
"to reasons listed in issue: https://github.com/urllib3/urllib3/issues/2282. "
"If you are a user of this module please comment in the mentioned issue.",
DeprecationWarning,
)
log = getLogger(__name__) log = getLogger(__name__)
@ -20,7 +29,7 @@ class NTLMConnectionPool(HTTPSConnectionPool):
Implements an NTLM authentication version of an urllib3 connection pool Implements an NTLM authentication version of an urllib3 connection pool
""" """
scheme = 'https' scheme = "https"
def __init__(self, user, pw, authurl, *args, **kwargs): def __init__(self, user, pw, authurl, *args, **kwargs):
""" """
@ -31,7 +40,7 @@ class NTLMConnectionPool(HTTPSConnectionPool):
super(NTLMConnectionPool, self).__init__(*args, **kwargs) super(NTLMConnectionPool, self).__init__(*args, **kwargs)
self.authurl = authurl self.authurl = authurl
self.rawuser = user self.rawuser = user
user_parts = user.split('\\', 1) user_parts = user.split("\\", 1)
self.domain = user_parts[0].upper() self.domain = user_parts[0].upper()
self.user = user_parts[1] self.user = user_parts[1]
self.pw = pw self.pw = pw
@ -40,73 +49,82 @@ class NTLMConnectionPool(HTTPSConnectionPool):
# Performs the NTLM handshake that secures the connection. The socket # Performs the NTLM handshake that secures the connection. The socket
# must be kept open while requests are performed. # must be kept open while requests are performed.
self.num_connections += 1 self.num_connections += 1
log.debug('Starting NTLM HTTPS connection no. %d: https://%s%s', log.debug(
self.num_connections, self.host, self.authurl) "Starting NTLM HTTPS connection no. %d: https://%s%s",
self.num_connections,
self.host,
self.authurl,
)
headers = {} headers = {"Connection": "Keep-Alive"}
headers['Connection'] = 'Keep-Alive' req_header = "Authorization"
req_header = 'Authorization' resp_header = "www-authenticate"
resp_header = 'www-authenticate'
conn = HTTPSConnection(host=self.host, port=self.port) conn = HTTPSConnection(host=self.host, port=self.port)
# Send negotiation message # Send negotiation message
headers[req_header] = ( headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
'NTLM %s' % ntlm.create_NTLM_NEGOTIATE_MESSAGE(self.rawuser)) self.rawuser
log.debug('Request headers: %s', headers) )
conn.request('GET', self.authurl, None, headers) log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse() res = conn.getresponse()
reshdr = dict(res.getheaders()) reshdr = dict(res.getheaders())
log.debug('Response status: %s %s', res.status, res.reason) log.debug("Response status: %s %s", res.status, res.reason)
log.debug('Response headers: %s', reshdr) log.debug("Response headers: %s", reshdr)
log.debug('Response data: %s [...]', res.read(100)) log.debug("Response data: %s [...]", res.read(100))
# Remove the reference to the socket, so that it can not be closed by # Remove the reference to the socket, so that it can not be closed by
# the response object (we want to keep the socket open) # the response object (we want to keep the socket open)
res.fp = None res.fp = None
# Server should respond with a challenge message # Server should respond with a challenge message
auth_header_values = reshdr[resp_header].split(', ') auth_header_values = reshdr[resp_header].split(", ")
auth_header_value = None auth_header_value = None
for s in auth_header_values: for s in auth_header_values:
if s[:5] == 'NTLM ': if s[:5] == "NTLM ":
auth_header_value = s[5:] auth_header_value = s[5:]
if auth_header_value is None: if auth_header_value is None:
raise Exception('Unexpected %s response header: %s' % raise Exception(
(resp_header, reshdr[resp_header])) "Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
)
# Send authentication message # Send authentication message
ServerChallenge, NegotiateFlags = \ ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
ntlm.parse_NTLM_CHALLENGE_MESSAGE(auth_header_value) auth_header_value
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(ServerChallenge, )
self.user, auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
self.domain, ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
self.pw, )
NegotiateFlags) headers[req_header] = "NTLM %s" % auth_msg
headers[req_header] = 'NTLM %s' % auth_msg log.debug("Request headers: %s", headers)
log.debug('Request headers: %s', headers) conn.request("GET", self.authurl, None, headers)
conn.request('GET', self.authurl, None, headers)
res = conn.getresponse() res = conn.getresponse()
log.debug('Response status: %s %s', res.status, res.reason) log.debug("Response status: %s %s", res.status, res.reason)
log.debug('Response headers: %s', dict(res.getheaders())) log.debug("Response headers: %s", dict(res.getheaders()))
log.debug('Response data: %s [...]', res.read()[:100]) log.debug("Response data: %s [...]", res.read()[:100])
if res.status != 200: if res.status != 200:
if res.status == 401: if res.status == 401:
raise Exception('Server rejected request: wrong ' raise Exception("Server rejected request: wrong username or password")
'username or password') raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
raise Exception('Wrong server response: %s %s' %
(res.status, res.reason))
res.fp = None res.fp = None
log.debug('Connection established') log.debug("Connection established")
return conn return conn
def urlopen(self, method, url, body=None, headers=None, retries=3, def urlopen(
redirect=True, assert_same_host=True): self,
method,
url,
body=None,
headers=None,
retries=3,
redirect=True,
assert_same_host=True,
):
if headers is None: if headers is None:
headers = {} headers = {}
headers['Connection'] = 'Keep-Alive' headers["Connection"] = "Keep-Alive"
return super(NTLMConnectionPool, self).urlopen(method, url, body, return super(NTLMConnectionPool, self).urlopen(
headers, retries, method, url, body, headers, retries, redirect, assert_same_host
redirect, )
assert_same_host)

View file

@ -1,27 +1,31 @@
""" """
SSL with SNI_-support for Python 2. Follow these instructions if you would TLS with SNI_-support for Python 2. Follow these instructions if you would
like to verify SSL certificates in Python 2. Note, the default libraries do like to verify TLS certificates in Python 2. Note, the default libraries do
*not* do certificate checking; you need to do additional work to validate *not* do certificate checking; you need to do additional work to validate
certificates yourself. certificates yourself.
This needs the following packages installed: This needs the following packages installed:
* pyOpenSSL (tested with 16.0.0) * `pyOpenSSL`_ (tested with 16.0.0)
* cryptography (minimum 1.3.4, from pyopenssl) * `cryptography`_ (minimum 1.3.4, from pyopenssl)
* idna (minimum 2.0, from cryptography) * `idna`_ (minimum 2.0, from cryptography)
However, pyopenssl depends on cryptography, which depends on idna, so while we However, pyopenssl depends on cryptography, which depends on idna, so while we
use all three directly here we end up having relatively few packages required. use all three directly here we end up having relatively few packages required.
You can install them with the following command: You can install them with the following command:
pip install pyopenssl cryptography idna .. code-block:: bash
$ python -m pip install pyopenssl cryptography idna
To activate certificate checking, call To activate certificate checking, call
:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code :func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
before you begin making HTTP requests. This can be done in a ``sitecustomize`` before you begin making HTTP requests. This can be done in a ``sitecustomize``
module, or at any other time before your application begins using ``urllib3``, module, or at any other time before your application begins using ``urllib3``,
like this:: like this:
.. code-block:: python
try: try:
import urllib3.contrib.pyopenssl import urllib3.contrib.pyopenssl
@ -35,11 +39,11 @@ when the required modules are installed.
Activating this module also has the positive side effect of disabling SSL/TLS Activating this module also has the positive side effect of disabling SSL/TLS
compression in Python 2 (see `CRIME attack`_). compression in Python 2 (see `CRIME attack`_).
If you want to configure the default list of supported cipher suites, you can
set the ``urllib3.contrib.pyopenssl.DEFAULT_SSL_CIPHER_LIST`` variable.
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication .. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit) .. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
.. _pyopenssl: https://www.pyopenssl.org
.. _cryptography: https://cryptography.io
.. _idna: https://github.com/kjd/idna
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -48,8 +52,17 @@ from cryptography import x509
from cryptography.hazmat.backends.openssl import backend as openssl_backend from cryptography.hazmat.backends.openssl import backend as openssl_backend
from cryptography.hazmat.backends.openssl.x509 import _Certificate from cryptography.hazmat.backends.openssl.x509 import _Certificate
from socket import timeout, error as SocketError try:
from cryptography.x509 import UnsupportedExtension
except ImportError:
# UnsupportedExtension is gone in cryptography >= 2.1.0
class UnsupportedExtension(Exception):
pass
from io import BytesIO from io import BytesIO
from socket import error as SocketError
from socket import timeout
try: # Platform-specific: Python 2 try: # Platform-specific: Python 2
from socket import _fileobject from socket import _fileobject
@ -59,42 +72,41 @@ except ImportError: # Platform-specific: Python 3
import logging import logging
import ssl import ssl
from ..packages import six
import sys import sys
from .. import util from .. import util
from ..packages import six
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
__all__ = ['inject_into_urllib3', 'extract_from_urllib3'] __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works. # SNI always works.
HAS_SNI = True HAS_SNI = True
# Map from urllib3 to PyOpenSSL compatible parameter-values. # Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = { _openssl_versions = {
ssl.PROTOCOL_SSLv23: OpenSSL.SSL.SSLv23_METHOD, util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD,
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
} }
if hasattr(ssl, 'PROTOCOL_TLSv1_1') and hasattr(OpenSSL.SSL, 'TLSv1_1_METHOD'): if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
if hasattr(ssl, 'PROTOCOL_TLSv1_2') and hasattr(OpenSSL.SSL, 'TLSv1_2_METHOD'): if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
try:
_openssl_versions.update({ssl.PROTOCOL_SSLv3: OpenSSL.SSL.SSLv3_METHOD})
except AttributeError:
pass
_stdlib_to_openssl_verify = { _stdlib_to_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED: ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
} }
_openssl_to_stdlib_verify = dict( _openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
(v, k) for k, v in _stdlib_to_openssl_verify.items()
)
# OpenSSL will only write 16K at a time # OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384 SSL_WRITE_BLOCKSIZE = 16384
@ -107,10 +119,11 @@ log = logging.getLogger(__name__)
def inject_into_urllib3(): def inject_into_urllib3():
'Monkey-patch urllib3 with PyOpenSSL-backed SSL-support.' "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
_validate_dependencies_met() _validate_dependencies_met()
util.SSLContext = PyOpenSSLContext
util.ssl_.SSLContext = PyOpenSSLContext util.ssl_.SSLContext = PyOpenSSLContext
util.HAS_SNI = HAS_SNI util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI util.ssl_.HAS_SNI = HAS_SNI
@ -119,8 +132,9 @@ def inject_into_urllib3():
def extract_from_urllib3(): def extract_from_urllib3():
'Undo monkey-patching by :func:`inject_into_urllib3`.' "Undo monkey-patching by :func:`inject_into_urllib3`."
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI util.ssl_.HAS_SNI = orig_util_HAS_SNI
@ -135,17 +149,23 @@ def _validate_dependencies_met():
""" """
# Method added in `cryptography==1.1`; not available in older versions # Method added in `cryptography==1.1`; not available in older versions
from cryptography.x509.extensions import Extensions from cryptography.x509.extensions import Extensions
if getattr(Extensions, "get_extension_for_class", None) is None: if getattr(Extensions, "get_extension_for_class", None) is None:
raise ImportError("'cryptography' module missing required functionality. " raise ImportError(
"Try upgrading to v1.3.4 or newer.") "'cryptography' module missing required functionality. "
"Try upgrading to v1.3.4 or newer."
)
# pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509 # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
# attribute is only present on those versions. # attribute is only present on those versions.
from OpenSSL.crypto import X509 from OpenSSL.crypto import X509
x509 = X509() x509 = X509()
if getattr(x509, "_x509", None) is None: if getattr(x509, "_x509", None) is None:
raise ImportError("'pyOpenSSL' module missing required functionality. " raise ImportError(
"Try upgrading to v0.14 or newer.") "'pyOpenSSL' module missing required functionality. "
"Try upgrading to v0.14 or newer."
)
def _dnsname_to_stdlib(name): def _dnsname_to_stdlib(name):
@ -157,7 +177,11 @@ def _dnsname_to_stdlib(name):
from ASCII bytes. We need to idna-encode that string to get it back, and from ASCII bytes. We need to idna-encode that string to get it back, and
then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8). uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
If the name cannot be idna-encoded then we return None signalling that
the name given should be skipped.
""" """
def idna_encode(name): def idna_encode(name):
""" """
Borrowed wholesale from the Python Cryptography Project. It turns out Borrowed wholesale from the Python Cryptography Project. It turns out
@ -166,15 +190,24 @@ def _dnsname_to_stdlib(name):
""" """
import idna import idna
for prefix in [u'*.', u'.']: try:
for prefix in [u"*.", u"."]:
if name.startswith(prefix): if name.startswith(prefix):
name = name[len(prefix) :] name = name[len(prefix) :]
return prefix.encode('ascii') + idna.encode(name) return prefix.encode("ascii") + idna.encode(name)
return idna.encode(name) return idna.encode(name)
except idna.core.IDNAError:
return None
# Don't send IPv6 addresses through the IDNA encoder.
if ":" in name:
return name
name = idna_encode(name) name = idna_encode(name)
if sys.version_info >= (3, 0): if name is None:
name = name.decode('utf-8') return None
elif sys.version_info >= (3, 0):
name = name.decode("utf-8")
return name return name
@ -193,14 +226,16 @@ def get_subj_alt_name(peer_cert):
# We want to find the SAN extension. Ask Cryptography to locate it (it's # We want to find the SAN extension. Ask Cryptography to locate it (it's
# faster than looping in Python) # faster than looping in Python)
try: try:
ext = cert.extensions.get_extension_for_class( ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
x509.SubjectAlternativeName
).value
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
# No such extension, return the empty list. # No such extension, return the empty list.
return [] return []
except (x509.DuplicateExtension, x509.UnsupportedExtension, except (
x509.UnsupportedGeneralNameType, UnicodeError) as e: x509.DuplicateExtension,
UnsupportedExtension,
x509.UnsupportedGeneralNameType,
UnicodeError,
) as e:
# A problem has been found with the quality of the certificate. Assume # A problem has been found with the quality of the certificate. Assume
# no SAN field is present. # no SAN field is present.
log.warning( log.warning(
@ -217,24 +252,25 @@ def get_subj_alt_name(peer_cert):
# Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8 # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
# decoded. This is pretty frustrating, but that's what the standard library # decoded. This is pretty frustrating, but that's what the standard library
# does with certificates, and so we need to attempt to do the same. # does with certificates, and so we need to attempt to do the same.
# We also want to skip over names which cannot be idna encoded.
names = [ names = [
('DNS', _dnsname_to_stdlib(name)) ("DNS", name)
for name in ext.get_values_for_type(x509.DNSName) for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
if name is not None
] ]
names.extend( names.extend(
('IP Address', str(name)) ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
for name in ext.get_values_for_type(x509.IPAddress)
) )
return names return names
class WrappedSocket(object): class WrappedSocket(object):
'''API-compatibility wrapper for Python OpenSSL's Connection-class. """API-compatibility wrapper for Python OpenSSL's Connection-class.
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
collector of pypy. collector of pypy.
''' """
def __init__(self, connection, socket, suppress_ragged_eofs=True): def __init__(self, connection, socket, suppress_ragged_eofs=True):
self.connection = connection self.connection = connection
@ -257,21 +293,24 @@ class WrappedSocket(object):
try: try:
data = self.connection.recv(*args, **kwargs) data = self.connection.recv(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e: except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'): if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return b'' return b""
else: else:
raise SocketError(str(e)) raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError as e: except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b'' return b""
else: else:
raise raise
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError:
rd = util.wait_for_read(self.socket, self.socket.gettimeout()) if not util.wait_for_read(self.socket, self.socket.gettimeout()):
if not rd: raise timeout("The read operation timed out")
raise timeout('The read operation timed out')
else: else:
return self.recv(*args, **kwargs) return self.recv(*args, **kwargs)
# TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e)
else: else:
return data return data
@ -279,22 +318,25 @@ class WrappedSocket(object):
try: try:
return self.connection.recv_into(*args, **kwargs) return self.connection.recv_into(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e: except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'): if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return 0 return 0
else: else:
raise SocketError(str(e)) raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError as e: except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return 0 return 0
else: else:
raise raise
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError:
rd = util.wait_for_read(self.socket, self.socket.gettimeout()) if not util.wait_for_read(self.socket, self.socket.gettimeout()):
if not rd: raise timeout("The read operation timed out")
raise timeout('The read operation timed out')
else: else:
return self.recv_into(*args, **kwargs) return self.recv_into(*args, **kwargs)
# TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e)
def settimeout(self, timeout): def settimeout(self, timeout):
return self.socket.settimeout(timeout) return self.socket.settimeout(timeout)
@ -303,8 +345,7 @@ class WrappedSocket(object):
try: try:
return self.connection.send(data) return self.connection.send(data)
except OpenSSL.SSL.WantWriteError: except OpenSSL.SSL.WantWriteError:
wr = util.wait_for_write(self.socket, self.socket.gettimeout()) if not util.wait_for_write(self.socket, self.socket.gettimeout()):
if not wr:
raise timeout() raise timeout()
continue continue
except OpenSSL.SSL.SysCallError as e: except OpenSSL.SSL.SysCallError as e:
@ -313,7 +354,9 @@ class WrappedSocket(object):
def sendall(self, data): def sendall(self, data):
total_sent = 0 total_sent = 0
while total_sent < len(data): while total_sent < len(data):
sent = self._send_until_done(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE]) sent = self._send_until_done(
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
)
total_sent += sent total_sent += sent
def shutdown(self): def shutdown(self):
@ -337,17 +380,16 @@ class WrappedSocket(object):
return x509 return x509
if binary_form: if binary_form:
return OpenSSL.crypto.dump_certificate( return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
OpenSSL.crypto.FILETYPE_ASN1,
x509)
return { return {
'subject': ( "subject": ((("commonName", x509.get_subject().CN),),),
(('commonName', x509.get_subject().CN),), "subjectAltName": get_subj_alt_name(x509),
),
'subjectAltName': get_subj_alt_name(x509)
} }
def version(self):
return self.connection.get_protocol_version_name()
def _reuse(self): def _reuse(self):
self._makefile_refs += 1 self._makefile_refs += 1
@ -359,9 +401,12 @@ class WrappedSocket(object):
if _fileobject: # Platform-specific: Python 2 if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1): def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1 self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True) return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3 else: # Platform-specific: Python 3
makefile = backport_makefile makefile = backport_makefile
@ -374,6 +419,7 @@ class PyOpenSSLContext(object):
for translating the interface of the standard library ``SSLContext`` object for translating the interface of the standard library ``SSLContext`` object
to calls into PyOpenSSL. to calls into PyOpenSSL.
""" """
def __init__(self, protocol): def __init__(self, protocol):
self.protocol = _openssl_versions[protocol] self.protocol = _openssl_versions[protocol]
self._ctx = OpenSSL.SSL.Context(self.protocol) self._ctx = OpenSSL.SSL.Context(self.protocol)
@ -395,41 +441,52 @@ class PyOpenSSLContext(object):
@verify_mode.setter @verify_mode.setter
def verify_mode(self, value): def verify_mode(self, value):
self._ctx.set_verify( self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
_stdlib_to_openssl_verify[value],
_verify_callback
)
def set_default_verify_paths(self): def set_default_verify_paths(self):
self._ctx.set_default_verify_paths() self._ctx.set_default_verify_paths()
def set_ciphers(self, ciphers): def set_ciphers(self, ciphers):
if isinstance(ciphers, six.text_type): if isinstance(ciphers, six.text_type):
ciphers = ciphers.encode('utf-8') ciphers = ciphers.encode("utf-8")
self._ctx.set_cipher_list(ciphers) self._ctx.set_cipher_list(ciphers)
def load_verify_locations(self, cafile=None, capath=None, cadata=None): def load_verify_locations(self, cafile=None, capath=None, cadata=None):
if cafile is not None: if cafile is not None:
cafile = cafile.encode('utf-8') cafile = cafile.encode("utf-8")
if capath is not None: if capath is not None:
capath = capath.encode('utf-8') capath = capath.encode("utf-8")
try:
self._ctx.load_verify_locations(cafile, capath) self._ctx.load_verify_locations(cafile, capath)
if cadata is not None: if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata)) self._ctx.load_verify_locations(BytesIO(cadata))
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("unable to load trusted certificates: %r" % e)
def load_cert_chain(self, certfile, keyfile=None, password=None): def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.use_certificate_file(certfile) self._ctx.use_certificate_chain_file(certfile)
if password is not None: if password is not None:
self._ctx.set_passwd_cb(lambda max_length, prompt_twice, userdata: password) if not isinstance(password, six.binary_type):
password = password.encode("utf-8")
self._ctx.set_passwd_cb(lambda *_: password)
self._ctx.use_privatekey_file(keyfile or certfile) self._ctx.use_privatekey_file(keyfile or certfile)
def wrap_socket(self, sock, server_side=False, def set_alpn_protocols(self, protocols):
do_handshake_on_connect=True, suppress_ragged_eofs=True, protocols = [six.ensure_binary(p) for p in protocols]
server_hostname=None): return self._ctx.set_alpn_protos(protocols)
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
cnx = OpenSSL.SSL.Connection(self._ctx, sock) cnx = OpenSSL.SSL.Connection(self._ctx, sock)
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3 if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
server_hostname = server_hostname.encode('utf-8') server_hostname = server_hostname.encode("utf-8")
if server_hostname is not None: if server_hostname is not None:
cnx.set_tlsext_host_name(server_hostname) cnx.set_tlsext_host_name(server_hostname)
@ -440,12 +497,11 @@ class PyOpenSSLContext(object):
try: try:
cnx.do_handshake() cnx.do_handshake()
except OpenSSL.SSL.WantReadError: except OpenSSL.SSL.WantReadError:
rd = util.wait_for_read(sock, sock.gettimeout()) if not util.wait_for_read(sock, sock.gettimeout()):
if not rd: raise timeout("select timed out")
raise timeout('select timed out')
continue continue
except OpenSSL.SSL.Error as e: except OpenSSL.SSL.Error as e:
raise ssl.SSLError('bad handshake: %r' % e) raise ssl.SSLError("bad handshake: %r" % e)
break break
return WrappedSocket(cnx, sock) return WrappedSocket(cnx, sock)

View file

@ -23,6 +23,33 @@ To use this module, simply import and inject it::
urllib3.contrib.securetransport.inject_into_urllib3() urllib3.contrib.securetransport.inject_into_urllib3()
Happy TLSing! Happy TLSing!
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
.. code-block::
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -33,16 +60,22 @@ import os.path
import shutil import shutil
import socket import socket
import ssl import ssl
import struct
import threading import threading
import weakref import weakref
import six
from .. import util from .. import util
from ._securetransport.bindings import ( from ..util.ssl_ import PROTOCOL_TLS_CLIENT
Security, SecurityConst, CoreFoundation from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
)
from ._securetransport.low_level import ( from ._securetransport.low_level import (
_assert_no_error, _cert_array_from_pem, _temporary_keychain, _assert_no_error,
_load_client_cert_chain _build_tls_unknown_ca_alert,
_cert_array_from_pem,
_create_cfstring_array,
_load_client_cert_chain,
_temporary_keychain,
) )
try: # Platform-specific: Python 2 try: # Platform-specific: Python 2
@ -51,12 +84,7 @@ except ImportError: # Platform-specific: Python 3
_fileobject = None _fileobject = None
from ..packages.backports.makefile import backport_makefile from ..packages.backports.makefile import backport_makefile
try: __all__ = ["inject_into_urllib3", "extract_from_urllib3"]
memoryview(b'')
except NameError:
raise ImportError("SecureTransport only works on Pythons with memoryview")
__all__ = ['inject_into_urllib3', 'extract_from_urllib3']
# SNI always works # SNI always works
HAS_SNI = True HAS_SNI = True
@ -88,38 +116,35 @@ _connection_ref_lock = threading.Lock()
SSL_WRITE_BLOCKSIZE = 16384 SSL_WRITE_BLOCKSIZE = 16384
# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to # This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to
# individual cipher suites. We need to do this becuase this is how # individual cipher suites. We need to do this because this is how
# SecureTransport wants them. # SecureTransport wants them.
CIPHER_SUITES = [ CIPHER_SUITES = [
SecurityConst.TLS_AES_256_GCM_SHA384,
SecurityConst.TLS_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_DHE_DSS_WITH_AES_256_GCM_SHA384, SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384, SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_DHE_DSS_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_DHE_DSS_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_DHE_DSS_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_DHE_DSS_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_DHE_DSS_WITH_AES_128_CBC_SHA, SecurityConst.TLS_AES_256_GCM_SHA384,
SecurityConst.TLS_AES_128_GCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384, SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256, SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_AES_128_CCM_8_SHA256,
SecurityConst.TLS_AES_128_CCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256, SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256, SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA, SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA,
@ -128,38 +153,44 @@ CIPHER_SUITES = [
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of # Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version. # TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+
_protocol_to_min_max = { _protocol_to_min_max = {
ssl.PROTOCOL_SSLv23: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
PROTOCOL_TLS_CLIENT: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
} }
if hasattr(ssl, "PROTOCOL_SSLv2"): if hasattr(ssl, "PROTOCOL_SSLv2"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = ( _protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
SecurityConst.kSSLProtocol2, SecurityConst.kSSLProtocol2 SecurityConst.kSSLProtocol2,
SecurityConst.kSSLProtocol2,
) )
if hasattr(ssl, "PROTOCOL_SSLv3"): if hasattr(ssl, "PROTOCOL_SSLv3"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = ( _protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
SecurityConst.kSSLProtocol3, SecurityConst.kSSLProtocol3 SecurityConst.kSSLProtocol3,
SecurityConst.kSSLProtocol3,
) )
if hasattr(ssl, "PROTOCOL_TLSv1"): if hasattr(ssl, "PROTOCOL_TLSv1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = ( _protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol1 SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol1,
) )
if hasattr(ssl, "PROTOCOL_TLSv1_1"): if hasattr(ssl, "PROTOCOL_TLSv1_1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = ( _protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
SecurityConst.kTLSProtocol11, SecurityConst.kTLSProtocol11 SecurityConst.kTLSProtocol11,
SecurityConst.kTLSProtocol11,
) )
if hasattr(ssl, "PROTOCOL_TLSv1_2"): if hasattr(ssl, "PROTOCOL_TLSv1_2"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = ( _protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
SecurityConst.kTLSProtocol12, SecurityConst.kTLSProtocol12 SecurityConst.kTLSProtocol12,
SecurityConst.kTLSProtocol12,
) )
if hasattr(ssl, "PROTOCOL_TLS"):
_protocol_to_min_max[ssl.PROTOCOL_TLS] = _protocol_to_min_max[ssl.PROTOCOL_SSLv23]
def inject_into_urllib3(): def inject_into_urllib3():
""" """
Monkey-patch urllib3 with SecureTransport-backed SSL-support. Monkey-patch urllib3 with SecureTransport-backed SSL-support.
""" """
util.SSLContext = SecureTransportContext
util.ssl_.SSLContext = SecureTransportContext util.ssl_.SSLContext = SecureTransportContext
util.HAS_SNI = HAS_SNI util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI util.ssl_.HAS_SNI = HAS_SNI
@ -171,6 +202,7 @@ def extract_from_urllib3():
""" """
Undo monkey-patching by :func:`inject_into_urllib3`. Undo monkey-patching by :func:`inject_into_urllib3`.
""" """
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI util.ssl_.HAS_SNI = orig_util_HAS_SNI
@ -195,21 +227,18 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
timeout = wrapped_socket.gettimeout() timeout = wrapped_socket.gettimeout()
error = None error = None
read_count = 0 read_count = 0
buffer = (ctypes.c_char * requested_length).from_address(data_buffer)
buffer_view = memoryview(buffer)
try: try:
while read_count < requested_length: while read_count < requested_length:
if timeout is None or timeout >= 0: if timeout is None or timeout >= 0:
readables = util.wait_for_read([base_socket], timeout) if not util.wait_for_read(base_socket, timeout):
if not readables: raise socket.error(errno.EAGAIN, "timed out")
raise socket.error(errno.EAGAIN, 'timed out')
# We need to tell ctypes that we have a buffer that can be remaining = requested_length - read_count
# written to. Upsettingly, we do that like this: buffer = (ctypes.c_char * remaining).from_address(
chunk_size = base_socket.recv_into( data_buffer + read_count
buffer_view[read_count:requested_length]
) )
chunk_size = base_socket.recv_into(buffer, remaining)
read_count += chunk_size read_count += chunk_size
if not chunk_size: if not chunk_size:
if not read_count: if not read_count:
@ -219,7 +248,8 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
error = e.errno error = e.errno
if error is not None and error != errno.EAGAIN: if error is not None and error != errno.EAGAIN:
if error == errno.ECONNRESET: data_length_pointer[0] = read_count
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort return SecurityConst.errSSLClosedAbort
raise raise
@ -257,9 +287,8 @@ def _write_callback(connection_id, data_buffer, data_length_pointer):
try: try:
while sent < bytes_to_write: while sent < bytes_to_write:
if timeout is None or timeout >= 0: if timeout is None or timeout >= 0:
writables = util.wait_for_write([base_socket], timeout) if not util.wait_for_write(base_socket, timeout):
if not writables: raise socket.error(errno.EAGAIN, "timed out")
raise socket.error(errno.EAGAIN, 'timed out')
chunk_sent = base_socket.send(data) chunk_sent = base_socket.send(data)
sent += chunk_sent sent += chunk_sent
@ -270,11 +299,13 @@ def _write_callback(connection_id, data_buffer, data_length_pointer):
error = e.errno error = e.errno
if error is not None and error != errno.EAGAIN: if error is not None and error != errno.EAGAIN:
if error == errno.ECONNRESET: data_length_pointer[0] = sent
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort return SecurityConst.errSSLClosedAbort
raise raise
data_length_pointer[0] = sent data_length_pointer[0] = sent
if sent != bytes_to_write: if sent != bytes_to_write:
return SecurityConst.errSSLWouldBlock return SecurityConst.errSSLWouldBlock
@ -299,6 +330,7 @@ class WrappedSocket(object):
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
collector of PyPy. collector of PyPy.
""" """
def __init__(self, socket): def __init__(self, socket):
self.socket = socket self.socket = socket
self.context = None self.context = None
@ -351,19 +383,58 @@ class WrappedSocket(object):
) )
_assert_no_error(result) _assert_no_error(result)
def _set_alpn_protocols(self, protocols):
"""
Sets up the ALPN protocols on the context.
"""
if not protocols:
return
protocols_arr = _create_cfstring_array(protocols)
try:
result = Security.SSLSetALPNProtocols(self.context, protocols_arr)
_assert_no_error(result)
finally:
CoreFoundation.CFRelease(protocols_arr)
def _custom_validate(self, verify, trust_bundle): def _custom_validate(self, verify, trust_bundle):
""" """
Called when we have set custom validation. We do this in two cases: Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when first, when cert validation is entirely disabled; and second, when
using a custom trust DB. using a custom trust DB.
Raises an SSLError if the connection is not trusted.
""" """
# If we disabled cert validation, just say: cool. # If we disabled cert validation, just say: cool.
if not verify: if not verify:
return return
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
try:
trust_result = self._evaluate_trust(trust_bundle)
if trust_result in successes:
return
reason = "error code: %d" % (trust_result,)
except Exception as e:
# Do not trust on error
reason = "exception: %r" % (e,)
# SecureTransport does not send an alert nor shuts down the connection.
rec = _build_tls_unknown_ca_alert(self.version())
self.socket.sendall(rec)
# close the connection immediately
# l_onoff = 1, activate linger
# l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
self.close()
raise ssl.SSLError("certificate verify failed, %s" % reason)
def _evaluate_trust(self, trust_bundle):
# We want data in memory, so load it up. # We want data in memory, so load it up.
if os.path.isfile(trust_bundle): if os.path.isfile(trust_bundle):
with open(trust_bundle, 'rb') as f: with open(trust_bundle, "rb") as f:
trust_bundle = f.read() trust_bundle = f.read()
cert_array = None cert_array = None
@ -377,9 +448,7 @@ class WrappedSocket(object):
# created for this connection, shove our CAs into it, tell ST to # created for this connection, shove our CAs into it, tell ST to
# ignore everything else it knows, and then ask if it can build a # ignore everything else it knows, and then ask if it can build a
# chain. This is a buuuunch of code. # chain. This is a buuuunch of code.
result = Security.SSLCopyPeerTrust( result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
self.context, ctypes.byref(trust)
)
_assert_no_error(result) _assert_no_error(result)
if not trust: if not trust:
raise ssl.SSLError("Failed to copy trust reference") raise ssl.SSLError("Failed to copy trust reference")
@ -391,29 +460,19 @@ class WrappedSocket(object):
_assert_no_error(result) _assert_no_error(result)
trust_result = Security.SecTrustResultType() trust_result = Security.SecTrustResultType()
result = Security.SecTrustEvaluate( result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
trust, ctypes.byref(trust_result)
)
_assert_no_error(result) _assert_no_error(result)
finally: finally:
if trust: if trust:
CoreFoundation.CFRelease(trust) CoreFoundation.CFRelease(trust)
if cert_array is None: if cert_array is not None:
CoreFoundation.CFRelease(cert_array) CoreFoundation.CFRelease(cert_array)
# Ok, now we can look at what the result was. return trust_result.value
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed
)
if trust_result.value not in successes:
raise ssl.SSLError(
"certificate verify failed, error code: %d" %
trust_result.value
)
def handshake(self, def handshake(
self,
server_hostname, server_hostname,
verify, verify,
trust_bundle, trust_bundle,
@ -421,7 +480,9 @@ class WrappedSocket(object):
max_version, max_version,
client_cert, client_cert,
client_key, client_key,
client_key_passphrase): client_key_passphrase,
alpn_protocols,
):
""" """
Actually performs the TLS handshake. This is run automatically by Actually performs the TLS handshake. This is run automatically by
wrapped socket, and shouldn't be needed in user code. wrapped socket, and shouldn't be needed in user code.
@ -451,7 +512,7 @@ class WrappedSocket(object):
# If we have a server hostname, we should set that too. # If we have a server hostname, we should set that too.
if server_hostname: if server_hostname:
if not isinstance(server_hostname, bytes): if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode('utf-8') server_hostname = server_hostname.encode("utf-8")
result = Security.SSLSetPeerDomainName( result = Security.SSLSetPeerDomainName(
self.context, server_hostname, len(server_hostname) self.context, server_hostname, len(server_hostname)
@ -461,9 +522,13 @@ class WrappedSocket(object):
# Setup the ciphers. # Setup the ciphers.
self._set_ciphers() self._set_ciphers()
# Setup the ALPN protocols.
self._set_alpn_protocols(alpn_protocols)
# Set the minimum and maximum TLS versions. # Set the minimum and maximum TLS versions.
result = Security.SSLSetProtocolVersionMin(self.context, min_version) result = Security.SSLSetProtocolVersionMin(self.context, min_version)
_assert_no_error(result) _assert_no_error(result)
result = Security.SSLSetProtocolVersionMax(self.context, max_version) result = Security.SSLSetProtocolVersionMax(self.context, max_version)
_assert_no_error(result) _assert_no_error(result)
@ -473,9 +538,7 @@ class WrappedSocket(object):
# authing in that case. # authing in that case.
if not verify or trust_bundle is not None: if not verify or trust_bundle is not None:
result = Security.SSLSetSessionOption( result = Security.SSLSetSessionOption(
self.context, self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
SecurityConst.kSSLSessionOptionBreakOnServerAuth,
True
) )
_assert_no_error(result) _assert_no_error(result)
@ -485,9 +548,7 @@ class WrappedSocket(object):
self._client_cert_chain = _load_client_cert_chain( self._client_cert_chain = _load_client_cert_chain(
self._keychain, client_cert, client_key self._keychain, client_cert, client_key
) )
result = Security.SSLSetCertificate( result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
self.context, self._client_cert_chain
)
_assert_no_error(result) _assert_no_error(result)
while True: while True:
@ -538,7 +599,7 @@ class WrappedSocket(object):
# There are some result codes that we want to treat as "not always # There are some result codes that we want to treat as "not always
# errors". Specifically, those are errSSLWouldBlock, # errors". Specifically, those are errSSLWouldBlock,
# errSSLClosedGraceful, and errSSLClosedNoNotify. # errSSLClosedGraceful, and errSSLClosedNoNotify.
if (result == SecurityConst.errSSLWouldBlock): if result == SecurityConst.errSSLWouldBlock:
# If we didn't process any bytes, then this was just a time out. # If we didn't process any bytes, then this was just a time out.
# However, we can get errSSLWouldBlock in situations when we *did* # However, we can get errSSLWouldBlock in situations when we *did*
# read some data, and in those cases we should just read "short" # read some data, and in those cases we should just read "short"
@ -546,7 +607,10 @@ class WrappedSocket(object):
if processed_bytes.value == 0: if processed_bytes.value == 0:
# Timed out, no data read. # Timed out, no data read.
raise socket.timeout("recv timed out") raise socket.timeout("recv timed out")
elif result in (SecurityConst.errSSLClosedGraceful, SecurityConst.errSSLClosedNoNotify): elif result in (
SecurityConst.errSSLClosedGraceful,
SecurityConst.errSSLClosedNoNotify,
):
# The remote peer has closed this connection. We should do so as # The remote peer has closed this connection. We should do so as
# well. Note that we don't actually return here because in # well. Note that we don't actually return here because in
# principle this could actually be fired along with return data. # principle this could actually be fired along with return data.
@ -632,18 +696,14 @@ class WrappedSocket(object):
# instead to just flag to urllib3 that it shouldn't do its own hostname # instead to just flag to urllib3 that it shouldn't do its own hostname
# validation when using SecureTransport. # validation when using SecureTransport.
if not binary_form: if not binary_form:
raise ValueError( raise ValueError("SecureTransport only supports dumping binary certs")
"SecureTransport only supports dumping binary certs"
)
trust = Security.SecTrustRef() trust = Security.SecTrustRef()
certdata = None certdata = None
der_bytes = None der_bytes = None
try: try:
# Grab the trust store. # Grab the trust store.
result = Security.SSLCopyPeerTrust( result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
self.context, ctypes.byref(trust)
)
_assert_no_error(result) _assert_no_error(result)
if not trust: if not trust:
# Probably we haven't done the handshake yet. No biggie. # Probably we haven't done the handshake yet. No biggie.
@ -673,6 +733,27 @@ class WrappedSocket(object):
return der_bytes return der_bytes
def version(self):
protocol = Security.SSLProtocol()
result = Security.SSLGetNegotiatedProtocolVersion(
self.context, ctypes.byref(protocol)
)
_assert_no_error(result)
if protocol.value == SecurityConst.kTLSProtocol13:
raise ssl.SSLError("SecureTransport does not support TLS 1.3")
elif protocol.value == SecurityConst.kTLSProtocol12:
return "TLSv1.2"
elif protocol.value == SecurityConst.kTLSProtocol11:
return "TLSv1.1"
elif protocol.value == SecurityConst.kTLSProtocol1:
return "TLSv1"
elif protocol.value == SecurityConst.kSSLProtocol3:
return "SSLv3"
elif protocol.value == SecurityConst.kSSLProtocol2:
return "SSLv2"
else:
raise ssl.SSLError("Unknown TLS version: %r" % protocol)
def _reuse(self): def _reuse(self):
self._makefile_refs += 1 self._makefile_refs += 1
@ -684,16 +765,21 @@ class WrappedSocket(object):
if _fileobject: # Platform-specific: Python 2 if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1): def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1 self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True) return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3 else: # Platform-specific: Python 3
def makefile(self, mode="r", buffering=None, *args, **kwargs): def makefile(self, mode="r", buffering=None, *args, **kwargs):
# We disable buffering with SecureTransport because it conflicts with # We disable buffering with SecureTransport because it conflicts with
# the buffering that ST does internally (see issue #1153 for more). # the buffering that ST does internally (see issue #1153 for more).
buffering = 0 buffering = 0
return backport_makefile(self, mode, buffering, *args, **kwargs) return backport_makefile(self, mode, buffering, *args, **kwargs)
WrappedSocket.makefile = makefile WrappedSocket.makefile = makefile
@ -703,6 +789,7 @@ class SecureTransportContext(object):
interface of the standard library ``SSLContext`` object to calls into interface of the standard library ``SSLContext`` object to calls into
SecureTransport. SecureTransport.
""" """
def __init__(self, protocol): def __init__(self, protocol):
self._min_version, self._max_version = _protocol_to_min_max[protocol] self._min_version, self._max_version = _protocol_to_min_max[protocol]
self._options = 0 self._options = 0
@ -711,6 +798,7 @@ class SecureTransportContext(object):
self._client_cert = None self._client_cert = None
self._client_key = None self._client_key = None
self._client_key_passphrase = None self._client_key_passphrase = None
self._alpn_protocols = None
@property @property
def check_hostname(self): def check_hostname(self):
@ -769,16 +857,17 @@ class SecureTransportContext(object):
def set_ciphers(self, ciphers): def set_ciphers(self, ciphers):
# For now, we just require the default cipher string. # For now, we just require the default cipher string.
if ciphers != util.ssl_.DEFAULT_CIPHERS: if ciphers != util.ssl_.DEFAULT_CIPHERS:
raise ValueError( raise ValueError("SecureTransport doesn't support custom cipher strings")
"SecureTransport doesn't support custom cipher strings"
)
def load_verify_locations(self, cafile=None, capath=None, cadata=None): def load_verify_locations(self, cafile=None, capath=None, cadata=None):
# OK, we only really support cadata and cafile. # OK, we only really support cadata and cafile.
if capath is not None: if capath is not None:
raise ValueError( raise ValueError("SecureTransport does not support cert directories")
"SecureTransport does not support cert directories"
) # Raise if cafile does not exist.
if cafile is not None:
with open(cafile):
pass
self._trust_bundle = cafile or cadata self._trust_bundle = cafile or cadata
@ -787,9 +876,26 @@ class SecureTransportContext(object):
self._client_key = keyfile self._client_key = keyfile
self._client_cert_passphrase = password self._client_cert_passphrase = password
def wrap_socket(self, sock, server_side=False, def set_alpn_protocols(self, protocols):
do_handshake_on_connect=True, suppress_ragged_eofs=True, """
server_hostname=None): Sets the ALPN protocols that will later be set on the context.
Raises a NotImplementedError if ALPN is not supported.
"""
if not hasattr(Security, "SSLSetALPNProtocols"):
raise NotImplementedError(
"SecureTransport supports ALPN only in macOS 10.12+"
)
self._alpn_protocols = [six.ensure_binary(p) for p in protocols]
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
# So, what do we do here? Firstly, we assert some properties. This is a # So, what do we do here? Firstly, we assert some properties. This is a
# stripped down shim, so there is some functionality we don't support. # stripped down shim, so there is some functionality we don't support.
# See PEP 543 for the real deal. # See PEP 543 for the real deal.
@ -803,8 +909,14 @@ class SecureTransportContext(object):
# Now we can handshake # Now we can handshake
wrapped_socket.handshake( wrapped_socket.handshake(
server_hostname, self._verify, self._trust_bundle, server_hostname,
self._min_version, self._max_version, self._client_cert, self._verify,
self._client_key, self._client_key_passphrase self._trust_bundle,
self._min_version,
self._max_version,
self._client_cert,
self._client_key,
self._client_key_passphrase,
self._alpn_protocols,
) )
return wrapped_socket return wrapped_socket

View file

@ -1,25 +1,42 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
This module contains provisional support for SOCKS proxies from within This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4 (specifically the SOCKS4A variant) and urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install PySocks or install this SOCKS5. To enable its functionality, either install PySocks or install this
module with the ``socks`` extra. module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features: supports the following SOCKS features:
- SOCKS4 - SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4a - SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 - SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy - Usernames and passwords for the SOCKS proxy
Known Limitations: .. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request:
.. code-block:: python
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy:
.. code-block:: python
proxy_url="socks5h://<username>:<password>@proxy-host"
- Currently PySocks does not support contacting remote websites via literal
IPv6 addresses. Any such connection attempt will fail. You must use a domain
name.
- Currently PySocks does not support IPv6 connections to the SOCKS proxy. Any
such connection attempt will fail.
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -27,25 +44,24 @@ try:
import socks import socks
except ImportError: except ImportError:
import warnings import warnings
from ..exceptions import DependencyWarning from ..exceptions import DependencyWarning
warnings.warn(( warnings.warn(
'SOCKS support in urllib3 requires the installation of optional ' (
'dependencies: specifically, PySocks. For more information, see ' "SOCKS support in urllib3 requires the installation of optional "
'https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies' "dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/contrib.html#socks-proxies"
), ),
DependencyWarning DependencyWarning,
) )
raise raise
from socket import error as SocketError, timeout as SocketTimeout from socket import error as SocketError
from socket import timeout as SocketTimeout
from ..connection import ( from ..connection import HTTPConnection, HTTPSConnection
HTTPConnection, HTTPSConnection from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
)
from ..connectionpool import (
HTTPConnectionPool, HTTPSConnectionPool
)
from ..exceptions import ConnectTimeoutError, NewConnectionError from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager from ..poolmanager import PoolManager
from ..util.url import parse_url from ..util.url import parse_url
@ -60,8 +76,9 @@ class SOCKSConnection(HTTPConnection):
""" """
A plain-text HTTP connection that connects via a SOCKS proxy. A plain-text HTTP connection that connects via a SOCKS proxy.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._socks_options = kwargs.pop('_socks_options') self._socks_options = kwargs.pop("_socks_options")
super(SOCKSConnection, self).__init__(*args, **kwargs) super(SOCKSConnection, self).__init__(*args, **kwargs)
def _new_conn(self): def _new_conn(self):
@ -70,28 +87,30 @@ class SOCKSConnection(HTTPConnection):
""" """
extra_kw = {} extra_kw = {}
if self.source_address: if self.source_address:
extra_kw['source_address'] = self.source_address extra_kw["source_address"] = self.source_address
if self.socket_options: if self.socket_options:
extra_kw['socket_options'] = self.socket_options extra_kw["socket_options"] = self.socket_options
try: try:
conn = socks.create_connection( conn = socks.create_connection(
(self.host, self.port), (self.host, self.port),
proxy_type=self._socks_options['socks_version'], proxy_type=self._socks_options["socks_version"],
proxy_addr=self._socks_options['proxy_host'], proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options['proxy_port'], proxy_port=self._socks_options["proxy_port"],
proxy_username=self._socks_options['username'], proxy_username=self._socks_options["username"],
proxy_password=self._socks_options['password'], proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options['rdns'], proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout, timeout=self.timeout,
**extra_kw **extra_kw
) )
except SocketTimeout as e: except SocketTimeout:
raise ConnectTimeoutError( raise ConnectTimeoutError(
self, "Connection to %s timed out. (connect timeout=%s)" % self,
(self.host, self.timeout)) "Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except socks.ProxyError as e: except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise # This is fragile as hell, but it seems to be the only way to raise
@ -101,23 +120,22 @@ class SOCKSConnection(HTTPConnection):
if isinstance(error, SocketTimeout): if isinstance(error, SocketTimeout):
raise ConnectTimeoutError( raise ConnectTimeoutError(
self, self,
"Connection to %s timed out. (connect timeout=%s)" % "Connection to %s timed out. (connect timeout=%s)"
(self.host, self.timeout) % (self.host, self.timeout),
) )
else: else:
raise NewConnectionError( raise NewConnectionError(
self, self, "Failed to establish a new connection: %s" % error
"Failed to establish a new connection: %s" % error
) )
else: else:
raise NewConnectionError( raise NewConnectionError(
self, self, "Failed to establish a new connection: %s" % e
"Failed to establish a new connection: %s" % e
) )
except SocketError as e: # Defensive: PySocks should catch all these. except SocketError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError( raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e) self, "Failed to establish a new connection: %s" % e
)
return conn return conn
@ -143,43 +161,53 @@ class SOCKSProxyManager(PoolManager):
A version of the urllib3 ProxyManager that routes connections via the A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy. defined SOCKS proxy.
""" """
pool_classes_by_scheme = { pool_classes_by_scheme = {
'http': SOCKSHTTPConnectionPool, "http": SOCKSHTTPConnectionPool,
'https': SOCKSHTTPSConnectionPool, "https": SOCKSHTTPSConnectionPool,
} }
def __init__(self, proxy_url, username=None, password=None, def __init__(
num_pools=10, headers=None, **connection_pool_kw): self,
proxy_url,
username=None,
password=None,
num_pools=10,
headers=None,
**connection_pool_kw
):
parsed = parse_url(proxy_url) parsed = parse_url(proxy_url)
if parsed.scheme == 'socks5': if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5 socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False rdns = False
elif parsed.scheme == 'socks5h': elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5 socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True rdns = True
elif parsed.scheme == 'socks4': elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4 socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False rdns = False
elif parsed.scheme == 'socks4a': elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4 socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True rdns = True
else: else:
raise ValueError( raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
"Unable to determine SOCKS version from %s" % proxy_url
)
self.proxy_url = proxy_url self.proxy_url = proxy_url
socks_options = { socks_options = {
'socks_version': socks_version, "socks_version": socks_version,
'proxy_host': parsed.host, "proxy_host": parsed.host,
'proxy_port': parsed.port, "proxy_port": parsed.port,
'username': username, "username": username,
'password': password, "password": password,
'rdns': rdns "rdns": rdns,
} }
connection_pool_kw['_socks_options'] = socks_options connection_pool_kw["_socks_options"] = socks_options
super(SOCKSProxyManager, self).__init__( super(SOCKSProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw num_pools, headers, **connection_pool_kw

View file

@ -1,22 +1,25 @@
from __future__ import absolute_import from __future__ import absolute_import
from .packages.six.moves.http_client import (
IncompleteRead as httplib_IncompleteRead from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead
)
# Base Exceptions # Base Exceptions
class HTTPError(Exception): class HTTPError(Exception):
"Base exception used by this module." """Base exception used by this module."""
pass pass
class HTTPWarning(Warning): class HTTPWarning(Warning):
"Base warning used by this module." """Base warning used by this module."""
pass pass
class PoolError(HTTPError): class PoolError(HTTPError):
"Base exception for errors caused within a pool." """Base exception for errors caused within a pool."""
def __init__(self, pool, message): def __init__(self, pool, message):
self.pool = pool self.pool = pool
HTTPError.__init__(self, "%s: %s" % (pool, message)) HTTPError.__init__(self, "%s: %s" % (pool, message))
@ -27,7 +30,8 @@ class PoolError(HTTPError):
class RequestError(PoolError): class RequestError(PoolError):
"Base exception for PoolErrors that have associated URLs." """Base exception for PoolErrors that have associated URLs."""
def __init__(self, pool, url, message): def __init__(self, pool, url, message):
self.url = url self.url = url
PoolError.__init__(self, pool, message) PoolError.__init__(self, pool, message)
@ -38,22 +42,28 @@ class RequestError(PoolError):
class SSLError(HTTPError): class SSLError(HTTPError):
"Raised when SSL certificate fails in an HTTPS connection." """Raised when SSL certificate fails in an HTTPS connection."""
pass pass
class ProxyError(HTTPError): class ProxyError(HTTPError):
"Raised when the connection to a proxy fails." """Raised when the connection to a proxy fails."""
pass
def __init__(self, message, error, *args):
super(ProxyError, self).__init__(message, error, *args)
self.original_error = error
class DecodeError(HTTPError): class DecodeError(HTTPError):
"Raised when automatic decoding based on Content-Type fails." """Raised when automatic decoding based on Content-Type fails."""
pass pass
class ProtocolError(HTTPError): class ProtocolError(HTTPError):
"Raised when something unexpected happens mid-request/response." """Raised when something unexpected happens mid-request/response."""
pass pass
@ -63,6 +73,7 @@ ConnectionError = ProtocolError
# Leaf Exceptions # Leaf Exceptions
class MaxRetryError(RequestError): class MaxRetryError(RequestError):
"""Raised when the maximum number of retries is exceeded. """Raised when the maximum number of retries is exceeded.
@ -76,14 +87,13 @@ class MaxRetryError(RequestError):
def __init__(self, pool, url, reason=None): def __init__(self, pool, url, reason=None):
self.reason = reason self.reason = reason
message = "Max retries exceeded with url: %s (Caused by %r)" % ( message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
url, reason)
RequestError.__init__(self, pool, url, message) RequestError.__init__(self, pool, url, message)
class HostChangedError(RequestError): class HostChangedError(RequestError):
"Raised when an existing pool gets a request for a foreign host." """Raised when an existing pool gets a request for a foreign host."""
def __init__(self, pool, url, retries=3): def __init__(self, pool, url, retries=3):
message = "Tried to open a foreign host with url: %s" % url message = "Tried to open a foreign host with url: %s" % url
@ -93,6 +103,7 @@ class HostChangedError(RequestError):
class TimeoutStateError(HTTPError): class TimeoutStateError(HTTPError):
"""Raised when passing an invalid state to a timeout""" """Raised when passing an invalid state to a timeout"""
pass pass
@ -102,43 +113,50 @@ class TimeoutError(HTTPError):
Catching this error will catch both :exc:`ReadTimeoutErrors Catching this error will catch both :exc:`ReadTimeoutErrors
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`. <ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
""" """
pass pass
class ReadTimeoutError(TimeoutError, RequestError): class ReadTimeoutError(TimeoutError, RequestError):
"Raised when a socket timeout occurs while receiving data from a server" """Raised when a socket timeout occurs while receiving data from a server"""
pass pass
# This timeout error does not have a URL attached and needs to inherit from the # This timeout error does not have a URL attached and needs to inherit from the
# base HTTPError # base HTTPError
class ConnectTimeoutError(TimeoutError): class ConnectTimeoutError(TimeoutError):
"Raised when a socket timeout occurs while connecting to a server" """Raised when a socket timeout occurs while connecting to a server"""
pass pass
class NewConnectionError(ConnectTimeoutError, PoolError): class NewConnectionError(ConnectTimeoutError, PoolError):
"Raised when we fail to establish a new connection. Usually ECONNREFUSED." """Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
pass pass
class EmptyPoolError(PoolError): class EmptyPoolError(PoolError):
"Raised when a pool runs out of connections and no more are allowed." """Raised when a pool runs out of connections and no more are allowed."""
pass pass
class ClosedPoolError(PoolError): class ClosedPoolError(PoolError):
"Raised when a request enters a pool after the pool has been closed." """Raised when a request enters a pool after the pool has been closed."""
pass pass
class LocationValueError(ValueError, HTTPError): class LocationValueError(ValueError, HTTPError):
"Raised when there is something wrong with a given URL input." """Raised when there is something wrong with a given URL input."""
pass pass
class LocationParseError(LocationValueError): class LocationParseError(LocationValueError):
"Raised when get_host or similar fails to parse the URL input." """Raised when get_host or similar fails to parse the URL input."""
def __init__(self, location): def __init__(self, location):
message = "Failed to parse: %s" % location message = "Failed to parse: %s" % location
@ -147,39 +165,56 @@ class LocationParseError(LocationValueError):
self.location = location self.location = location
class URLSchemeUnknown(LocationValueError):
"""Raised when a URL input has an unsupported scheme."""
def __init__(self, scheme):
message = "Not supported URL scheme %s" % scheme
super(URLSchemeUnknown, self).__init__(message)
self.scheme = scheme
class ResponseError(HTTPError): class ResponseError(HTTPError):
"Used as a container for an error reason supplied in a MaxRetryError." """Used as a container for an error reason supplied in a MaxRetryError."""
GENERIC_ERROR = 'too many error responses'
SPECIFIC_ERROR = 'too many {status_code} error responses' GENERIC_ERROR = "too many error responses"
SPECIFIC_ERROR = "too many {status_code} error responses"
class SecurityWarning(HTTPWarning): class SecurityWarning(HTTPWarning):
"Warned when perfoming security reducing actions" """Warned when performing security reducing actions"""
pass pass
class SubjectAltNameWarning(SecurityWarning): class SubjectAltNameWarning(SecurityWarning):
"Warned when connecting to a host with a certificate missing a SAN." """Warned when connecting to a host with a certificate missing a SAN."""
pass pass
class InsecureRequestWarning(SecurityWarning): class InsecureRequestWarning(SecurityWarning):
"Warned when making an unverified HTTPS request." """Warned when making an unverified HTTPS request."""
pass pass
class SystemTimeWarning(SecurityWarning): class SystemTimeWarning(SecurityWarning):
"Warned when system time is suspected to be wrong" """Warned when system time is suspected to be wrong"""
pass pass
class InsecurePlatformWarning(SecurityWarning): class InsecurePlatformWarning(SecurityWarning):
"Warned when certain SSL configuration is not available on a platform." """Warned when certain TLS/SSL configuration is not available on a platform."""
pass pass
class SNIMissingWarning(HTTPWarning): class SNIMissingWarning(HTTPWarning):
"Warned when making a HTTPS request without SNI available." """Warned when making a HTTPS request without SNI available."""
pass pass
@ -188,19 +223,22 @@ class DependencyWarning(HTTPWarning):
Warned when an attempt is made to import a module with missing optional Warned when an attempt is made to import a module with missing optional
dependencies. dependencies.
""" """
pass pass
class ResponseNotChunked(ProtocolError, ValueError): class ResponseNotChunked(ProtocolError, ValueError):
"Response needs to be chunked in order to read it as chunks." """Response needs to be chunked in order to read it as chunks."""
pass pass
class BodyNotHttplibCompatible(HTTPError): class BodyNotHttplibCompatible(HTTPError):
""" """
Body should be httplib.HTTPResponse like (have an fp attribute which Body should be :class:`http.client.HTTPResponse` like
returns raw chunks) for read_chunked(). (have an fp attribute which returns raw chunks) for read_chunked().
""" """
pass pass
@ -208,39 +246,78 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
""" """
Response length doesn't match expected Content-Length Response length doesn't match expected Content-Length
Subclass of http_client.IncompleteRead to allow int value Subclass of :class:`http.client.IncompleteRead` to allow int value
for `partial` to avoid creating large objects on streamed for ``partial`` to avoid creating large objects on streamed reads.
reads.
""" """
def __init__(self, partial, expected): def __init__(self, partial, expected):
super(IncompleteRead, self).__init__(partial, expected) super(IncompleteRead, self).__init__(partial, expected)
def __repr__(self): def __repr__(self):
return ('IncompleteRead(%i bytes read, ' return "IncompleteRead(%i bytes read, %i more expected)" % (
'%i more expected)' % (self.partial, self.expected)) self.partial,
self.expected,
)
class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
"""Invalid chunk length in a chunked response."""
def __init__(self, response, length):
super(InvalidChunkLength, self).__init__(
response.tell(), response.length_remaining
)
self.response = response
self.length = length
def __repr__(self):
return "InvalidChunkLength(got length %r, %i bytes read)" % (
self.length,
self.partial,
)
class InvalidHeader(HTTPError): class InvalidHeader(HTTPError):
"The header provided was somehow invalid." """The header provided was somehow invalid."""
pass pass
class ProxySchemeUnknown(AssertionError, ValueError): class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
"ProxyManager does not support the supplied scheme" """ProxyManager does not support the supplied scheme"""
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0. # TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
def __init__(self, scheme): def __init__(self, scheme):
message = "Not supported proxy scheme %s" % scheme # 'localhost' is here because our URL parser parses
# localhost:8080 -> scheme=localhost, remove if we fix this.
if scheme == "localhost":
scheme = None
if scheme is None:
message = "Proxy URL had no scheme, should start with http:// or https://"
else:
message = (
"Proxy URL had unsupported scheme %s, should use http:// or https://"
% scheme
)
super(ProxySchemeUnknown, self).__init__(message) super(ProxySchemeUnknown, self).__init__(message)
class ProxySchemeUnsupported(ValueError):
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
pass
class HeaderParsingError(HTTPError): class HeaderParsingError(HTTPError):
"Raised by assert_header_parsing, but we convert it to a log.warning statement." """Raised by assert_header_parsing, but we convert it to a log.warning statement."""
def __init__(self, defects, unparsed_data): def __init__(self, defects, unparsed_data):
message = '%s, unparsed data: %r' % (defects or 'Unknown', unparsed_data) message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
super(HeaderParsingError, self).__init__(message) super(HeaderParsingError, self).__init__(message)
class UnrewindableBodyError(HTTPError): class UnrewindableBodyError(HTTPError):
"urllib3 encountered an error when trying to rewind a body" """urllib3 encountered an error when trying to rewind a body"""
pass pass

View file

@ -1,11 +1,13 @@
from __future__ import absolute_import from __future__ import absolute_import
import email.utils import email.utils
import mimetypes import mimetypes
import re
from .packages import six from .packages import six
def guess_content_type(filename, default='application/octet-stream'): def guess_content_type(filename, default="application/octet-stream"):
""" """
Guess the "Content-Type" of a file. Guess the "Content-Type" of a file.
@ -19,57 +21,143 @@ def guess_content_type(filename, default='application/octet-stream'):
return default return default
def format_header_param(name, value): def format_header_param_rfc2231(name, value):
""" """
Helper function to format and quote a single header parameter. Helper function to format and quote a single header parameter using the
strategy defined in RFC 2231.
Particularly useful for header parameters which might contain Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows RFC 2231, as non-ASCII values, like file names. This follows
suggested by RFC 2388 Section 4.4. `RFC 2388 Section 4.4 <https://tools.ietf.org/html/rfc2388#section-4.4>`_.
:param name: :param name:
The name of the parameter, a string expected to be ASCII only. The name of the parameter, a string expected to be ASCII only.
:param value: :param value:
The value of the parameter, provided as a unicode string. The value of the parameter, provided as ``bytes`` or `str``.
:ret:
An RFC-2231-formatted unicode string.
""" """
if isinstance(value, six.binary_type):
value = value.decode("utf-8")
if not any(ch in value for ch in '"\\\r\n'): if not any(ch in value for ch in '"\\\r\n'):
result = '%s="%s"' % (name, value) result = u'%s="%s"' % (name, value)
try: try:
result.encode('ascii') result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError): except (UnicodeEncodeError, UnicodeDecodeError):
pass pass
else: else:
return result return result
if not six.PY3 and isinstance(value, six.text_type): # Python 2:
value = value.encode('utf-8') if six.PY2: # Python 2:
value = email.utils.encode_rfc2231(value, 'utf-8') value = value.encode("utf-8")
value = '%s*=%s' % (name, value)
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
# string in Python 2 but accepts and returns unicode strings in Python 3
value = email.utils.encode_rfc2231(value, "utf-8")
value = "%s*=%s" % (name, value)
if six.PY2: # Python 2:
value = value.decode("utf-8")
return value return value
_HTML5_REPLACEMENTS = {
u"\u0022": u"%22",
# Replace "\" with "\\".
u"\u005C": u"\u005C\u005C",
}
# All control characters from 0x00 to 0x1F *except* 0x1B.
_HTML5_REPLACEMENTS.update(
{
six.unichr(cc): u"%{:02X}".format(cc)
for cc in range(0x00, 0x1F + 1)
if cc not in (0x1B,)
}
)
def _replace_multiple(value, needles_and_replacements):
def replacer(match):
return needles_and_replacements[match.group(0)]
pattern = re.compile(
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
)
result = pattern.sub(replacer, value)
return result
def format_header_param_html5(name, value):
"""
Helper function to format and quote a single header parameter using the
HTML5 strategy.
Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows the `HTML5 Working Draft
Section 4.10.22.7`_ and matches the behavior of curl and modern browsers.
.. _HTML5 Working Draft Section 4.10.22.7:
https://w3c.github.io/html/sec-forms.html#multipart-form-data
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as ``bytes`` or `str``.
:ret:
A unicode string, stripped of troublesome characters.
"""
if isinstance(value, six.binary_type):
value = value.decode("utf-8")
value = _replace_multiple(value, _HTML5_REPLACEMENTS)
return u'%s="%s"' % (name, value)
# For backwards-compatibility.
format_header_param = format_header_param_html5
class RequestField(object): class RequestField(object):
""" """
A data container for request body parameters. A data container for request body parameters.
:param name: :param name:
The name of this request field. The name of this request field. Must be unicode.
:param data: :param data:
The data/value body. The data/value body.
:param filename: :param filename:
An optional filename of the request field. An optional filename of the request field. Must be unicode.
:param headers: :param headers:
An optional dict-like object of headers to initially use for the field. An optional dict-like object of headers to initially use for the field.
:param header_formatter:
An optional callable that is used to encode and format the headers. By
default, this is :func:`format_header_param_html5`.
""" """
def __init__(self, name, data, filename=None, headers=None):
def __init__(
self,
name,
data,
filename=None,
headers=None,
header_formatter=format_header_param_html5,
):
self._name = name self._name = name
self._filename = filename self._filename = filename
self.data = data self.data = data
self.headers = {} self.headers = {}
if headers: if headers:
self.headers = dict(headers) self.headers = dict(headers)
self.header_formatter = header_formatter
@classmethod @classmethod
def from_tuples(cls, fieldname, value): def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5):
""" """
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters. A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
@ -97,21 +185,25 @@ class RequestField(object):
content_type = None content_type = None
data = value data = value
request_param = cls(fieldname, data, filename=filename) request_param = cls(
fieldname, data, filename=filename, header_formatter=header_formatter
)
request_param.make_multipart(content_type=content_type) request_param.make_multipart(content_type=content_type)
return request_param return request_param
def _render_part(self, name, value): def _render_part(self, name, value):
""" """
Overridable helper function to format a single header parameter. Overridable helper function to format a single header parameter. By
default, this calls ``self.header_formatter``.
:param name: :param name:
The name of the parameter, a string expected to be ASCII only. The name of the parameter, a string expected to be ASCII only.
:param value: :param value:
The value of the parameter, provided as a unicode string. The value of the parameter, provided as a unicode string.
""" """
return format_header_param(name, value)
return self.header_formatter(name, value)
def _render_parts(self, header_parts): def _render_parts(self, header_parts):
""" """
@ -121,7 +213,7 @@ class RequestField(object):
'Content-Disposition' fields. 'Content-Disposition' fields.
:param header_parts: :param header_parts:
A sequence of (k, v) typles or a :class:`dict` of (k, v) to format A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
as `k1="v1"; k2="v2"; ...`. as `k1="v1"; k2="v2"; ...`.
""" """
parts = [] parts = []
@ -133,7 +225,7 @@ class RequestField(object):
if value is not None: if value is not None:
parts.append(self._render_part(name, value)) parts.append(self._render_part(name, value))
return '; '.join(parts) return u"; ".join(parts)
def render_headers(self): def render_headers(self):
""" """
@ -141,21 +233,22 @@ class RequestField(object):
""" """
lines = [] lines = []
sort_keys = ['Content-Disposition', 'Content-Type', 'Content-Location'] sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys: for sort_key in sort_keys:
if self.headers.get(sort_key, False): if self.headers.get(sort_key, False):
lines.append('%s: %s' % (sort_key, self.headers[sort_key])) lines.append(u"%s: %s" % (sort_key, self.headers[sort_key]))
for header_name, header_value in self.headers.items(): for header_name, header_value in self.headers.items():
if header_name not in sort_keys: if header_name not in sort_keys:
if header_value: if header_value:
lines.append('%s: %s' % (header_name, header_value)) lines.append(u"%s: %s" % (header_name, header_value))
lines.append('\r\n') lines.append(u"\r\n")
return '\r\n'.join(lines) return u"\r\n".join(lines)
def make_multipart(self, content_disposition=None, content_type=None, def make_multipart(
content_location=None): self, content_disposition=None, content_type=None, content_location=None
):
""" """
Makes this request field into a multipart request field. Makes this request field into a multipart request field.
@ -168,11 +261,14 @@ class RequestField(object):
The 'Content-Location' of the request body. The 'Content-Location' of the request body.
""" """
self.headers['Content-Disposition'] = content_disposition or 'form-data' self.headers["Content-Disposition"] = content_disposition or u"form-data"
self.headers['Content-Disposition'] += '; '.join([ self.headers["Content-Disposition"] += u"; ".join(
'', self._render_parts( [
(('name', self._name), ('filename', self._filename)) u"",
self._render_parts(
((u"name", self._name), (u"filename", self._filename))
),
]
) )
]) self.headers["Content-Type"] = content_type
self.headers['Content-Type'] = content_type self.headers["Content-Location"] = content_location
self.headers['Content-Location'] = content_location

View file

@ -1,21 +1,25 @@
from __future__ import absolute_import from __future__ import absolute_import
import codecs
from uuid import uuid4 import binascii
import codecs
import os
from io import BytesIO from io import BytesIO
from .fields import RequestField
from .packages import six from .packages import six
from .packages.six import b from .packages.six import b
from .fields import RequestField
writer = codecs.lookup('utf-8')[3] writer = codecs.lookup("utf-8")[3]
def choose_boundary(): def choose_boundary():
""" """
Our embarrassingly-simple replacement for mimetools.choose_boundary. Our embarrassingly-simple replacement for mimetools.choose_boundary.
""" """
return uuid4().hex boundary = binascii.hexlify(os.urandom(16))
if not six.PY2:
boundary = boundary.decode("ascii")
return boundary
def iter_field_objects(fields): def iter_field_objects(fields):
@ -65,14 +69,14 @@ def encode_multipart_formdata(fields, boundary=None):
:param boundary: :param boundary:
If not specified, then a random boundary will be generated using If not specified, then a random boundary will be generated using
:func:`mimetools.choose_boundary`. :func:`urllib3.filepost.choose_boundary`.
""" """
body = BytesIO() body = BytesIO()
if boundary is None: if boundary is None:
boundary = choose_boundary() boundary = choose_boundary()
for field in iter_field_objects(fields): for field in iter_field_objects(fields):
body.write(b('--%s\r\n' % (boundary))) body.write(b("--%s\r\n" % (boundary)))
writer(body).write(field.render_headers()) writer(body).write(field.render_headers())
data = field.data data = field.data
@ -85,10 +89,10 @@ def encode_multipart_formdata(fields, boundary=None):
else: else:
body.write(data) body.write(data)
body.write(b'\r\n') body.write(b"\r\n")
body.write(b('--%s--\r\n' % (boundary))) body.write(b("--%s--\r\n" % (boundary)))
content_type = str('multipart/form-data; boundary=%s' % boundary) content_type = str("multipart/form-data; boundary=%s" % boundary)
return body.getvalue(), content_type return body.getvalue(), content_type

View file

@ -2,4 +2,4 @@ from __future__ import absolute_import
from . import ssl_match_hostname from . import ssl_match_hostname
__all__ = ('ssl_match_hostname', ) __all__ = ("ssl_match_hostname",)

View file

@ -7,19 +7,17 @@ Backports the Python 3 ``socket.makefile`` method for use with anything that
wants to create a "fake" socket object. wants to create a "fake" socket object.
""" """
import io import io
from socket import SocketIO from socket import SocketIO
def backport_makefile(self, mode="r", buffering=None, encoding=None, def backport_makefile(
errors=None, newline=None): self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
""" """
Backport of ``socket.makefile`` from Python 3.5. Backport of ``socket.makefile`` from Python 3.5.
""" """
if not set(mode) <= set(["r", "w", "b"]): if not set(mode) <= {"r", "w", "b"}:
raise ValueError( raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
"invalid mode %r (only r, w, b allowed)" % (mode,)
)
writing = "w" in mode writing = "w" in mode
reading = "r" in mode or not writing reading = "r" in mode or not writing
assert reading or writing assert reading or writing

View file

@ -1,259 +0,0 @@
# Backport of OrderedDict() class that runs on Python 2.4, 2.5, 2.6, 2.7 and pypy.
# Passes Python2.7's test suite and incorporates all the latest updates.
# Copyright 2009 Raymond Hettinger, released under the MIT License.
# http://code.activestate.com/recipes/576693/
try:
from thread import get_ident as _get_ident
except ImportError:
from dummy_thread import get_ident as _get_ident
try:
from _abcoll import KeysView, ValuesView, ItemsView
except ImportError:
pass
class OrderedDict(dict):
'Dictionary that remembers insertion order'
# An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get.
# The remaining methods are order-aware.
# Big-O running times for all methods are the same as for regular dictionaries.
# The internal self.__map dictionary maps keys to links in a doubly linked list.
# The circular doubly linked list starts and ends with a sentinel element.
# The sentinel element never gets deleted (this simplifies the algorithm).
# Each link is stored as a list of length three: [PREV, NEXT, KEY].
def __init__(self, *args, **kwds):
'''Initialize an ordered dictionary. Signature is the same as for
regular dictionaries, but keyword arguments are not recommended
because their insertion order is arbitrary.
'''
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
try:
self.__root
except AttributeError:
self.__root = root = [] # sentinel node
root[:] = [root, root, None]
self.__map = {}
self.__update(*args, **kwds)
def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
'od.__setitem__(i, y) <==> od[i]=y'
# Setting a new item creates a new link which goes at the end of the linked
# list, and the inherited dictionary is updated with the new key/value pair.
if key not in self:
root = self.__root
last = root[0]
last[1] = root[0] = self.__map[key] = [last, root, key]
dict_setitem(self, key, value)
def __delitem__(self, key, dict_delitem=dict.__delitem__):
'od.__delitem__(y) <==> del od[y]'
# Deleting an existing item uses self.__map to find the link which is
# then removed by updating the links in the predecessor and successor nodes.
dict_delitem(self, key)
link_prev, link_next, key = self.__map.pop(key)
link_prev[1] = link_next
link_next[0] = link_prev
def __iter__(self):
'od.__iter__() <==> iter(od)'
root = self.__root
curr = root[1]
while curr is not root:
yield curr[2]
curr = curr[1]
def __reversed__(self):
'od.__reversed__() <==> reversed(od)'
root = self.__root
curr = root[0]
while curr is not root:
yield curr[2]
curr = curr[0]
def clear(self):
'od.clear() -> None. Remove all items from od.'
try:
for node in self.__map.itervalues():
del node[:]
root = self.__root
root[:] = [root, root, None]
self.__map.clear()
except AttributeError:
pass
dict.clear(self)
def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false.
'''
if not self:
raise KeyError('dictionary is empty')
root = self.__root
if last:
link = root[0]
link_prev = link[0]
link_prev[1] = root
root[0] = link_prev
else:
link = root[1]
link_next = link[1]
root[1] = link_next
link_next[0] = root
key = link[2]
del self.__map[key]
value = dict.pop(self, key)
return key, value
# -- the following methods do not depend on the internal structure --
def keys(self):
'od.keys() -> list of keys in od'
return list(self)
def values(self):
'od.values() -> list of values in od'
return [self[key] for key in self]
def items(self):
'od.items() -> list of (key, value) pairs in od'
return [(key, self[key]) for key in self]
def iterkeys(self):
'od.iterkeys() -> an iterator over the keys in od'
return iter(self)
def itervalues(self):
'od.itervalues -> an iterator over the values in od'
for k in self:
yield self[k]
def iteritems(self):
'od.iteritems -> an iterator over the (key, value) items in od'
for k in self:
yield (k, self[k])
def update(*args, **kwds):
'''od.update(E, **F) -> None. Update od from dict/iterable E and F.
If E is a dict instance, does: for k in E: od[k] = E[k]
If E has a .keys() method, does: for k in E.keys(): od[k] = E[k]
Or if E is an iterable of items, does: for k, v in E: od[k] = v
In either case, this is followed by: for k, v in F.items(): od[k] = v
'''
if len(args) > 2:
raise TypeError('update() takes at most 2 positional '
'arguments (%d given)' % (len(args),))
elif not args:
raise TypeError('update() takes at least 1 argument (0 given)')
self = args[0]
# Make progressively weaker assumptions about "other"
other = ()
if len(args) == 2:
other = args[1]
if isinstance(other, dict):
for key in other:
self[key] = other[key]
elif hasattr(other, 'keys'):
for key in other.keys():
self[key] = other[key]
else:
for key, value in other:
self[key] = value
for key, value in kwds.items():
self[key] = value
__update = update # let subclasses override update without breaking __init__
__marker = object()
def pop(self, key, default=__marker):
'''od.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
'''
if key in self:
result = self[key]
del self[key]
return result
if default is self.__marker:
raise KeyError(key)
return default
def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
if key in self:
return self[key]
self[key] = default
return default
def __repr__(self, _repr_running={}):
'od.__repr__() <==> repr(od)'
call_key = id(self), _get_ident()
if call_key in _repr_running:
return '...'
_repr_running[call_key] = 1
try:
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, self.items())
finally:
del _repr_running[call_key]
def __reduce__(self):
'Return state information for pickling'
items = [[k, self[k]] for k in self]
inst_dict = vars(self).copy()
for k in vars(OrderedDict()):
inst_dict.pop(k, None)
if inst_dict:
return (self.__class__, (items,), inst_dict)
return self.__class__, (items,)
def copy(self):
'od.copy() -> a shallow copy of od'
return self.__class__(self)
@classmethod
def fromkeys(cls, iterable, value=None):
'''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
and values equal to v (which defaults to None).
'''
d = cls()
for key in iterable:
d[key] = value
return d
def __eq__(self, other):
'''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
while comparison to a regular mapping is order-insensitive.
'''
if isinstance(other, OrderedDict):
return len(self)==len(other) and self.items() == other.items()
return dict.__eq__(self, other)
def __ne__(self, other):
return not self == other
# -- the following methods are only used in Python 2.7 --
def viewkeys(self):
"od.viewkeys() -> a set-like object providing a view on od's keys"
return KeysView(self)
def viewvalues(self):
"od.viewvalues() -> an object providing a view on od's values"
return ValuesView(self)
def viewitems(self):
"od.viewitems() -> a set-like object providing a view on od's items"
return ItemsView(self)

View file

@ -1,6 +1,4 @@
"""Utilities for writing code that runs on Python 2 and 3""" # Copyright (c) 2010-2020 Benjamin Peterson
# Copyright (c) 2010-2015 Benjamin Peterson
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
@ -20,6 +18,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
"""Utilities for writing code that runs on Python 2 and 3"""
from __future__ import absolute_import from __future__ import absolute_import
import functools import functools
@ -29,7 +29,7 @@ import sys
import types import types
__author__ = "Benjamin Peterson <benjamin@python.org>" __author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.10.0" __version__ = "1.16.0"
# Useful for very coarse version differentiation. # Useful for very coarse version differentiation.
@ -38,15 +38,15 @@ PY3 = sys.version_info[0] == 3
PY34 = sys.version_info[0:2] >= (3, 4) PY34 = sys.version_info[0:2] >= (3, 4)
if PY3: if PY3:
string_types = str, string_types = (str,)
integer_types = int, integer_types = (int,)
class_types = type, class_types = (type,)
text_type = str text_type = str
binary_type = bytes binary_type = bytes
MAXSIZE = sys.maxsize MAXSIZE = sys.maxsize
else: else:
string_types = basestring, string_types = (basestring,)
integer_types = (int, long) integer_types = (int, long)
class_types = (type, types.ClassType) class_types = (type, types.ClassType)
text_type = unicode text_type = unicode
@ -58,9 +58,9 @@ else:
else: else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t). # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object): class X(object):
def __len__(self): def __len__(self):
return 1 << 31 return 1 << 31
try: try:
len(X()) len(X())
except OverflowError: except OverflowError:
@ -71,6 +71,11 @@ else:
MAXSIZE = int((1 << 63) - 1) MAXSIZE = int((1 << 63) - 1)
del X del X
if PY34:
from importlib.util import spec_from_loader
else:
spec_from_loader = None
def _add_doc(func, doc): def _add_doc(func, doc):
"""Add documentation to a function.""" """Add documentation to a function."""
@ -84,7 +89,6 @@ def _import_module(name):
class _LazyDescr(object): class _LazyDescr(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@ -101,7 +105,6 @@ class _LazyDescr(object):
class MovedModule(_LazyDescr): class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None): def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name) super(MovedModule, self).__init__(name)
if PY3: if PY3:
@ -122,7 +125,6 @@ class MovedModule(_LazyDescr):
class _LazyModule(types.ModuleType): class _LazyModule(types.ModuleType):
def __init__(self, name): def __init__(self, name):
super(_LazyModule, self).__init__(name) super(_LazyModule, self).__init__(name)
self.__doc__ = self.__class__.__doc__ self.__doc__ = self.__class__.__doc__
@ -137,7 +139,6 @@ class _LazyModule(types.ModuleType):
class MovedAttribute(_LazyDescr): class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name) super(MovedAttribute, self).__init__(name)
if PY3: if PY3:
@ -186,6 +187,11 @@ class _SixMetaPathImporter(object):
return self return self
return None return None
def find_spec(self, fullname, path, target=None):
if fullname in self.known_modules:
return spec_from_loader(fullname, self)
return None
def __get_module(self, fullname): def __get_module(self, fullname):
try: try:
return self.known_modules[fullname] return self.known_modules[fullname]
@ -221,28 +227,42 @@ class _SixMetaPathImporter(object):
Required, if is_package is implemented""" Required, if is_package is implemented"""
self.__get_module(fullname) # eventually raises ImportError self.__get_module(fullname) # eventually raises ImportError
return None return None
get_source = get_code # same as get_code get_source = get_code # same as get_code
def create_module(self, spec):
return self.load_module(spec.name)
def exec_module(self, module):
pass
_importer = _SixMetaPathImporter(__name__) _importer = _SixMetaPathImporter(__name__)
class _MovedItems(_LazyModule): class _MovedItems(_LazyModule):
"""Lazy loading of moved objects""" """Lazy loading of moved objects"""
__path__ = [] # mark as package __path__ = [] # mark as package
_moved_attributes = [ _moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), MovedAttribute(
"filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"
),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("intern", "__builtin__", "sys"), MovedAttribute("intern", "__builtin__", "sys"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"), MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
MovedAttribute("getoutput", "commands", "subprocess"),
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), MovedAttribute(
"reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"
),
MovedAttribute("reduce", "__builtin__", "functools"), MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
MovedAttribute("StringIO", "StringIO", "io"), MovedAttribute("StringIO", "StringIO", "io"),
@ -251,21 +271,36 @@ _moved_attributes = [
MovedAttribute("UserString", "UserString", "collections"), MovedAttribute("UserString", "UserString", "collections"),
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), MovedAttribute(
"zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"
),
MovedModule("builtins", "__builtin__"), MovedModule("builtins", "__builtin__"),
MovedModule("configparser", "ConfigParser"), MovedModule("configparser", "ConfigParser"),
MovedModule(
"collections_abc",
"collections",
"collections.abc" if sys.version_info >= (3, 3) else "collections",
),
MovedModule("copyreg", "copy_reg"), MovedModule("copyreg", "copy_reg"),
MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"),
MovedModule(
"_dummy_thread",
"dummy_thread",
"_dummy_thread" if sys.version_info < (3, 9) else "_thread",
),
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
MovedModule("http_cookies", "Cookie", "http.cookies"), MovedModule("http_cookies", "Cookie", "http.cookies"),
MovedModule("html_entities", "htmlentitydefs", "html.entities"), MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"), MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"), MovedModule("http_client", "httplib", "http.client"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule(
"email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"
),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
@ -283,15 +318,12 @@ _moved_attributes = [
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule("tkinter_colorchooser", "tkColorChooser", MovedModule("tkinter_colorchooser", "tkColorChooser", "tkinter.colorchooser"),
"tkinter.colorchooser"), MovedModule("tkinter_commondialog", "tkCommonDialog", "tkinter.commondialog"),
MovedModule("tkinter_commondialog", "tkCommonDialog",
"tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"), MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", "tkinter.simpledialog"),
"tkinter.simpledialog"),
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
@ -337,10 +369,14 @@ _urllib_parse_moved_attributes = [
MovedAttribute("quote_plus", "urllib", "urllib.parse"), MovedAttribute("quote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote", "urllib", "urllib.parse"), MovedAttribute("unquote", "urllib", "urllib.parse"),
MovedAttribute("unquote_plus", "urllib", "urllib.parse"), MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
MovedAttribute(
"unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"
),
MovedAttribute("urlencode", "urllib", "urllib.parse"), MovedAttribute("urlencode", "urllib", "urllib.parse"),
MovedAttribute("splitquery", "urllib", "urllib.parse"), MovedAttribute("splitquery", "urllib", "urllib.parse"),
MovedAttribute("splittag", "urllib", "urllib.parse"), MovedAttribute("splittag", "urllib", "urllib.parse"),
MovedAttribute("splituser", "urllib", "urllib.parse"), MovedAttribute("splituser", "urllib", "urllib.parse"),
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
MovedAttribute("uses_params", "urlparse", "urllib.parse"), MovedAttribute("uses_params", "urlparse", "urllib.parse"),
@ -353,8 +389,11 @@ del attr
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), _importer._add_module(
"moves.urllib_parse", "moves.urllib.parse") Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
"moves.urllib_parse",
"moves.urllib.parse",
)
class Module_six_moves_urllib_error(_LazyModule): class Module_six_moves_urllib_error(_LazyModule):
@ -373,8 +412,11 @@ del attr
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), _importer._add_module(
"moves.urllib_error", "moves.urllib.error") Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
"moves.urllib_error",
"moves.urllib.error",
)
class Module_six_moves_urllib_request(_LazyModule): class Module_six_moves_urllib_request(_LazyModule):
@ -416,6 +458,8 @@ _urllib_request_moved_attributes = [
MovedAttribute("URLopener", "urllib", "urllib.request"), MovedAttribute("URLopener", "urllib", "urllib.request"),
MovedAttribute("FancyURLopener", "urllib", "urllib.request"), MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
MovedAttribute("proxy_bypass", "urllib", "urllib.request"), MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
] ]
for attr in _urllib_request_moved_attributes: for attr in _urllib_request_moved_attributes:
setattr(Module_six_moves_urllib_request, attr.name, attr) setattr(Module_six_moves_urllib_request, attr.name, attr)
@ -423,8 +467,11 @@ del attr
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), _importer._add_module(
"moves.urllib_request", "moves.urllib.request") Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
"moves.urllib_request",
"moves.urllib.request",
)
class Module_six_moves_urllib_response(_LazyModule): class Module_six_moves_urllib_response(_LazyModule):
@ -444,8 +491,11 @@ del attr
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), _importer._add_module(
"moves.urllib_response", "moves.urllib.response") Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
"moves.urllib_response",
"moves.urllib.response",
)
class Module_six_moves_urllib_robotparser(_LazyModule): class Module_six_moves_urllib_robotparser(_LazyModule):
@ -460,15 +510,21 @@ for attr in _urllib_robotparser_moved_attributes:
setattr(Module_six_moves_urllib_robotparser, attr.name, attr) setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
del attr del attr
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes Module_six_moves_urllib_robotparser._moved_attributes = (
_urllib_robotparser_moved_attributes
)
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), _importer._add_module(
"moves.urllib_robotparser", "moves.urllib.robotparser") Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser",
"moves.urllib.robotparser",
)
class Module_six_moves_urllib(types.ModuleType): class Module_six_moves_urllib(types.ModuleType):
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace""" """Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
__path__ = [] # mark as package __path__ = [] # mark as package
parse = _importer._get_module("moves.urllib_parse") parse = _importer._get_module("moves.urllib_parse")
error = _importer._get_module("moves.urllib_error") error = _importer._get_module("moves.urllib_error")
@ -477,10 +533,12 @@ class Module_six_moves_urllib(types.ModuleType):
robotparser = _importer._get_module("moves.urllib_robotparser") robotparser = _importer._get_module("moves.urllib_robotparser")
def __dir__(self): def __dir__(self):
return ['parse', 'error', 'request', 'response', 'robotparser'] return ["parse", "error", "request", "response", "robotparser"]
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
"moves.urllib") _importer._add_module(
Module_six_moves_urllib(__name__ + ".moves.urllib"), "moves.urllib"
)
def add_move(move): def add_move(move):
@ -520,19 +578,24 @@ else:
try: try:
advance_iterator = next advance_iterator = next
except NameError: except NameError:
def advance_iterator(it): def advance_iterator(it):
return it.next() return it.next()
next = advance_iterator next = advance_iterator
try: try:
callable = callable callable = callable
except NameError: except NameError:
def callable(obj): def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
if PY3: if PY3:
def get_unbound_function(unbound): def get_unbound_function(unbound):
return unbound return unbound
@ -543,6 +606,7 @@ if PY3:
Iterator = object Iterator = object
else: else:
def get_unbound_function(unbound): def get_unbound_function(unbound):
return unbound.im_func return unbound.im_func
@ -553,13 +617,13 @@ else:
return types.MethodType(func, None, cls) return types.MethodType(func, None, cls)
class Iterator(object): class Iterator(object):
def next(self): def next(self):
return type(self).__next__(self) return type(self).__next__(self)
callable = callable callable = callable
_add_doc(get_unbound_function, _add_doc(
"""Get the function out of a possibly unbound function""") get_unbound_function, """Get the function out of a possibly unbound function"""
)
get_method_function = operator.attrgetter(_meth_func) get_method_function = operator.attrgetter(_meth_func)
@ -571,6 +635,7 @@ get_function_globals = operator.attrgetter(_func_globals)
if PY3: if PY3:
def iterkeys(d, **kw): def iterkeys(d, **kw):
return iter(d.keys(**kw)) return iter(d.keys(**kw))
@ -589,6 +654,7 @@ if PY3:
viewitems = operator.methodcaller("items") viewitems = operator.methodcaller("items")
else: else:
def iterkeys(d, **kw): def iterkeys(d, **kw):
return d.iterkeys(**kw) return d.iterkeys(**kw)
@ -609,42 +675,52 @@ else:
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") _add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
_add_doc(itervalues, "Return an iterator over the values of a dictionary.") _add_doc(itervalues, "Return an iterator over the values of a dictionary.")
_add_doc(iteritems, _add_doc(iteritems, "Return an iterator over the (key, value) pairs of a dictionary.")
"Return an iterator over the (key, value) pairs of a dictionary.") _add_doc(
_add_doc(iterlists, iterlists, "Return an iterator over the (key, [values]) pairs of a dictionary."
"Return an iterator over the (key, [values]) pairs of a dictionary.") )
if PY3: if PY3:
def b(s): def b(s):
return s.encode("latin-1") return s.encode("latin-1")
def u(s): def u(s):
return s return s
unichr = chr unichr = chr
import struct import struct
int2byte = struct.Struct(">B").pack int2byte = struct.Struct(">B").pack
del struct del struct
byte2int = operator.itemgetter(0) byte2int = operator.itemgetter(0)
indexbytes = operator.getitem indexbytes = operator.getitem
iterbytes = iter iterbytes = iter
import io import io
StringIO = io.StringIO StringIO = io.StringIO
BytesIO = io.BytesIO BytesIO = io.BytesIO
del io
_assertCountEqual = "assertCountEqual" _assertCountEqual = "assertCountEqual"
if sys.version_info[1] <= 1: if sys.version_info[1] <= 1:
_assertRaisesRegex = "assertRaisesRegexp" _assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches" _assertRegex = "assertRegexpMatches"
_assertNotRegex = "assertNotRegexpMatches"
else: else:
_assertRaisesRegex = "assertRaisesRegex" _assertRaisesRegex = "assertRaisesRegex"
_assertRegex = "assertRegex" _assertRegex = "assertRegex"
_assertNotRegex = "assertNotRegex"
else: else:
def b(s): def b(s):
return s return s
# Workaround for standalone backslash # Workaround for standalone backslash
def u(s): def u(s):
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") return unicode(s.replace(r"\\", r"\\\\"), "unicode_escape")
unichr = unichr unichr = unichr
int2byte = chr int2byte = chr
@ -653,12 +729,15 @@ else:
def indexbytes(buf, i): def indexbytes(buf, i):
return ord(buf[i]) return ord(buf[i])
iterbytes = functools.partial(itertools.imap, ord) iterbytes = functools.partial(itertools.imap, ord)
import StringIO import StringIO
StringIO = BytesIO = StringIO.StringIO StringIO = BytesIO = StringIO.StringIO
_assertCountEqual = "assertItemsEqual" _assertCountEqual = "assertItemsEqual"
_assertRaisesRegex = "assertRaisesRegexp" _assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches" _assertRegex = "assertRegexpMatches"
_assertNotRegex = "assertNotRegexpMatches"
_add_doc(b, """Byte literal""") _add_doc(b, """Byte literal""")
_add_doc(u, """Text literal""") _add_doc(u, """Text literal""")
@ -675,17 +754,27 @@ def assertRegex(self, *args, **kwargs):
return getattr(self, _assertRegex)(*args, **kwargs) return getattr(self, _assertRegex)(*args, **kwargs)
def assertNotRegex(self, *args, **kwargs):
return getattr(self, _assertNotRegex)(*args, **kwargs)
if PY3: if PY3:
exec_ = getattr(moves.builtins, "exec") exec_ = getattr(moves.builtins, "exec")
def reraise(tp, value, tb=None): def reraise(tp, value, tb=None):
try:
if value is None: if value is None:
value = tp() value = tp()
if value.__traceback__ is not tb: if value.__traceback__ is not tb:
raise value.with_traceback(tb) raise value.with_traceback(tb)
raise value raise value
finally:
value = None
tb = None
else: else:
def exec_(_code_, _globs_=None, _locs_=None): def exec_(_code_, _globs_=None, _locs_=None):
"""Execute code in a namespace.""" """Execute code in a namespace."""
if _globs_ is None: if _globs_ is None:
@ -698,28 +787,34 @@ else:
_locs_ = _globs_ _locs_ = _globs_
exec ("""exec _code_ in _globs_, _locs_""") exec ("""exec _code_ in _globs_, _locs_""")
exec_("""def reraise(tp, value, tb=None): exec_(
"""def reraise(tp, value, tb=None):
try:
raise tp, value, tb raise tp, value, tb
""") finally:
tb = None
"""
)
if sys.version_info[:2] == (3, 2): if sys.version_info[:2] > (3,):
exec_("""def raise_from(value, from_value): exec_(
if from_value is None: """def raise_from(value, from_value):
raise value try:
raise value from from_value raise value from from_value
""") finally:
elif sys.version_info[:2] > (3, 2): value = None
exec_("""def raise_from(value, from_value): """
raise value from from_value )
""")
else: else:
def raise_from(value, from_value): def raise_from(value, from_value):
raise value raise value
print_ = getattr(moves.builtins, "print", None) print_ = getattr(moves.builtins, "print", None)
if print_ is None: if print_ is None:
def print_(*args, **kwargs): def print_(*args, **kwargs):
"""The new-style print function for Python 2.4 and 2.5.""" """The new-style print function for Python 2.4 and 2.5."""
fp = kwargs.pop("file", sys.stdout) fp = kwargs.pop("file", sys.stdout)
@ -730,14 +825,17 @@ if print_ is None:
if not isinstance(data, basestring): if not isinstance(data, basestring):
data = str(data) data = str(data)
# If the file has an encoding, encode unicode with it. # If the file has an encoding, encode unicode with it.
if (isinstance(fp, file) and if (
isinstance(data, unicode) and isinstance(fp, file)
fp.encoding is not None): and isinstance(data, unicode)
and fp.encoding is not None
):
errors = getattr(fp, "errors", None) errors = getattr(fp, "errors", None)
if errors is None: if errors is None:
errors = "strict" errors = "strict"
data = data.encode(fp.encoding, errors) data = data.encode(fp.encoding, errors)
fp.write(data) fp.write(data)
want_unicode = False want_unicode = False
sep = kwargs.pop("sep", None) sep = kwargs.pop("sep", None)
if sep is not None: if sep is not None:
@ -773,6 +871,8 @@ if print_ is None:
write(sep) write(sep)
write(arg) write(arg)
write(end) write(end)
if sys.version_info[:2] < (3, 3): if sys.version_info[:2] < (3, 3):
_print = print_ _print = print_
@ -783,16 +883,46 @@ if sys.version_info[:2] < (3, 3):
if flush and fp is not None: if flush and fp is not None:
fp.flush() fp.flush()
_add_doc(reraise, """Reraise an exception.""") _add_doc(reraise, """Reraise an exception.""")
if sys.version_info[0:2] < (3, 4): if sys.version_info[0:2] < (3, 4):
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, # This does exactly the same what the :func:`py3:functools.update_wrapper`
updated=functools.WRAPPER_UPDATES): # function does on Python versions after 3.2. It sets the ``__wrapped__``
def wrapper(f): # attribute on ``wrapper`` object and it doesn't raise an error if any of
f = functools.wraps(wrapped, assigned, updated)(f) # the attributes mentioned in ``assigned`` and ``updated`` are missing on
f.__wrapped__ = wrapped # ``wrapped`` object.
return f def _update_wrapper(
wrapper,
wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
for attr in assigned:
try:
value = getattr(wrapped, attr)
except AttributeError:
continue
else:
setattr(wrapper, attr, value)
for attr in updated:
getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
wrapper.__wrapped__ = wrapped
return wrapper return wrapper
_update_wrapper.__doc__ = functools.update_wrapper.__doc__
def wraps(
wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
return functools.partial(
_update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated
)
wraps.__doc__ = functools.wraps.__doc__
else: else:
wraps = functools.wraps wraps = functools.wraps
@ -802,44 +932,121 @@ def with_metaclass(meta, *bases):
# This requires a bit of explanation: the basic idea is to make a dummy # This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with # metaclass for one level of class instantiation that replaces itself with
# the actual metaclass. # the actual metaclass.
class metaclass(meta): class metaclass(type):
def __new__(cls, name, this_bases, d): def __new__(cls, name, this_bases, d):
return meta(name, bases, d) if sys.version_info[:2] >= (3, 7):
return type.__new__(metaclass, 'temporary_class', (), {}) # This version introduced PEP 560 that requires a bit
# of extra care (we mimic what is done by __build_class__).
resolved_bases = types.resolve_bases(bases)
if resolved_bases is not bases:
d["__orig_bases__"] = bases
else:
resolved_bases = bases
return meta(name, resolved_bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, "temporary_class", (), {})
def add_metaclass(metaclass): def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass.""" """Class decorator for creating a class with a metaclass."""
def wrapper(cls): def wrapper(cls):
orig_vars = cls.__dict__.copy() orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__') slots = orig_vars.get("__slots__")
if slots is not None: if slots is not None:
if isinstance(slots, str): if isinstance(slots, str):
slots = [slots] slots = [slots]
for slots_var in slots: for slots_var in slots:
orig_vars.pop(slots_var) orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None) orig_vars.pop("__dict__", None)
orig_vars.pop('__weakref__', None) orig_vars.pop("__weakref__", None)
if hasattr(cls, "__qualname__"):
orig_vars["__qualname__"] = cls.__qualname__
return metaclass(cls.__name__, cls.__bases__, orig_vars) return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper return wrapper
def ensure_binary(s, encoding="utf-8", errors="strict"):
"""Coerce **s** to six.binary_type.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> encoded to `bytes`
- `bytes` -> `bytes`
"""
if isinstance(s, binary_type):
return s
if isinstance(s, text_type):
return s.encode(encoding, errors)
raise TypeError("not expecting type '%s'" % type(s))
def ensure_str(s, encoding="utf-8", errors="strict"):
"""Coerce *s* to `str`.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
# Optimization: Fast return for the common case.
if type(s) is str:
return s
if PY2 and isinstance(s, text_type):
return s.encode(encoding, errors)
elif PY3 and isinstance(s, binary_type):
return s.decode(encoding, errors)
elif not isinstance(s, (text_type, binary_type)):
raise TypeError("not expecting type '%s'" % type(s))
return s
def ensure_text(s, encoding="utf-8", errors="strict"):
"""Coerce *s* to six.text_type.
For Python 2:
- `unicode` -> `unicode`
- `str` -> `unicode`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
if isinstance(s, binary_type):
return s.decode(encoding, errors)
elif isinstance(s, text_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
def python_2_unicode_compatible(klass): def python_2_unicode_compatible(klass):
""" """
A decorator that defines __unicode__ and __str__ methods under Python 2. A class decorator that defines __unicode__ and __str__ methods under Python 2.
Under Python 3 it does nothing. Under Python 3 it does nothing.
To support Python 2 and 3 with a single code base, define a __str__ method To support Python 2 and 3 with a single code base, define a __str__ method
returning text and apply this decorator to the class. returning text and apply this decorator to the class.
""" """
if PY2: if PY2:
if '__str__' not in klass.__dict__: if "__str__" not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied " raise ValueError(
"to %s because it doesn't define __str__()." % "@python_2_unicode_compatible cannot be applied "
klass.__name__) "to %s because it doesn't define __str__()." % klass.__name__
)
klass.__unicode__ = klass.__str__ klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8') klass.__str__ = lambda self: self.__unicode__().encode("utf-8")
return klass return klass
@ -859,8 +1066,10 @@ if sys.meta_path:
# be floating around. Therefore, we can't use isinstance() to check for # be floating around. Therefore, we can't use isinstance() to check for
# the six meta path importer, since the other six instance will have # the six meta path importer, since the other six instance will have
# inserted an importer with different class. # inserted an importer with different class.
if (type(importer).__name__ == "_SixMetaPathImporter" and if (
importer.name == __name__): type(importer).__name__ == "_SixMetaPathImporter"
and importer.name == __name__
):
del sys.meta_path[i] del sys.meta_path[i]
break break
del i, importer del i, importer

View file

@ -1,19 +1,24 @@
import sys import sys
try: try:
# Our match_hostname function is the same as 3.5's, so we only want to # Our match_hostname function is the same as 3.10's, so we only want to
# import the match_hostname function if it's at least that good. # import the match_hostname function if it's at least that good.
if sys.version_info < (3, 5): # We also fallback on Python 3.10+ because our code doesn't emit
# deprecation warnings and is the same as Python 3.10 otherwise.
if sys.version_info < (3, 5) or sys.version_info >= (3, 10):
raise ImportError("Fallback to vendored code") raise ImportError("Fallback to vendored code")
from ssl import CertificateError, match_hostname from ssl import CertificateError, match_hostname
except ImportError: except ImportError:
try: try:
# Backport of the function from a pypi module # Backport of the function from a pypi module
from backports.ssl_match_hostname import CertificateError, match_hostname from backports.ssl_match_hostname import ( # type: ignore
CertificateError,
match_hostname,
)
except ImportError: except ImportError:
# Our vendored copy # Our vendored copy
from ._implementation import CertificateError, match_hostname from ._implementation import CertificateError, match_hostname # type: ignore
# Not needed, but documenting what we provide. # Not needed, but documenting what we provide.
__all__ = ('CertificateError', 'match_hostname') __all__ = ("CertificateError", "match_hostname")

View file

@ -9,14 +9,13 @@ import sys
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the # ipaddress has been backported to 2.6+ in pypi. If it is installed on the
# system, use it to handle IPAddress ServerAltnames (this was added in # system, use it to handle IPAddress ServerAltnames (this was added in
# python-3.5) otherwise only do DNS matching. This allows # python-3.5) otherwise only do DNS matching. This allows
# backports.ssl_match_hostname to continue to be used all the way back to # backports.ssl_match_hostname to continue to be used in Python 2.7.
# python-2.4.
try: try:
import ipaddress import ipaddress
except ImportError: except ImportError:
ipaddress = None ipaddress = None
__version__ = '3.5.0.1' __version__ = "3.5.0.1"
class CertificateError(ValueError): class CertificateError(ValueError):
@ -34,18 +33,19 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# Ported from python3-syntax: # Ported from python3-syntax:
# leftmost, *remainder = dn.split(r'.') # leftmost, *remainder = dn.split(r'.')
parts = dn.split(r'.') parts = dn.split(r".")
leftmost = parts[0] leftmost = parts[0]
remainder = parts[1:] remainder = parts[1:]
wildcards = leftmost.count('*') wildcards = leftmost.count("*")
if wildcards > max_wildcards: if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more # Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established # than one wildcard per fragment. A survey of established
# policy among SSL implementations showed it to be a # policy among SSL implementations showed it to be a
# reasonable choice. # reasonable choice.
raise CertificateError( raise CertificateError(
"too many wildcards in certificate DNS name: " + repr(dn)) "too many wildcards in certificate DNS name: " + repr(dn)
)
# speed up common case w/o wildcards # speed up common case w/o wildcards
if not wildcards: if not wildcards:
@ -54,11 +54,11 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# RFC 6125, section 6.4.3, subitem 1. # RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which # The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label. # the wildcard character comprises a label other than the left-most label.
if leftmost == '*': if leftmost == "*":
# When '*' is a fragment by itself, it matches a non-empty dotless # When '*' is a fragment by itself, it matches a non-empty dotless
# fragment. # fragment.
pats.append('[^.]+') pats.append("[^.]+")
elif leftmost.startswith('xn--') or hostname.startswith('xn--'): elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
# RFC 6125, section 6.4.3, subitem 3. # RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier # The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or # where the wildcard character is embedded within an A-label or
@ -66,21 +66,22 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
pats.append(re.escape(leftmost)) pats.append(re.escape(leftmost))
else: else:
# Otherwise, '*' matches any dotless string, e.g. www* # Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
# add the remaining fragments, ignore any wildcards # add the remaining fragments, ignore any wildcards
for frag in remainder: for frag in remainder:
pats.append(re.escape(frag)) pats.append(re.escape(frag))
pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
return pat.match(hostname) return pat.match(hostname)
def _to_unicode(obj): def _to_unicode(obj):
if isinstance(obj, str) and sys.version_info < (3,): if isinstance(obj, str) and sys.version_info < (3,):
obj = unicode(obj, encoding='ascii', errors='strict') obj = unicode(obj, encoding="ascii", errors="strict")
return obj return obj
def _ipaddress_match(ipname, host_ip): def _ipaddress_match(ipname, host_ip):
"""Exact matching of IP addresses. """Exact matching of IP addresses.
@ -102,9 +103,11 @@ def match_hostname(cert, hostname):
returns nothing. returns nothing.
""" """
if not cert: if not cert:
raise ValueError("empty or no certificate, match_hostname needs a " raise ValueError(
"empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either " "SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED") "CERT_OPTIONAL or CERT_REQUIRED"
)
try: try:
# Divergence from upstream: ipaddress can't handle byte str # Divergence from upstream: ipaddress can't handle byte str
host_ip = ipaddress.ip_address(_to_unicode(hostname)) host_ip = ipaddress.ip_address(_to_unicode(hostname))
@ -123,35 +126,35 @@ def match_hostname(cert, hostname):
else: else:
raise raise
dnsnames = [] dnsnames = []
san = cert.get('subjectAltName', ()) san = cert.get("subjectAltName", ())
for key, value in san: for key, value in san:
if key == 'DNS': if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname): if host_ip is None and _dnsname_match(value, hostname):
return return
dnsnames.append(value) dnsnames.append(value)
elif key == 'IP Address': elif key == "IP Address":
if host_ip is not None and _ipaddress_match(value, host_ip): if host_ip is not None and _ipaddress_match(value, host_ip):
return return
dnsnames.append(value) dnsnames.append(value)
if not dnsnames: if not dnsnames:
# The subject is only checked when there is no dNSName entry # The subject is only checked when there is no dNSName entry
# in subjectAltName # in subjectAltName
for sub in cert.get('subject', ()): for sub in cert.get("subject", ()):
for key, value in sub: for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name # XXX according to RFC 2818, the most specific Common Name
# must be used. # must be used.
if key == 'commonName': if key == "commonName":
if _dnsname_match(value, hostname): if _dnsname_match(value, hostname):
return return
dnsnames.append(value) dnsnames.append(value)
if len(dnsnames) > 1: if len(dnsnames) > 1:
raise CertificateError("hostname %r " raise CertificateError(
"doesn't match either of %s" "hostname %r "
% (hostname, ', '.join(map(repr, dnsnames)))) "doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1: elif len(dnsnames) == 1:
raise CertificateError("hostname %r " raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
"doesn't match %r"
% (hostname, dnsnames[0]))
else: else:
raise CertificateError("no appropriate commonName or " raise CertificateError(
"subjectAltName fields were found") "no appropriate commonName or subjectAltName fields were found"
)

View file

@ -1,57 +1,78 @@
from __future__ import absolute_import from __future__ import absolute_import
import collections import collections
import functools import functools
import logging import logging
from ._collections import RecentlyUsedContainer from ._collections import RecentlyUsedContainer
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme
from .connectionpool import port_by_scheme from .exceptions import (
from .exceptions import LocationValueError, MaxRetryError, ProxySchemeUnknown LocationValueError,
MaxRetryError,
ProxySchemeUnknown,
ProxySchemeUnsupported,
URLSchemeUnknown,
)
from .packages import six
from .packages.six.moves.urllib.parse import urljoin from .packages.six.moves.urllib.parse import urljoin
from .request import RequestMethods from .request import RequestMethods
from .util.url import parse_url from .util.proxy import connection_requires_http_tunnel
from .util.retry import Retry from .util.retry import Retry
from .util.url import parse_url
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
__all__ = ['PoolManager', 'ProxyManager', 'proxy_from_url']
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
SSL_KEYWORDS = ('key_file', 'cert_file', 'cert_reqs', 'ca_certs', SSL_KEYWORDS = (
'ssl_version', 'ca_cert_dir', 'ssl_context') "key_file",
"cert_file",
"cert_reqs",
"ca_certs",
"ssl_version",
"ca_cert_dir",
"ssl_context",
"key_password",
)
# All known keyword arguments that could be provided to the pool manager, its # All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key. # pools, or the underlying connections. This is used to construct a pool key.
_key_fields = ( _key_fields = (
'key_scheme', # str "key_scheme", # str
'key_host', # str "key_host", # str
'key_port', # int "key_port", # int
'key_timeout', # int or float or Timeout "key_timeout", # int or float or Timeout
'key_retries', # int or Retry "key_retries", # int or Retry
'key_strict', # bool "key_strict", # bool
'key_block', # bool "key_block", # bool
'key_source_address', # str "key_source_address", # str
'key_key_file', # str "key_key_file", # str
'key_cert_file', # str "key_key_password", # str
'key_cert_reqs', # str "key_cert_file", # str
'key_ca_certs', # str "key_cert_reqs", # str
'key_ssl_version', # str "key_ca_certs", # str
'key_ca_cert_dir', # str "key_ssl_version", # str
'key_ssl_context', # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext "key_ca_cert_dir", # str
'key_maxsize', # int "key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
'key_headers', # dict "key_maxsize", # int
'key__proxy', # parsed proxy url "key_headers", # dict
'key__proxy_headers', # dict "key__proxy", # parsed proxy url
'key_socket_options', # list of (level (int), optname (int), value (int or str)) tuples "key__proxy_headers", # dict
'key__socks_options', # dict "key__proxy_config", # class
'key_assert_hostname', # bool or string "key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
'key_assert_fingerprint', # str "key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
) )
#: The namedtuple class used to construct keys for the connection pool. #: The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum. #: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple('PoolKey', _key_fields) PoolKey = collections.namedtuple("PoolKey", _key_fields)
_proxy_config_fields = ("ssl_context", "use_forwarding_for_https")
ProxyConfig = collections.namedtuple("ProxyConfig", _proxy_config_fields)
def _default_key_normalizer(key_class, request_context): def _default_key_normalizer(key_class, request_context):
@ -76,24 +97,24 @@ def _default_key_normalizer(key_class, request_context):
""" """
# Since we mutate the dictionary, make a copy first # Since we mutate the dictionary, make a copy first
context = request_context.copy() context = request_context.copy()
context['scheme'] = context['scheme'].lower() context["scheme"] = context["scheme"].lower()
context['host'] = context['host'].lower() context["host"] = context["host"].lower()
# These are both dictionaries and need to be transformed into frozensets # These are both dictionaries and need to be transformed into frozensets
for key in ('headers', '_proxy_headers', '_socks_options'): for key in ("headers", "_proxy_headers", "_socks_options"):
if key in context and context[key] is not None: if key in context and context[key] is not None:
context[key] = frozenset(context[key].items()) context[key] = frozenset(context[key].items())
# The socket_options key may be a list and needs to be transformed into a # The socket_options key may be a list and needs to be transformed into a
# tuple. # tuple.
socket_opts = context.get('socket_options') socket_opts = context.get("socket_options")
if socket_opts is not None: if socket_opts is not None:
context['socket_options'] = tuple(socket_opts) context["socket_options"] = tuple(socket_opts)
# Map the kwargs to the names in the namedtuple - this is necessary since # Map the kwargs to the names in the namedtuple - this is necessary since
# namedtuples can't have fields starting with '_'. # namedtuples can't have fields starting with '_'.
for key in list(context.keys()): for key in list(context.keys()):
context['key_' + key] = context.pop(key) context["key_" + key] = context.pop(key)
# Default to ``None`` for keys missing from the context # Default to ``None`` for keys missing from the context
for field in key_class._fields: for field in key_class._fields:
@ -108,14 +129,11 @@ def _default_key_normalizer(key_class, request_context):
#: Each PoolManager makes a copy of this dictionary so they can be configured #: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance. #: globally here, or individually on the instance.
key_fn_by_scheme = { key_fn_by_scheme = {
'http': functools.partial(_default_key_normalizer, PoolKey), "http": functools.partial(_default_key_normalizer, PoolKey),
'https': functools.partial(_default_key_normalizer, PoolKey), "https": functools.partial(_default_key_normalizer, PoolKey),
} }
pool_classes_by_scheme = { pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
'http': HTTPConnectionPool,
'https': HTTPSConnectionPool,
}
class PoolManager(RequestMethods): class PoolManager(RequestMethods):
@ -147,12 +165,12 @@ class PoolManager(RequestMethods):
""" """
proxy = None proxy = None
proxy_config = None
def __init__(self, num_pools=10, headers=None, **connection_pool_kw): def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
RequestMethods.__init__(self, headers) RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools, self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
dispose_func=lambda p: p.close())
# Locally set the pool classes and keys so other PoolManagers can # Locally set the pool classes and keys so other PoolManagers can
# override them. # override them.
@ -169,7 +187,7 @@ class PoolManager(RequestMethods):
def _new_pool(self, scheme, host, port, request_context=None): def _new_pool(self, scheme, host, port, request_context=None):
""" """
Create a new :class:`ConnectionPool` based on host, port, scheme, and Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and
any additional pool keyword arguments. any additional pool keyword arguments.
If ``request_context`` is provided, it is provided as keyword arguments If ``request_context`` is provided, it is provided as keyword arguments
@ -185,10 +203,10 @@ class PoolManager(RequestMethods):
# this function has historically only used the scheme, host, and port # this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can # in the positional args. When an API change is acceptable these can
# be removed. # be removed.
for key in ('scheme', 'host', 'port'): for key in ("scheme", "host", "port"):
request_context.pop(key, None) request_context.pop(key, None)
if scheme == 'http': if scheme == "http":
for kw in SSL_KEYWORDS: for kw in SSL_KEYWORDS:
request_context.pop(kw, None) request_context.pop(kw, None)
@ -203,9 +221,9 @@ class PoolManager(RequestMethods):
""" """
self.pools.clear() self.pools.clear()
def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None): def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
""" """
Get a :class:`ConnectionPool` based on the host, port, and scheme. Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme.
If ``port`` isn't given, it will be derived from the ``scheme`` using If ``port`` isn't given, it will be derived from the ``scheme`` using
``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is ``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is
@ -218,30 +236,32 @@ class PoolManager(RequestMethods):
raise LocationValueError("No host specified.") raise LocationValueError("No host specified.")
request_context = self._merge_pool_kwargs(pool_kwargs) request_context = self._merge_pool_kwargs(pool_kwargs)
request_context['scheme'] = scheme or 'http' request_context["scheme"] = scheme or "http"
if not port: if not port:
port = port_by_scheme.get(request_context['scheme'].lower(), 80) port = port_by_scheme.get(request_context["scheme"].lower(), 80)
request_context['port'] = port request_context["port"] = port
request_context['host'] = host request_context["host"] = host
return self.connection_from_context(request_context) return self.connection_from_context(request_context)
def connection_from_context(self, request_context): def connection_from_context(self, request_context):
""" """
Get a :class:`ConnectionPool` based on the request context. Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context.
``request_context`` must at least contain the ``scheme`` key and its ``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable. value must be a key in ``key_fn_by_scheme`` instance variable.
""" """
scheme = request_context['scheme'].lower() scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme[scheme] pool_key_constructor = self.key_fn_by_scheme.get(scheme)
if not pool_key_constructor:
raise URLSchemeUnknown(scheme)
pool_key = pool_key_constructor(request_context) pool_key = pool_key_constructor(request_context)
return self.connection_from_pool_key(pool_key, request_context=request_context) return self.connection_from_pool_key(pool_key, request_context=request_context)
def connection_from_pool_key(self, pool_key, request_context=None): def connection_from_pool_key(self, pool_key, request_context=None):
""" """
Get a :class:`ConnectionPool` based on the provided pool key. Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key.
``pool_key`` should be a namedtuple that only contains immutable ``pool_key`` should be a namedtuple that only contains immutable
objects. At a minimum it must have the ``scheme``, ``host``, and objects. At a minimum it must have the ``scheme``, ``host``, and
@ -255,9 +275,9 @@ class PoolManager(RequestMethods):
return pool return pool
# Make a fresh ConnectionPool of the desired type # Make a fresh ConnectionPool of the desired type
scheme = request_context['scheme'] scheme = request_context["scheme"]
host = request_context['host'] host = request_context["host"]
port = request_context['port'] port = request_context["port"]
pool = self._new_pool(scheme, host, port, request_context=request_context) pool = self._new_pool(scheme, host, port, request_context=request_context)
self.pools[pool_key] = pool self.pools[pool_key] = pool
@ -275,8 +295,9 @@ class PoolManager(RequestMethods):
not used. not used.
""" """
u = parse_url(url) u = parse_url(url)
return self.connection_from_host(u.host, port=u.port, scheme=u.scheme, return self.connection_from_host(
pool_kwargs=pool_kwargs) u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
)
def _merge_pool_kwargs(self, override): def _merge_pool_kwargs(self, override):
""" """
@ -298,9 +319,39 @@ class PoolManager(RequestMethods):
base_pool_kwargs[key] = value base_pool_kwargs[key] = value
return base_pool_kwargs return base_pool_kwargs
def _proxy_requires_url_absolute_form(self, parsed_url):
"""
Indicates if the proxy requires the complete destination URL in the
request. Normally this is only needed when not using an HTTP CONNECT
tunnel.
"""
if self.proxy is None:
return False
return not connection_requires_http_tunnel(
self.proxy, self.proxy_config, parsed_url.scheme
)
def _validate_proxy_scheme_url_selection(self, url_scheme):
"""
Validates that were not attempting to do TLS in TLS connections on
Python2 or with unsupported SSL implementations.
"""
if self.proxy is None or url_scheme != "https":
return
if self.proxy.scheme != "https":
return
if six.PY2 and not self.proxy_config.use_forwarding_for_https:
raise ProxySchemeUnsupported(
"Contacting HTTPS destinations through HTTPS proxies "
"'via CONNECT tunnels' is not supported in Python 2"
)
def urlopen(self, method, url, redirect=True, **kw): def urlopen(self, method, url, redirect=True, **kw):
""" """
Same as :meth:`urllib3.connectionpool.HTTPConnectionPool.urlopen` Same as :meth:`urllib3.HTTPConnectionPool.urlopen`
with custom cross-host redirect logic and only sends the request-uri with custom cross-host redirect logic and only sends the request-uri
portion of the ``url``. portion of the ``url``.
@ -308,14 +359,17 @@ class PoolManager(RequestMethods):
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it. :class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
""" """
u = parse_url(url) u = parse_url(url)
self._validate_proxy_scheme_url_selection(u.scheme)
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme) conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
kw['assert_same_host'] = False kw["assert_same_host"] = False
kw['redirect'] = False kw["redirect"] = False
if 'headers' not in kw:
kw['headers'] = self.headers
if self.proxy is not None and u.scheme == "http": if "headers" not in kw:
kw["headers"] = self.headers.copy()
if self._proxy_requires_url_absolute_form(u):
response = conn.urlopen(method, url, **kw) response = conn.urlopen(method, url, **kw)
else: else:
response = conn.urlopen(method, u.request_uri, **kw) response = conn.urlopen(method, u.request_uri, **kw)
@ -329,23 +383,37 @@ class PoolManager(RequestMethods):
# RFC 7231, Section 6.4.4 # RFC 7231, Section 6.4.4
if response.status == 303: if response.status == 303:
method = 'GET' method = "GET"
retries = kw.get('retries') retries = kw.get("retries")
if not isinstance(retries, Retry): if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect) retries = Retry.from_int(retries, redirect=redirect)
# Strip headers marked as unsafe to forward to the redirected location.
# Check remove_headers_on_redirect to avoid a potential network call within
# conn.is_same_host() which may use socket.gethostbyname() in the future.
if retries.remove_headers_on_redirect and not conn.is_same_host(
redirect_location
):
headers = list(six.iterkeys(kw["headers"]))
for header in headers:
if header.lower() in retries.remove_headers_on_redirect:
kw["headers"].pop(header, None)
try: try:
retries = retries.increment(method, url, response=response, _pool=conn) retries = retries.increment(method, url, response=response, _pool=conn)
except MaxRetryError: except MaxRetryError:
if retries.raise_on_redirect: if retries.raise_on_redirect:
response.drain_conn()
raise raise
return response return response
kw['retries'] = retries kw["retries"] = retries
kw['redirect'] = redirect kw["redirect"] = redirect
log.info("Redirecting %s -> %s", url, redirect_location) log.info("Redirecting %s -> %s", url, redirect_location)
response.drain_conn()
return self.urlopen(method, redirect_location, **kw) return self.urlopen(method, redirect_location, **kw)
@ -358,11 +426,24 @@ class ProxyManager(PoolManager):
The URL of the proxy to be used. The URL of the proxy to be used.
:param proxy_headers: :param proxy_headers:
A dictionary contaning headers that will be sent to the proxy. In case A dictionary containing headers that will be sent to the proxy. In case
of HTTP they are being sent with each request, while in the of HTTP they are being sent with each request, while in the
HTTPS/CONNECT case they are sent only once. Could be used for proxy HTTPS/CONNECT case they are sent only once. Could be used for proxy
authentication. authentication.
:param proxy_ssl_context:
The proxy SSL context is used to establish the TLS connection to the
proxy when using HTTPS proxies.
:param use_forwarding_for_https:
(Defaults to False) If set to True will forward requests to the HTTPS
proxy to be made on behalf of the client instead of creating a TLS
tunnel via the CONNECT method. **Enabling this flag means that request
and response headers and content will be visible from the HTTPS proxy**
whereas tunneling keeps request and response headers and content
private. IP address, target hostname, SNI, and port are always visible
to an HTTPS proxy even when this flag is disabled.
Example: Example:
>>> proxy = urllib3.ProxyManager('http://localhost:3128/') >>> proxy = urllib3.ProxyManager('http://localhost:3128/')
>>> r1 = proxy.request('GET', 'http://google.com/') >>> r1 = proxy.request('GET', 'http://google.com/')
@ -376,47 +457,63 @@ class ProxyManager(PoolManager):
""" """
def __init__(self, proxy_url, num_pools=10, headers=None, def __init__(
proxy_headers=None, **connection_pool_kw): self,
proxy_url,
num_pools=10,
headers=None,
proxy_headers=None,
proxy_ssl_context=None,
use_forwarding_for_https=False,
**connection_pool_kw
):
if isinstance(proxy_url, HTTPConnectionPool): if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = '%s://%s:%i' % (proxy_url.scheme, proxy_url.host, proxy_url = "%s://%s:%i" % (
proxy_url.port) proxy_url.scheme,
proxy_url.host,
proxy_url.port,
)
proxy = parse_url(proxy_url) proxy = parse_url(proxy_url)
if not proxy.port:
port = port_by_scheme.get(proxy.scheme, 80)
proxy = proxy._replace(port=port)
if proxy.scheme not in ("http", "https"): if proxy.scheme not in ("http", "https"):
raise ProxySchemeUnknown(proxy.scheme) raise ProxySchemeUnknown(proxy.scheme)
if not proxy.port:
port = port_by_scheme.get(proxy.scheme, 80)
proxy = proxy._replace(port=port)
self.proxy = proxy self.proxy = proxy
self.proxy_headers = proxy_headers or {} self.proxy_headers = proxy_headers or {}
self.proxy_ssl_context = proxy_ssl_context
self.proxy_config = ProxyConfig(proxy_ssl_context, use_forwarding_for_https)
connection_pool_kw['_proxy'] = self.proxy connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw['_proxy_headers'] = self.proxy_headers connection_pool_kw["_proxy_headers"] = self.proxy_headers
connection_pool_kw["_proxy_config"] = self.proxy_config
super(ProxyManager, self).__init__( super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
num_pools, headers, **connection_pool_kw)
def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None): def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
if scheme == "https": if scheme == "https":
return super(ProxyManager, self).connection_from_host( return super(ProxyManager, self).connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs) host, port, scheme, pool_kwargs=pool_kwargs
)
return super(ProxyManager, self).connection_from_host( return super(ProxyManager, self).connection_from_host(
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs) self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
)
def _set_proxy_headers(self, url, headers=None): def _set_proxy_headers(self, url, headers=None):
""" """
Sets headers needed by proxies: specifically, the Accept and Host Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user. headers. Only sets headers not provided by the user.
""" """
headers_ = {'Accept': '*/*'} headers_ = {"Accept": "*/*"}
netloc = parse_url(url).netloc netloc = parse_url(url).netloc
if netloc: if netloc:
headers_['Host'] = netloc headers_["Host"] = netloc
if headers: if headers:
headers_.update(headers) headers_.update(headers)
@ -425,13 +522,12 @@ class ProxyManager(PoolManager):
def urlopen(self, method, url, redirect=True, **kw): def urlopen(self, method, url, redirect=True, **kw):
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute." "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
u = parse_url(url) u = parse_url(url)
if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme):
if u.scheme == "http": # For connections using HTTP CONNECT, httplib sets the necessary
# For proxied HTTPS requests, httplib sets the necessary headers # headers on the CONNECT to the proxy. If we're not using CONNECT,
# on the CONNECT to the proxy. For HTTP, we'll definitely # we'll definitely need to set 'Host' at the very least.
# need to set 'Host' at the very least. headers = kw.get("headers", self.headers)
headers = kw.get('headers', self.headers) kw["headers"] = self._set_proxy_headers(url, headers)
kw['headers'] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw) return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)

View file

@ -3,15 +3,14 @@ from __future__ import absolute_import
from .filepost import encode_multipart_formdata from .filepost import encode_multipart_formdata
from .packages.six.moves.urllib.parse import urlencode from .packages.six.moves.urllib.parse import urlencode
__all__ = ["RequestMethods"]
__all__ = ['RequestMethods']
class RequestMethods(object): class RequestMethods(object):
""" """
Convenience mixin for classes who implement a :meth:`urlopen` method, such Convenience mixin for classes who implement a :meth:`urlopen` method, such
as :class:`~urllib3.connectionpool.HTTPConnectionPool` and as :class:`urllib3.HTTPConnectionPool` and
:class:`~urllib3.poolmanager.PoolManager`. :class:`urllib3.PoolManager`.
Provides behavior for making common types of HTTP request methods and Provides behavior for making common types of HTTP request methods and
decides which type of request field encoding to use. decides which type of request field encoding to use.
@ -36,16 +35,25 @@ class RequestMethods(object):
explicitly. explicitly.
""" """
_encode_url_methods = set(['DELETE', 'GET', 'HEAD', 'OPTIONS']) _encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers=None): def __init__(self, headers=None):
self.headers = headers or {} self.headers = headers or {}
def urlopen(self, method, url, body=None, headers=None, def urlopen(
encode_multipart=True, multipart_boundary=None, self,
**kw): # Abstract method,
raise NotImplemented("Classes extending RequestMethods must implement " url,
"their own ``urlopen`` method.") body=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**kw
): # Abstract
raise NotImplementedError(
"Classes extending RequestMethods must implement "
"their own ``urlopen`` method."
)
def request(self, method, url, fields=None, headers=None, **urlopen_kw): def request(self, method, url, fields=None, headers=None, **urlopen_kw):
""" """
@ -60,17 +68,18 @@ class RequestMethods(object):
""" """
method = method.upper() method = method.upper()
if method in self._encode_url_methods: urlopen_kw["request_url"] = url
return self.request_encode_url(method, url, fields=fields,
headers=headers,
**urlopen_kw)
else:
return self.request_encode_body(method, url, fields=fields,
headers=headers,
**urlopen_kw)
def request_encode_url(self, method, url, fields=None, headers=None, if method in self._encode_url_methods:
**urlopen_kw): return self.request_encode_url(
method, url, fields=fields, headers=headers, **urlopen_kw
)
else:
return self.request_encode_body(
method, url, fields=fields, headers=headers, **urlopen_kw
)
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
""" """
Make a request using :meth:`urlopen` with the ``fields`` encoded in Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc. the url. This is useful for request methods like GET, HEAD, DELETE, etc.
@ -78,25 +87,32 @@ class RequestMethods(object):
if headers is None: if headers is None:
headers = self.headers headers = self.headers
extra_kw = {'headers': headers} extra_kw = {"headers": headers}
extra_kw.update(urlopen_kw) extra_kw.update(urlopen_kw)
if fields: if fields:
url += '?' + urlencode(fields) url += "?" + urlencode(fields)
return self.urlopen(method, url, **extra_kw) return self.urlopen(method, url, **extra_kw)
def request_encode_body(self, method, url, fields=None, headers=None, def request_encode_body(
encode_multipart=True, multipart_boundary=None, self,
**urlopen_kw): method,
url,
fields=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**urlopen_kw
):
""" """
Make a request using :meth:`urlopen` with the ``fields`` encoded in Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc. the body. This is useful for request methods like POST, PUT, PATCH, etc.
When ``encode_multipart=True`` (default), then When ``encode_multipart=True`` (default), then
:meth:`urllib3.filepost.encode_multipart_formdata` is used to encode :func:`urllib3.encode_multipart_formdata` is used to encode
the payload with the appropriate content type. Otherwise the payload with the appropriate content type. Otherwise
:meth:`urllib.urlencode` is used with the :func:`urllib.parse.urlencode` is used with the
'application/x-www-form-urlencoded' content type. 'application/x-www-form-urlencoded' content type.
Multipart encoding must be used when posting files, and it's reasonably Multipart encoding must be used when posting files, and it's reasonably
@ -117,7 +133,7 @@ class RequestMethods(object):
} }
When uploading a file, providing a filename (the first parameter of the When uploading a file, providing a filename (the first parameter of the
tuple) is optional but recommended to best mimick behavior of browsers. tuple) is optional but recommended to best mimic behavior of browsers.
Note that if ``headers`` are supplied, the 'Content-Type' header will Note that if ``headers`` are supplied, the 'Content-Type' header will
be overwritten because it depends on the dynamic random boundary string be overwritten because it depends on the dynamic random boundary string
@ -127,22 +143,28 @@ class RequestMethods(object):
if headers is None: if headers is None:
headers = self.headers headers = self.headers
extra_kw = {'headers': {}} extra_kw = {"headers": {}}
if fields: if fields:
if 'body' in urlopen_kw: if "body" in urlopen_kw:
raise TypeError( raise TypeError(
"request got values for both 'fields' and 'body', can only specify one.") "request got values for both 'fields' and 'body', can only specify one."
)
if encode_multipart: if encode_multipart:
body, content_type = encode_multipart_formdata(fields, boundary=multipart_boundary) body, content_type = encode_multipart_formdata(
fields, boundary=multipart_boundary
)
else: else:
body, content_type = urlencode(fields), 'application/x-www-form-urlencoded' body, content_type = (
urlencode(fields),
"application/x-www-form-urlencoded",
)
extra_kw['body'] = body extra_kw["body"] = body
extra_kw['headers'] = {'Content-Type': content_type} extra_kw["headers"] = {"Content-Type": content_type}
extra_kw['headers'].update(headers) extra_kw["headers"].update(headers)
extra_kw.update(urlopen_kw) extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw) return self.urlopen(method, url, **extra_kw)

View file

@ -1,29 +1,41 @@
from __future__ import absolute_import from __future__ import absolute_import
from contextlib import contextmanager
import zlib
import io import io
import logging import logging
from socket import timeout as SocketTimeout import zlib
from contextlib import contextmanager
from socket import error as SocketError from socket import error as SocketError
from socket import timeout as SocketTimeout
try:
import brotli
except ImportError:
brotli = None
from ._collections import HTTPHeaderDict from ._collections import HTTPHeaderDict
from .connection import BaseSSLError, HTTPException
from .exceptions import ( from .exceptions import (
BodyNotHttplibCompatible, ProtocolError, DecodeError, ReadTimeoutError, BodyNotHttplibCompatible,
ResponseNotChunked, IncompleteRead, InvalidHeader DecodeError,
HTTPError,
IncompleteRead,
InvalidChunkLength,
InvalidHeader,
ProtocolError,
ReadTimeoutError,
ResponseNotChunked,
SSLError,
) )
from .packages.six import string_types as basestring, binary_type, PY3 from .packages import six
from .packages.six.moves import http_client as httplib
from .connection import HTTPException, BaseSSLError
from .util.response import is_fp_closed, is_response_to_head from .util.response import is_fp_closed, is_response_to_head
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class DeflateDecoder(object): class DeflateDecoder(object):
def __init__(self): def __init__(self):
self._first_try = True self._first_try = True
self._data = binary_type() self._data = b""
self._obj = zlib.decompressobj() self._obj = zlib.decompressobj()
def __getattr__(self, name): def __getattr__(self, name):
@ -52,24 +64,93 @@ class DeflateDecoder(object):
self._data = None self._data = None
class GzipDecoder(object): class GzipDecoderState(object):
FIRST_MEMBER = 0
OTHER_MEMBERS = 1
SWALLOW_DATA = 2
class GzipDecoder(object):
def __init__(self): def __init__(self):
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
self._state = GzipDecoderState.FIRST_MEMBER
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._obj, name) return getattr(self._obj, name)
def decompress(self, data): def decompress(self, data):
ret = bytearray()
if self._state == GzipDecoderState.SWALLOW_DATA or not data:
return bytes(ret)
while True:
try:
ret += self._obj.decompress(data)
except zlib.error:
previous_state = self._state
# Ignore data after the first error
self._state = GzipDecoderState.SWALLOW_DATA
if previous_state == GzipDecoderState.OTHER_MEMBERS:
# Allow trailing garbage acceptable in other gzip clients
return bytes(ret)
raise
data = self._obj.unused_data
if not data: if not data:
return bytes(ret)
self._state = GzipDecoderState.OTHER_MEMBERS
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
if brotli is not None:
class BrotliDecoder(object):
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self):
self._obj = brotli.Decompressor()
if hasattr(self._obj, "decompress"):
self.decompress = self._obj.decompress
else:
self.decompress = self._obj.process
def flush(self):
if hasattr(self._obj, "flush"):
return self._obj.flush()
return b""
class MultiDecoder(object):
"""
From RFC7231:
If one or more encodings have been applied to a representation, the
sender that applied the encodings MUST generate a Content-Encoding
header field that lists the content codings in the order in which
they were applied.
"""
def __init__(self, modes):
self._decoders = [_get_decoder(m.strip()) for m in modes.split(",")]
def flush(self):
return self._decoders[0].flush()
def decompress(self, data):
for d in reversed(self._decoders):
data = d.decompress(data)
return data return data
return self._obj.decompress(data)
def _get_decoder(mode): def _get_decoder(mode):
if mode == 'gzip': if "," in mode:
return MultiDecoder(mode)
if mode == "gzip":
return GzipDecoder() return GzipDecoder()
if brotli is not None and mode == "br":
return BrotliDecoder()
return DeflateDecoder() return DeflateDecoder()
@ -77,24 +158,23 @@ class HTTPResponse(io.IOBase):
""" """
HTTP Response container. HTTP Response container.
Backwards-compatible to httplib's HTTPResponse but the response ``body`` is Backwards-compatible with :class:`http.client.HTTPResponse` but the response ``body`` is
loaded and decoded on-demand when the ``data`` property is accessed. This loaded and decoded on-demand when the ``data`` property is accessed. This
class is also compatible with the Python standard library's :mod:`io` class is also compatible with the Python standard library's :mod:`io`
module, and can hence be treated as a readable object in the context of that module, and can hence be treated as a readable object in the context of that
framework. framework.
Extra parameters for behaviour not present in httplib.HTTPResponse: Extra parameters for behaviour not present in :class:`http.client.HTTPResponse`:
:param preload_content: :param preload_content:
If True, the response's body will be preloaded during construction. If True, the response's body will be preloaded during construction.
:param decode_content: :param decode_content:
If True, attempts to decode specific content-encoding's based on headers If True, will attempt to decode the body based on the
(like 'gzip' and 'deflate') will be skipped and raw data will be used 'content-encoding' header.
instead.
:param original_response: :param original_response:
When this HTTPResponse wrapper is generated from an httplib.HTTPResponse When this HTTPResponse wrapper is generated from an :class:`http.client.HTTPResponse`
object, it's convenient to include the original for debug purposes. It's object, it's convenient to include the original for debug purposes. It's
otherwise unused. otherwise unused.
@ -107,13 +187,31 @@ class HTTPResponse(io.IOBase):
value of Content-Length header, if present. Otherwise, raise error. value of Content-Length header, if present. Otherwise, raise error.
""" """
CONTENT_DECODERS = ['gzip', 'deflate'] CONTENT_DECODERS = ["gzip", "deflate"]
if brotli is not None:
CONTENT_DECODERS += ["br"]
REDIRECT_STATUSES = [301, 302, 303, 307, 308] REDIRECT_STATUSES = [301, 302, 303, 307, 308]
def __init__(self, body='', headers=None, status=0, version=0, reason=None, def __init__(
strict=0, preload_content=True, decode_content=True, self,
original_response=None, pool=None, connection=None, body="",
retries=None, enforce_content_length=False, request_method=None): headers=None,
status=0,
version=0,
reason=None,
strict=0,
preload_content=True,
decode_content=True,
original_response=None,
pool=None,
connection=None,
msg=None,
retries=None,
enforce_content_length=False,
request_method=None,
request_url=None,
auto_close=True,
):
if isinstance(headers, HTTPHeaderDict): if isinstance(headers, HTTPHeaderDict):
self.headers = headers self.headers = headers
@ -126,26 +224,29 @@ class HTTPResponse(io.IOBase):
self.decode_content = decode_content self.decode_content = decode_content
self.retries = retries self.retries = retries
self.enforce_content_length = enforce_content_length self.enforce_content_length = enforce_content_length
self.auto_close = auto_close
self._decoder = None self._decoder = None
self._body = None self._body = None
self._fp = None self._fp = None
self._original_response = original_response self._original_response = original_response
self._fp_bytes_read = 0 self._fp_bytes_read = 0
self.msg = msg
self._request_url = request_url
if body and isinstance(body, (basestring, binary_type)): if body and isinstance(body, (six.string_types, bytes)):
self._body = body self._body = body
self._pool = pool self._pool = pool
self._connection = connection self._connection = connection
if hasattr(body, 'read'): if hasattr(body, "read"):
self._fp = body self._fp = body
# Are we using the chunked-style of transfer encoding? # Are we using the chunked-style of transfer encoding?
self.chunked = False self.chunked = False
self.chunk_left = None self.chunk_left = None
tr_enc = self.headers.get('transfer-encoding', '').lower() tr_enc = self.headers.get("transfer-encoding", "").lower()
# Don't incur the penalty of creating a list and then discarding it # Don't incur the penalty of creating a list and then discarding it
encodings = (enc.strip() for enc in tr_enc.split(",")) encodings = (enc.strip() for enc in tr_enc.split(","))
if "chunked" in encodings: if "chunked" in encodings:
@ -167,7 +268,7 @@ class HTTPResponse(io.IOBase):
location. ``False`` if not a redirect status code. location. ``False`` if not a redirect status code.
""" """
if self.status in self.REDIRECT_STATUSES: if self.status in self.REDIRECT_STATUSES:
return self.headers.get('location') return self.headers.get("location")
return False return False
@ -178,9 +279,20 @@ class HTTPResponse(io.IOBase):
self._pool._put_conn(self._connection) self._pool._put_conn(self._connection)
self._connection = None self._connection = None
def drain_conn(self):
"""
Read and discard any remaining HTTP response data in the response connection.
Unread data in the HTTPResponse connection blocks the connection from being released back to the pool.
"""
try:
self.read()
except (HTTPError, SocketError, BaseSSLError, HTTPException):
pass
@property @property
def data(self): def data(self):
# For backwords-compat with earlier urllib3 0.4 and earlier. # For backwards-compat with earlier urllib3 0.4 and earlier.
if self._body: if self._body:
return self._body return self._body
@ -191,11 +303,14 @@ class HTTPResponse(io.IOBase):
def connection(self): def connection(self):
return self._connection return self._connection
def isclosed(self):
return is_fp_closed(self._fp)
def tell(self): def tell(self):
""" """
Obtain the number of bytes pulled over the wire so far. May differ from Obtain the number of bytes pulled over the wire so far. May differ from
the amount of content returned by :meth:``HTTPResponse.read`` if bytes the amount of content returned by :meth:``urllib3.response.HTTPResponse.read``
are encoded on the wire (e.g, compressed). if bytes are encoded on the wire (e.g, compressed).
""" """
return self._fp_bytes_read return self._fp_bytes_read
@ -203,30 +318,34 @@ class HTTPResponse(io.IOBase):
""" """
Set initial length value for Response content if available. Set initial length value for Response content if available.
""" """
length = self.headers.get('content-length') length = self.headers.get("content-length")
if length is not None and self.chunked: if length is not None:
if self.chunked:
# This Response will fail with an IncompleteRead if it can't be # This Response will fail with an IncompleteRead if it can't be
# received as chunked. This method falls back to attempt reading # received as chunked. This method falls back to attempt reading
# the response before raising an exception. # the response before raising an exception.
log.warning("Received response with both Content-Length and " log.warning(
"Received response with both Content-Length and "
"Transfer-Encoding set. This is expressly forbidden " "Transfer-Encoding set. This is expressly forbidden "
"by RFC 7230 sec 3.3.2. Ignoring Content-Length and " "by RFC 7230 sec 3.3.2. Ignoring Content-Length and "
"attempting to process response as Transfer-Encoding: " "attempting to process response as Transfer-Encoding: "
"chunked.") "chunked."
)
return None return None
elif length is not None:
try: try:
# RFC 7230 section 3.3.2 specifies multiple content lengths can # RFC 7230 section 3.3.2 specifies multiple content lengths can
# be sent in a single Content-Length header # be sent in a single Content-Length header
# (e.g. Content-Length: 42, 42). This line ensures the values # (e.g. Content-Length: 42, 42). This line ensures the values
# are all valid ints and that as long as the `set` length is 1, # are all valid ints and that as long as the `set` length is 1,
# all values are the same. Otherwise, the header is invalid. # all values are the same. Otherwise, the header is invalid.
lengths = set([int(val) for val in length.split(',')]) lengths = set([int(val) for val in length.split(",")])
if len(lengths) > 1: if len(lengths) > 1:
raise InvalidHeader("Content-Length contained multiple " raise InvalidHeader(
"unmatching values (%s)" % length) "Content-Length contained multiple "
"unmatching values (%s)" % length
)
length = lengths.pop() length = lengths.pop()
except ValueError: except ValueError:
length = None length = None
@ -242,7 +361,7 @@ class HTTPResponse(io.IOBase):
status = 0 status = 0
# Check for responses that shouldn't include a body # Check for responses that shouldn't include a body
if status in (204, 304) or 100 <= status < 200 or request_method == 'HEAD': if status in (204, 304) or 100 <= status < 200 or request_method == "HEAD":
length = 0 length = 0
return length return length
@ -253,24 +372,41 @@ class HTTPResponse(io.IOBase):
""" """
# Note: content-encoding value should be case-insensitive, per RFC 7230 # Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2 # Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower() content_encoding = self.headers.get("content-encoding", "").lower()
if self._decoder is None and content_encoding in self.CONTENT_DECODERS: if self._decoder is None:
if content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding) self._decoder = _get_decoder(content_encoding)
elif "," in content_encoding:
encodings = [
e.strip()
for e in content_encoding.split(",")
if e.strip() in self.CONTENT_DECODERS
]
if len(encodings):
self._decoder = _get_decoder(content_encoding)
DECODER_ERROR_CLASSES = (IOError, zlib.error)
if brotli is not None:
DECODER_ERROR_CLASSES += (brotli.error,)
def _decode(self, data, decode_content, flush_decoder): def _decode(self, data, decode_content, flush_decoder):
""" """
Decode the data passed in and potentially flush the decoder. Decode the data passed in and potentially flush the decoder.
""" """
if not decode_content:
return data
try: try:
if decode_content and self._decoder: if self._decoder:
data = self._decoder.decompress(data) data = self._decoder.decompress(data)
except (IOError, zlib.error) as e: except self.DECODER_ERROR_CLASSES as e:
content_encoding = self.headers.get('content-encoding', '').lower() content_encoding = self.headers.get("content-encoding", "").lower()
raise DecodeError( raise DecodeError(
"Received response with content-encoding: %s, but " "Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding, e) "failed to decode it." % content_encoding,
e,
if flush_decoder and decode_content: )
if flush_decoder:
data += self._flush_decoder() data += self._flush_decoder()
return data return data
@ -281,10 +417,10 @@ class HTTPResponse(io.IOBase):
being used. being used.
""" """
if self._decoder: if self._decoder:
buf = self._decoder.decompress(b'') buf = self._decoder.decompress(b"")
return buf + self._decoder.flush() return buf + self._decoder.flush()
return b'' return b""
@contextmanager @contextmanager
def _error_catcher(self): def _error_catcher(self):
@ -304,20 +440,19 @@ class HTTPResponse(io.IOBase):
except SocketTimeout: except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but # FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context. # there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, 'Read timed out.') raise ReadTimeoutError(self._pool, None, "Read timed out.")
except BaseSSLError as e: except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors? # FIXME: Is there a better way to differentiate between SSLErrors?
if 'read operation timed out' not in str(e): # Defensive: if "read operation timed out" not in str(e):
# This shouldn't happen but just in case we're missing an edge # SSL errors related to framing/MAC get wrapped and reraised here
# case, let's avoid swallowing SSL errors. raise SSLError(e)
raise
raise ReadTimeoutError(self._pool, None, 'Read timed out.') raise ReadTimeoutError(self._pool, None, "Read timed out.")
except (HTTPException, SocketError) as e: except (HTTPException, SocketError) as e:
# This includes IncompleteRead. # This includes IncompleteRead.
raise ProtocolError('Connection broken: %r' % e, e) raise ProtocolError("Connection broken: %r" % e, e)
# If no exception is thrown, we should avoid cleaning up # If no exception is thrown, we should avoid cleaning up
# unnecessarily. # unnecessarily.
@ -345,7 +480,7 @@ class HTTPResponse(io.IOBase):
def read(self, amt=None, decode_content=None, cache_content=False): def read(self, amt=None, decode_content=None, cache_content=False):
""" """
Similar to :meth:`httplib.HTTPResponse.read`, but with two additional Similar to :meth:`http.client.HTTPResponse.read`, but with two additional
parameters: ``decode_content`` and ``cache_content``. parameters: ``decode_content`` and ``cache_content``.
:param amt: :param amt:
@ -372,17 +507,19 @@ class HTTPResponse(io.IOBase):
return return
flush_decoder = False flush_decoder = False
data = None fp_closed = getattr(self._fp, "closed", False)
with self._error_catcher(): with self._error_catcher():
if amt is None: if amt is None:
# cStringIO doesn't like amt=None # cStringIO doesn't like amt=None
data = self._fp.read() data = self._fp.read() if not fp_closed else b""
flush_decoder = True flush_decoder = True
else: else:
cache_content = False cache_content = False
data = self._fp.read(amt) data = self._fp.read(amt) if not fp_closed else b""
if amt != 0 and not data: # Platform-specific: Buggy versions of Python. if (
amt != 0 and not data
): # Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned # Close the connection when no data is returned
# #
# This is redundant to what httplib/http.client _should_ # This is redundant to what httplib/http.client _should_
@ -392,7 +529,10 @@ class HTTPResponse(io.IOBase):
# no harm in redundantly calling close. # no harm in redundantly calling close.
self._fp.close() self._fp.close()
flush_decoder = True flush_decoder = True
if self.enforce_content_length and self.length_remaining not in (0, None): if self.enforce_content_length and self.length_remaining not in (
0,
None,
):
# This is an edge case that httplib failed to cover due # This is an edge case that httplib failed to cover due
# to concerns of backward compatibility. We're # to concerns of backward compatibility. We're
# addressing it here to make sure IncompleteRead is # addressing it here to make sure IncompleteRead is
@ -441,7 +581,7 @@ class HTTPResponse(io.IOBase):
@classmethod @classmethod
def from_httplib(ResponseCls, r, **response_kw): def from_httplib(ResponseCls, r, **response_kw):
""" """
Given an :class:`httplib.HTTPResponse` instance ``r``, return a Given an :class:`http.client.HTTPResponse` instance ``r``, return a
corresponding :class:`urllib3.response.HTTPResponse` object. corresponding :class:`urllib3.response.HTTPResponse` object.
Remaining parameters are passed to the HTTPResponse constructor, along Remaining parameters are passed to the HTTPResponse constructor, along
@ -450,24 +590,27 @@ class HTTPResponse(io.IOBase):
headers = r.msg headers = r.msg
if not isinstance(headers, HTTPHeaderDict): if not isinstance(headers, HTTPHeaderDict):
if PY3: # Python 3 if six.PY2:
headers = HTTPHeaderDict(headers.items()) # Python 2.7
else: # Python 2
headers = HTTPHeaderDict.from_httplib(headers) headers = HTTPHeaderDict.from_httplib(headers)
else:
headers = HTTPHeaderDict(headers.items())
# HTTPResponse objects in Python 3 don't have a .strict attribute # HTTPResponse objects in Python 3 don't have a .strict attribute
strict = getattr(r, 'strict', 0) strict = getattr(r, "strict", 0)
resp = ResponseCls(body=r, resp = ResponseCls(
body=r,
headers=headers, headers=headers,
status=r.status, status=r.status,
version=r.version, version=r.version,
reason=r.reason, reason=r.reason,
strict=strict, strict=strict,
original_response=r, original_response=r,
**response_kw) **response_kw
)
return resp return resp
# Backwards-compatibility methods for httplib.HTTPResponse # Backwards-compatibility methods for http.client.HTTPResponse
def getheaders(self): def getheaders(self):
return self.headers return self.headers
@ -486,13 +629,18 @@ class HTTPResponse(io.IOBase):
if self._connection: if self._connection:
self._connection.close() self._connection.close()
if not self.auto_close:
io.IOBase.close(self)
@property @property
def closed(self): def closed(self):
if self._fp is None: if not self.auto_close:
return io.IOBase.closed.__get__(self)
elif self._fp is None:
return True return True
elif hasattr(self._fp, 'isclosed'): elif hasattr(self._fp, "isclosed"):
return self._fp.isclosed() return self._fp.isclosed()
elif hasattr(self._fp, 'closed'): elif hasattr(self._fp, "closed"):
return self._fp.closed return self._fp.closed
else: else:
return True return True
@ -503,11 +651,17 @@ class HTTPResponse(io.IOBase):
elif hasattr(self._fp, "fileno"): elif hasattr(self._fp, "fileno"):
return self._fp.fileno() return self._fp.fileno()
else: else:
raise IOError("The file-like object this HTTPResponse is wrapped " raise IOError(
"around has no file descriptor") "The file-like object this HTTPResponse is wrapped "
"around has no file descriptor"
)
def flush(self): def flush(self):
if self._fp is not None and hasattr(self._fp, 'flush'): if (
self._fp is not None
and hasattr(self._fp, "flush")
and not getattr(self._fp, "closed", False)
):
return self._fp.flush() return self._fp.flush()
def readable(self): def readable(self):
@ -526,11 +680,11 @@ class HTTPResponse(io.IOBase):
def supports_chunked_reads(self): def supports_chunked_reads(self):
""" """
Checks if the underlying file-like object looks like a Checks if the underlying file-like object looks like a
httplib.HTTPResponse object. We do this by testing for the fp :class:`http.client.HTTPResponse` object. We do this by testing for
attribute. If it is present we assume it returns raw chunks as the fp attribute. If it is present we assume it returns raw chunks as
processed by read_chunked(). processed by read_chunked().
""" """
return hasattr(self._fp, 'fp') return hasattr(self._fp, "fp")
def _update_chunk_length(self): def _update_chunk_length(self):
# First, we'll figure out length of a chunk and then # First, we'll figure out length of a chunk and then
@ -538,13 +692,13 @@ class HTTPResponse(io.IOBase):
if self.chunk_left is not None: if self.chunk_left is not None:
return return
line = self._fp.fp.readline() line = self._fp.fp.readline()
line = line.split(b';', 1)[0] line = line.split(b";", 1)[0]
try: try:
self.chunk_left = int(line, 16) self.chunk_left = int(line, 16)
except ValueError: except ValueError:
# Invalid chunked protocol response, abort. # Invalid chunked protocol response, abort.
self.close() self.close()
raise httplib.IncompleteRead(line) raise InvalidChunkLength(self, line)
def _handle_chunk(self, amt): def _handle_chunk(self, amt):
returned_chunk = None returned_chunk = None
@ -573,6 +727,11 @@ class HTTPResponse(io.IOBase):
Similar to :meth:`HTTPResponse.read`, but with an additional Similar to :meth:`HTTPResponse.read`, but with an additional
parameter: ``decode_content``. parameter: ``decode_content``.
:param amt:
How much of the content to read. If specified, caching is skipped
because it doesn't make sense to cache partial content as the full
response.
:param decode_content: :param decode_content:
If True, will attempt to decode the body based on the If True, will attempt to decode the body based on the
'content-encoding' header. 'content-encoding' header.
@ -582,25 +741,33 @@ class HTTPResponse(io.IOBase):
if not self.chunked: if not self.chunked:
raise ResponseNotChunked( raise ResponseNotChunked(
"Response is not chunked. " "Response is not chunked. "
"Header 'transfer-encoding: chunked' is missing.") "Header 'transfer-encoding: chunked' is missing."
)
if not self.supports_chunked_reads(): if not self.supports_chunked_reads():
raise BodyNotHttplibCompatible( raise BodyNotHttplibCompatible(
"Body should be httplib.HTTPResponse like. " "Body should be http.client.HTTPResponse like. "
"It should have have an fp attribute which returns raw chunks.") "It should have have an fp attribute which returns raw chunks."
)
with self._error_catcher():
# Don't bother reading the body of a HEAD request. # Don't bother reading the body of a HEAD request.
if self._original_response and is_response_to_head(self._original_response): if self._original_response and is_response_to_head(self._original_response):
self._original_response.close() self._original_response.close()
return return
with self._error_catcher(): # If a response is already read and closed
# then return immediately.
if self._fp.fp is None:
return
while True: while True:
self._update_chunk_length() self._update_chunk_length()
if self.chunk_left == 0: if self.chunk_left == 0:
break break
chunk = self._handle_chunk(amt) chunk = self._handle_chunk(amt)
decoded = self._decode(chunk, decode_content=decode_content, decoded = self._decode(
flush_decoder=False) chunk, decode_content=decode_content, flush_decoder=False
)
if decoded: if decoded:
yield decoded yield decoded
@ -618,9 +785,37 @@ class HTTPResponse(io.IOBase):
if not line: if not line:
# Some sites may not end with '\r\n'. # Some sites may not end with '\r\n'.
break break
if line == b'\r\n': if line == b"\r\n":
break break
# We read everything; close the "file". # We read everything; close the "file".
if self._original_response: if self._original_response:
self._original_response.close() self._original_response.close()
def geturl(self):
"""
Returns the URL that was the source of this response.
If the request that generated this response redirected, this method
will return the final redirect location.
"""
if self.retries is not None and len(self.retries.history):
return self.retries.history[-1].redirect_location
else:
return self._request_url
def __iter__(self):
buffer = []
for chunk in self.stream(decode_content=True):
if b"\n" in chunk:
chunk = chunk.split(b"\n")
yield b"".join(buffer) + chunk[0] + b"\n"
for x in chunk[1:-1]:
yield x + b"\n"
if chunk[-1]:
buffer = [chunk[-1]]
else:
buffer = []
else:
buffer.append(chunk)
if buffer:
yield b"".join(buffer)

View file

@ -1,54 +1,49 @@
from __future__ import absolute_import from __future__ import absolute_import
# For backwards compatibility, provide imports that used to be here. # For backwards compatibility, provide imports that used to be here.
from .connection import is_connection_dropped from .connection import is_connection_dropped
from .request import make_headers from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
from .response import is_fp_closed from .response import is_fp_closed
from .retry import Retry
from .ssl_ import ( from .ssl_ import (
SSLContext, ALPN_PROTOCOLS,
HAS_SNI, HAS_SNI,
IS_PYOPENSSL, IS_PYOPENSSL,
IS_SECURETRANSPORT, IS_SECURETRANSPORT,
PROTOCOL_TLS,
SSLContext,
assert_fingerprint, assert_fingerprint,
resolve_cert_reqs, resolve_cert_reqs,
resolve_ssl_version, resolve_ssl_version,
ssl_wrap_socket, ssl_wrap_socket,
) )
from .timeout import ( from .timeout import Timeout, current_time
current_time, from .url import Url, get_host, parse_url, split_first
Timeout, from .wait import wait_for_read, wait_for_write
)
from .retry import Retry
from .url import (
get_host,
parse_url,
split_first,
Url,
)
from .wait import (
wait_for_read,
wait_for_write
)
__all__ = ( __all__ = (
'HAS_SNI', "HAS_SNI",
'IS_PYOPENSSL', "IS_PYOPENSSL",
'IS_SECURETRANSPORT', "IS_SECURETRANSPORT",
'SSLContext', "SSLContext",
'Retry', "PROTOCOL_TLS",
'Timeout', "ALPN_PROTOCOLS",
'Url', "Retry",
'assert_fingerprint', "Timeout",
'current_time', "Url",
'is_connection_dropped', "assert_fingerprint",
'is_fp_closed', "current_time",
'get_host', "is_connection_dropped",
'parse_url', "is_fp_closed",
'make_headers', "get_host",
'resolve_cert_reqs', "parse_url",
'resolve_ssl_version', "make_headers",
'split_first', "resolve_cert_reqs",
'ssl_wrap_socket', "resolve_ssl_version",
'wait_for_read', "split_first",
'wait_for_write' "ssl_wrap_socket",
"wait_for_read",
"wait_for_write",
"SKIP_HEADER",
"SKIPPABLE_HEADERS",
) )

View file

@ -1,7 +1,12 @@
from __future__ import absolute_import from __future__ import absolute_import
import socket import socket
from .wait import wait_for_read
from .selectors import HAS_SELECT, SelectorError from urllib3.exceptions import LocationParseError
from ..contrib import _appengine_environ
from ..packages import six
from .wait import NoWayToWaitForSocketError, wait_for_read
def is_connection_dropped(conn): # Platform-specific def is_connection_dropped(conn): # Platform-specific
@ -9,47 +14,48 @@ def is_connection_dropped(conn): # Platform-specific
Returns True if the connection is dropped and should be closed. Returns True if the connection is dropped and should be closed.
:param conn: :param conn:
:class:`httplib.HTTPConnection` object. :class:`http.client.HTTPConnection` object.
Note: For platforms like AppEngine, this will always return ``False`` to Note: For platforms like AppEngine, this will always return ``False`` to
let the platform handle connection recycling transparently for us. let the platform handle connection recycling transparently for us.
""" """
sock = getattr(conn, 'sock', False) sock = getattr(conn, "sock", False)
if sock is False: # Platform-specific: AppEngine if sock is False: # Platform-specific: AppEngine
return False return False
if sock is None: # Connection already closed (such as by httplib). if sock is None: # Connection already closed (such as by httplib).
return True return True
if not HAS_SELECT:
return False
try: try:
return bool(wait_for_read(sock, timeout=0.0)) # Returns True if readable, which here means it's been dropped
except SelectorError: return wait_for_read(sock, timeout=0.0)
return True except NoWayToWaitForSocketError: # Platform-specific: AppEngine
return False
# This function is copied from socket.py in the Python 2.7 standard # This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`. # library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers # One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality. # discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, def create_connection(
source_address=None, socket_options=None): address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None,
socket_options=None,
):
"""Connect to *address* and return the socket object. """Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host, Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance *timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`getdefaulttimeout` global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port) is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection. for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default. An host of '' or port 0 tells the OS to use the default.
""" """
host, port = address host, port = address
if host.startswith('['): if host.startswith("["):
host = host.strip('[]') host = host.strip("[]")
err = None err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets # Using the value from allowed_gai_family() in the context of getaddrinfo lets
@ -57,6 +63,13 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
# The original create_connection function always returns all records. # The original create_connection function always returns all records.
family = allowed_gai_family() family = allowed_gai_family()
try:
host.encode("idna")
except UnicodeError:
return six.raise_from(
LocationParseError(u"'%s', label empty or too long" % host), None
)
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res af, socktype, proto, canonname, sa = res
sock = None sock = None
@ -109,11 +122,18 @@ def _has_ipv6(host):
sock = None sock = None
has_ipv6 = False has_ipv6 = False
# App Engine doesn't support IPV6 sockets and actually has a quota on the
# number of sockets that can be used, so just early out here instead of
# creating a socket needlessly.
# See https://github.com/urllib3/urllib3/issues/1446
if _appengine_environ.is_appengine_sandbox():
return False
if socket.has_ipv6: if socket.has_ipv6:
# has_ipv6 returns true if cPython was compiled with IPv6 support. # has_ipv6 returns true if cPython was compiled with IPv6 support.
# It does not tell us if the system has IPv6 support enabled. To # It does not tell us if the system has IPv6 support enabled. To
# determine that we must bind to an IPv6 address. # determine that we must bind to an IPv6 address.
# https://github.com/shazow/urllib3/pull/611 # https://github.com/urllib3/urllib3/pull/611
# https://bugs.python.org/issue658327 # https://bugs.python.org/issue658327
try: try:
sock = socket.socket(socket.AF_INET6) sock = socket.socket(socket.AF_INET6)
@ -127,4 +147,4 @@ def _has_ipv6(host):
return has_ipv6 return has_ipv6
HAS_IPV6 = _has_ipv6('::1') HAS_IPV6 = _has_ipv6("::1")

57
lib/urllib3/util/proxy.py Normal file
View file

@ -0,0 +1,57 @@
from .ssl_ import create_urllib3_context, resolve_cert_reqs, resolve_ssl_version
def connection_requires_http_tunnel(
proxy_url=None, proxy_config=None, destination_scheme=None
):
"""
Returns True if the connection requires an HTTP CONNECT through the proxy.
:param URL proxy_url:
URL of the proxy.
:param ProxyConfig proxy_config:
Proxy configuration from poolmanager.py
:param str destination_scheme:
The scheme of the destination. (i.e https, http, etc)
"""
# If we're not using a proxy, no way to use a tunnel.
if proxy_url is None:
return False
# HTTP destinations never require tunneling, we always forward.
if destination_scheme == "http":
return False
# Support for forwarding with HTTPS proxies and HTTPS destinations.
if (
proxy_url.scheme == "https"
and proxy_config
and proxy_config.use_forwarding_for_https
):
return False
# Otherwise always use a tunnel.
return True
def create_proxy_ssl_context(
ssl_version, cert_reqs, ca_certs=None, ca_cert_dir=None, ca_cert_data=None
):
"""
Generates a default proxy ssl context if one hasn't been provided by the
user.
"""
ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(ssl_version),
cert_reqs=resolve_cert_reqs(cert_reqs),
)
if (
not ca_certs
and not ca_cert_dir
and not ca_cert_data
and hasattr(ssl_context, "load_default_certs")
):
ssl_context.load_default_certs()
return ssl_context

22
lib/urllib3/util/queue.py Normal file
View file

@ -0,0 +1,22 @@
import collections
from ..packages import six
from ..packages.six.moves import queue
if six.PY2:
# Queue is imported for side effects on MS Windows. See issue #229.
import Queue as _unused_module_Queue # noqa: F401
class LifoQueue(queue.Queue):
def _init(self, _):
self.queue = collections.deque()
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item):
self.queue.append(item)
def _get(self):
return self.queue.pop()

View file

@ -1,15 +1,36 @@
from __future__ import absolute_import from __future__ import absolute_import
from base64 import b64encode from base64 import b64encode
from ..packages.six import b, integer_types
from ..exceptions import UnrewindableBodyError from ..exceptions import UnrewindableBodyError
from ..packages.six import b, integer_types
# Pass as a value within ``headers`` to skip
# emitting some HTTP headers that are added automatically.
# The only headers that are supported are ``Accept-Encoding``,
# ``Host``, and ``User-Agent``.
SKIP_HEADER = "@@@SKIP_HEADER@@@"
SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"])
ACCEPT_ENCODING = "gzip,deflate"
try:
import brotli as _unused_module_brotli # noqa: F401
except ImportError:
pass
else:
ACCEPT_ENCODING += ",br"
ACCEPT_ENCODING = 'gzip,deflate'
_FAILEDTELL = object() _FAILEDTELL = object()
def make_headers(keep_alive=None, accept_encoding=None, user_agent=None, def make_headers(
basic_auth=None, proxy_basic_auth=None, disable_cache=None): keep_alive=None,
accept_encoding=None,
user_agent=None,
basic_auth=None,
proxy_basic_auth=None,
disable_cache=None,
):
""" """
Shortcuts for generating request headers. Shortcuts for generating request headers.
@ -49,27 +70,27 @@ def make_headers(keep_alive=None, accept_encoding=None, user_agent=None,
if isinstance(accept_encoding, str): if isinstance(accept_encoding, str):
pass pass
elif isinstance(accept_encoding, list): elif isinstance(accept_encoding, list):
accept_encoding = ','.join(accept_encoding) accept_encoding = ",".join(accept_encoding)
else: else:
accept_encoding = ACCEPT_ENCODING accept_encoding = ACCEPT_ENCODING
headers['accept-encoding'] = accept_encoding headers["accept-encoding"] = accept_encoding
if user_agent: if user_agent:
headers['user-agent'] = user_agent headers["user-agent"] = user_agent
if keep_alive: if keep_alive:
headers['connection'] = 'keep-alive' headers["connection"] = "keep-alive"
if basic_auth: if basic_auth:
headers['authorization'] = 'Basic ' + \ headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8")
b64encode(b(basic_auth)).decode('utf-8')
if proxy_basic_auth: if proxy_basic_auth:
headers['proxy-authorization'] = 'Basic ' + \ headers["proxy-authorization"] = "Basic " + b64encode(
b64encode(b(proxy_basic_auth)).decode('utf-8') b(proxy_basic_auth)
).decode("utf-8")
if disable_cache: if disable_cache:
headers['cache-control'] = 'no-cache' headers["cache-control"] = "no-cache"
return headers return headers
@ -81,7 +102,7 @@ def set_file_position(body, pos):
""" """
if pos is not None: if pos is not None:
rewind_body(body, pos) rewind_body(body, pos)
elif getattr(body, 'tell', None) is not None: elif getattr(body, "tell", None) is not None:
try: try:
pos = body.tell() pos = body.tell()
except (IOError, OSError): except (IOError, OSError):
@ -103,16 +124,20 @@ def rewind_body(body, body_pos):
:param int pos: :param int pos:
Position to seek to in file. Position to seek to in file.
""" """
body_seek = getattr(body, 'seek', None) body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types): if body_seek is not None and isinstance(body_pos, integer_types):
try: try:
body_seek(body_pos) body_seek(body_pos)
except (IOError, OSError): except (IOError, OSError):
raise UnrewindableBodyError("An error occurred when rewinding request " raise UnrewindableBodyError(
"body for redirect/retry.") "An error occurred when rewinding request body for redirect/retry."
)
elif body_pos is _FAILEDTELL: elif body_pos is _FAILEDTELL:
raise UnrewindableBodyError("Unable to record file position for rewinding " raise UnrewindableBodyError(
"request body during a redirect/retry.") "Unable to record file position for rewinding "
"request body during a redirect/retry."
)
else: else:
raise ValueError("body_pos must be of type integer, " raise ValueError(
"instead it was %s." % type(body_pos)) "body_pos must be of type integer, instead it was %s." % type(body_pos)
)

View file

@ -1,7 +1,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from ..packages.six.moves import http_client as httplib
from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect
from ..exceptions import HeaderParsingError from ..exceptions import HeaderParsingError
from ..packages.six.moves import http_client as httplib
def is_fp_closed(obj): def is_fp_closed(obj):
@ -42,8 +44,7 @@ def assert_header_parsing(headers):
Only works on Python 3. Only works on Python 3.
:param headers: Headers to verify. :param http.client.HTTPMessage headers: Headers to verify.
:type headers: `httplib.HTTPMessage`.
:raises urllib3.exceptions.HeaderParsingError: :raises urllib3.exceptions.HeaderParsingError:
If parsing errors are found. If parsing errors are found.
@ -52,15 +53,39 @@ def assert_header_parsing(headers):
# This will fail silently if we pass in the wrong kind of parameter. # This will fail silently if we pass in the wrong kind of parameter.
# To make debugging easier add an explicit check. # To make debugging easier add an explicit check.
if not isinstance(headers, httplib.HTTPMessage): if not isinstance(headers, httplib.HTTPMessage):
raise TypeError('expected httplib.Message, got {0}.'.format( raise TypeError("expected httplib.Message, got {0}.".format(type(headers)))
type(headers)))
defects = getattr(headers, 'defects', None) defects = getattr(headers, "defects", None)
get_payload = getattr(headers, 'get_payload', None) get_payload = getattr(headers, "get_payload", None)
unparsed_data = None unparsed_data = None
if get_payload: # Platform-specific: Python 3. if get_payload:
unparsed_data = get_payload() # get_payload is actually email.message.Message.get_payload;
# we're only interested in the result if it's not a multipart message
if not headers.is_multipart():
payload = get_payload()
if isinstance(payload, (bytes, str)):
unparsed_data = payload
if defects:
# httplib is assuming a response body is available
# when parsing headers even when httplib only sends
# header data to parse_headers() This results in
# defects on multipart responses in particular.
# See: https://github.com/urllib3/urllib3/issues/800
# So we ignore the following defects:
# - StartBoundaryNotFoundDefect:
# The claimed start boundary was never found.
# - MultipartInvariantViolationDefect:
# A message claimed to be a multipart but no subparts were found.
defects = [
defect
for defect in defects
if not isinstance(
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
)
]
if defects or unparsed_data: if defects or unparsed_data:
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data) raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
@ -71,11 +96,12 @@ def is_response_to_head(response):
Checks whether the request of a response has been a HEAD-request. Checks whether the request of a response has been a HEAD-request.
Handles the quirks of AppEngine. Handles the quirks of AppEngine.
:param conn: :param http.client.HTTPResponse response:
:type conn: :class:`httplib.HTTPResponse` Response to check if the originating request
used 'HEAD' as a method.
""" """
# FIXME: Can we do this somehow without accessing private httplib _method? # FIXME: Can we do this somehow without accessing private httplib _method?
method = response._method method = response._method
if isinstance(method, int): # Platform-specific: Appengine if isinstance(method, int): # Platform-specific: Appengine
return method == 3 return method == 3
return method.upper() == 'HEAD' return method.upper() == "HEAD"

View file

@ -1,29 +1,76 @@
from __future__ import absolute_import from __future__ import absolute_import
import time
import email
import logging import logging
import re
import time
import warnings
from collections import namedtuple from collections import namedtuple
from itertools import takewhile from itertools import takewhile
import email
import re
from ..exceptions import ( from ..exceptions import (
ConnectTimeoutError, ConnectTimeoutError,
InvalidHeader,
MaxRetryError, MaxRetryError,
ProtocolError, ProtocolError,
ProxyError,
ReadTimeoutError, ReadTimeoutError,
ResponseError, ResponseError,
InvalidHeader,
) )
from ..packages import six from ..packages import six
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Data structure for representing the metadata of requests that result in a retry. # Data structure for representing the metadata of requests that result in a retry.
RequestHistory = namedtuple('RequestHistory', ["method", "url", "error", RequestHistory = namedtuple(
"status", "redirect_location"]) "RequestHistory", ["method", "url", "error", "status", "redirect_location"]
)
# TODO: In v2 we can remove this sentinel and metaclass with deprecated options.
_Default = object()
class _RetryMeta(type):
@property
def DEFAULT_METHOD_WHITELIST(cls):
warnings.warn(
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
DeprecationWarning,
)
return cls.DEFAULT_ALLOWED_METHODS
@DEFAULT_METHOD_WHITELIST.setter
def DEFAULT_METHOD_WHITELIST(cls, value):
warnings.warn(
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
DeprecationWarning,
)
cls.DEFAULT_ALLOWED_METHODS = value
@property
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls):
warnings.warn(
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
DeprecationWarning,
)
return cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
@DEFAULT_REDIRECT_HEADERS_BLACKLIST.setter
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls, value):
warnings.warn(
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
DeprecationWarning,
)
cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT = value
@six.add_metaclass(_RetryMeta)
class Retry(object): class Retry(object):
"""Retry configuration. """Retry configuration.
@ -51,8 +98,7 @@ class Retry(object):
Total number of retries to allow. Takes precedence over other counts. Total number of retries to allow. Takes precedence over other counts.
Set to ``None`` to remove this constraint and fall back on other Set to ``None`` to remove this constraint and fall back on other
counts. It's a good idea to set this to some sensibly-high value to counts.
account for unexpected edge cases and avoid infinite retry loops.
Set to ``0`` to fail on the first retry. Set to ``0`` to fail on the first retry.
@ -93,18 +139,35 @@ class Retry(object):
Set to ``0`` to fail on the first retry of this type. Set to ``0`` to fail on the first retry of this type.
:param iterable method_whitelist: :param int other:
How many times to retry on other errors.
Other errors are errors that are not connect, read, redirect or status errors.
These errors might be raised after the request was sent to the server, so the
request might have side-effects.
Set to ``0`` to fail on the first retry of this type.
If ``total`` is not set, it's a good idea to set this to 0 to account
for unexpected edge cases and avoid infinite retry loops.
:param iterable allowed_methods:
Set of uppercased HTTP method verbs that we should retry on. Set of uppercased HTTP method verbs that we should retry on.
By default, we only retry on methods which are considered to be By default, we only retry on methods which are considered to be
idempotent (multiple requests with the same parameters end with the idempotent (multiple requests with the same parameters end with the
same state). See :attr:`Retry.DEFAULT_METHOD_WHITELIST`. same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`.
Set to a ``False`` value to retry on any verb. Set to a ``False`` value to retry on any verb.
.. warning::
Previously this parameter was named ``method_whitelist``, that
usage is deprecated in v1.26.0 and will be removed in v2.0.
:param iterable status_forcelist: :param iterable status_forcelist:
A set of integer HTTP status codes that we should force a retry on. A set of integer HTTP status codes that we should force a retry on.
A retry is initiated if the request method is in ``method_whitelist`` A retry is initiated if the request method is in ``allowed_methods``
and the response status code is in ``status_forcelist``. and the response status code is in ``status_forcelist``.
By default, this is disabled with ``None``. By default, this is disabled with ``None``.
@ -114,7 +177,7 @@ class Retry(object):
(most errors are resolved immediately by a second try without a (most errors are resolved immediately by a second try without a
delay). urllib3 will sleep for:: delay). urllib3 will sleep for::
{backoff factor} * (2 ^ ({number of total retries} - 1)) {backoff factor} * (2 ** ({number of total retries} - 1))
seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep
for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer
@ -139,25 +202,70 @@ class Retry(object):
Whether to respect Retry-After header on status codes defined as Whether to respect Retry-After header on status codes defined as
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not. :attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
:param iterable remove_headers_on_redirect:
Sequence of headers to remove from the request when a response
indicating a redirect is returned before firing off the redirected
request.
""" """
DEFAULT_METHOD_WHITELIST = frozenset([ #: Default methods to be used for ``allowed_methods``
'HEAD', 'GET', 'PUT', 'DELETE', 'OPTIONS', 'TRACE']) DEFAULT_ALLOWED_METHODS = frozenset(
["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]
)
#: Default status codes to be used for ``status_forcelist``
RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503]) RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503])
#: Default headers to be used for ``remove_headers_on_redirect``
DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"])
#: Maximum backoff time. #: Maximum backoff time.
BACKOFF_MAX = 120 BACKOFF_MAX = 120
def __init__(self, total=10, connect=None, read=None, redirect=None, status=None, def __init__(
method_whitelist=DEFAULT_METHOD_WHITELIST, status_forcelist=None, self,
backoff_factor=0, raise_on_redirect=True, raise_on_status=True, total=10,
history=None, respect_retry_after_header=True): connect=None,
read=None,
redirect=None,
status=None,
other=None,
allowed_methods=_Default,
status_forcelist=None,
backoff_factor=0,
raise_on_redirect=True,
raise_on_status=True,
history=None,
respect_retry_after_header=True,
remove_headers_on_redirect=_Default,
# TODO: Deprecated, remove in v2.0
method_whitelist=_Default,
):
if method_whitelist is not _Default:
if allowed_methods is not _Default:
raise ValueError(
"Using both 'allowed_methods' and "
"'method_whitelist' together is not allowed. "
"Instead only use 'allowed_methods'"
)
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
stacklevel=2,
)
allowed_methods = method_whitelist
if allowed_methods is _Default:
allowed_methods = self.DEFAULT_ALLOWED_METHODS
if remove_headers_on_redirect is _Default:
remove_headers_on_redirect = self.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
self.total = total self.total = total
self.connect = connect self.connect = connect
self.read = read self.read = read
self.status = status self.status = status
self.other = other
if redirect is False or total is False: if redirect is False or total is False:
redirect = 0 redirect = 0
@ -165,24 +273,49 @@ class Retry(object):
self.redirect = redirect self.redirect = redirect
self.status_forcelist = status_forcelist or set() self.status_forcelist = status_forcelist or set()
self.method_whitelist = method_whitelist self.allowed_methods = allowed_methods
self.backoff_factor = backoff_factor self.backoff_factor = backoff_factor
self.raise_on_redirect = raise_on_redirect self.raise_on_redirect = raise_on_redirect
self.raise_on_status = raise_on_status self.raise_on_status = raise_on_status
self.history = history or tuple() self.history = history or tuple()
self.respect_retry_after_header = respect_retry_after_header self.respect_retry_after_header = respect_retry_after_header
self.remove_headers_on_redirect = frozenset(
[h.lower() for h in remove_headers_on_redirect]
)
def new(self, **kw): def new(self, **kw):
params = dict( params = dict(
total=self.total, total=self.total,
connect=self.connect, read=self.read, redirect=self.redirect, status=self.status, connect=self.connect,
method_whitelist=self.method_whitelist, read=self.read,
redirect=self.redirect,
status=self.status,
other=self.other,
status_forcelist=self.status_forcelist, status_forcelist=self.status_forcelist,
backoff_factor=self.backoff_factor, backoff_factor=self.backoff_factor,
raise_on_redirect=self.raise_on_redirect, raise_on_redirect=self.raise_on_redirect,
raise_on_status=self.raise_on_status, raise_on_status=self.raise_on_status,
history=self.history, history=self.history,
remove_headers_on_redirect=self.remove_headers_on_redirect,
respect_retry_after_header=self.respect_retry_after_header,
) )
# TODO: If already given in **kw we use what's given to us
# If not given we need to figure out what to pass. We decide
# based on whether our class has the 'method_whitelist' property
# and if so we pass the deprecated 'method_whitelist' otherwise
# we use 'allowed_methods'. Remove in v2.0
if "method_whitelist" not in kw and "allowed_methods" not in kw:
if "method_whitelist" in self.__dict__:
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
params["method_whitelist"] = self.allowed_methods
else:
params["allowed_methods"] = self.allowed_methods
params.update(kw) params.update(kw)
return type(self)(**params) return type(self)(**params)
@ -206,8 +339,11 @@ class Retry(object):
:rtype: float :rtype: float
""" """
# We want to consider only the last consecutive errors sequence (Ignore redirects). # We want to consider only the last consecutive errors sequence (Ignore redirects).
consecutive_errors_len = len(list(takewhile(lambda x: x.redirect_location is None, consecutive_errors_len = len(
reversed(self.history)))) list(
takewhile(lambda x: x.redirect_location is None, reversed(self.history))
)
)
if consecutive_errors_len <= 1: if consecutive_errors_len <= 1:
return 0 return 0
@ -219,10 +355,17 @@ class Retry(object):
if re.match(r"^\s*[0-9]+\s*$", retry_after): if re.match(r"^\s*[0-9]+\s*$", retry_after):
seconds = int(retry_after) seconds = int(retry_after)
else: else:
retry_date_tuple = email.utils.parsedate(retry_after) retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None: if retry_date_tuple is None:
raise InvalidHeader("Invalid Retry-After header: %s" % retry_after) raise InvalidHeader("Invalid Retry-After header: %s" % retry_after)
retry_date = time.mktime(retry_date_tuple) if retry_date_tuple[9] is None: # Python 2
# Assume UTC if no timezone was specified
# On Python2.7, parsedate_tz returns None for a timezone offset
# instead of 0 if no timezone is given, where mktime_tz treats
# a None timezone offset as local time.
retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:]
retry_date = email.utils.mktime_tz(retry_date_tuple)
seconds = retry_date - time.time() seconds = retry_date - time.time()
if seconds < 0: if seconds < 0:
@ -263,7 +406,7 @@ class Retry(object):
this method will return immediately. this method will return immediately.
""" """
if response: if self.respect_retry_after_header and response:
slept = self.sleep_for_retry(response) slept = self.sleep_for_retry(response)
if slept: if slept:
return return
@ -274,6 +417,8 @@ class Retry(object):
"""Errors when we're fairly sure that the server did not receive the """Errors when we're fairly sure that the server did not receive the
request, so it should be safe to retry. request, so it should be safe to retry.
""" """
if isinstance(err, ProxyError):
err = err.original_error
return isinstance(err, ConnectTimeoutError) return isinstance(err, ConnectTimeoutError)
def _is_read_error(self, err): def _is_read_error(self, err):
@ -284,15 +429,26 @@ class Retry(object):
def _is_method_retryable(self, method): def _is_method_retryable(self, method):
"""Checks if a given HTTP method should be retried upon, depending if """Checks if a given HTTP method should be retried upon, depending if
it is included on the method whitelist. it is included in the allowed_methods
""" """
if self.method_whitelist and method.upper() not in self.method_whitelist: # TODO: For now favor if the Retry implementation sets its own method_whitelist
return False # property outside of our constructor to avoid breaking custom implementations.
if "method_whitelist" in self.__dict__:
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
allowed_methods = self.method_whitelist
else:
allowed_methods = self.allowed_methods
if allowed_methods and method.upper() not in allowed_methods:
return False
return True return True
def is_retry(self, method, status_code, has_retry_after=False): def is_retry(self, method, status_code, has_retry_after=False):
""" Is this method/status code retryable? (Based on whitelists and control """Is this method/status code retryable? (Based on allowlists and control
variables such as the number of total retries to allow, whether to variables such as the number of total retries to allow, whether to
respect the Retry-After header, whether this header is present, and respect the Retry-After header, whether this header is present, and
whether the returned status code is on the list of status codes to whether the returned status code is on the list of status codes to
@ -304,20 +460,38 @@ class Retry(object):
if self.status_forcelist and status_code in self.status_forcelist: if self.status_forcelist and status_code in self.status_forcelist:
return True return True
return (self.total and self.respect_retry_after_header and return (
has_retry_after and (status_code in self.RETRY_AFTER_STATUS_CODES)) self.total
and self.respect_retry_after_header
and has_retry_after
and (status_code in self.RETRY_AFTER_STATUS_CODES)
)
def is_exhausted(self): def is_exhausted(self):
"""Are we out of retries?""" """Are we out of retries?"""
retry_counts = (self.total, self.connect, self.read, self.redirect, self.status) retry_counts = (
self.total,
self.connect,
self.read,
self.redirect,
self.status,
self.other,
)
retry_counts = list(filter(None, retry_counts)) retry_counts = list(filter(None, retry_counts))
if not retry_counts: if not retry_counts:
return False return False
return min(retry_counts) < 0 return min(retry_counts) < 0
def increment(self, method=None, url=None, response=None, error=None, def increment(
_pool=None, _stacktrace=None): self,
method=None,
url=None,
response=None,
error=None,
_pool=None,
_stacktrace=None,
):
"""Return a new Retry object with incremented retry counters. """Return a new Retry object with incremented retry counters.
:param response: A response object, or None, if the server did not :param response: A response object, or None, if the server did not
@ -340,7 +514,8 @@ class Retry(object):
read = self.read read = self.read
redirect = self.redirect redirect = self.redirect
status_count = self.status status_count = self.status
cause = 'unknown' other = self.other
cause = "unknown"
status = None status = None
redirect_location = None redirect_location = None
@ -358,31 +533,42 @@ class Retry(object):
elif read is not None: elif read is not None:
read -= 1 read -= 1
elif error:
# Other retry?
if other is not None:
other -= 1
elif response and response.get_redirect_location(): elif response and response.get_redirect_location():
# Redirect retry? # Redirect retry?
if redirect is not None: if redirect is not None:
redirect -= 1 redirect -= 1
cause = 'too many redirects' cause = "too many redirects"
redirect_location = response.get_redirect_location() redirect_location = response.get_redirect_location()
status = response.status status = response.status
else: else:
# Incrementing because of a server error like a 500 in # Incrementing because of a server error like a 500 in
# status_forcelist and a the given method is in the whitelist # status_forcelist and the given method is in the allowed_methods
cause = ResponseError.GENERIC_ERROR cause = ResponseError.GENERIC_ERROR
if response and response.status: if response and response.status:
if status_count is not None: if status_count is not None:
status_count -= 1 status_count -= 1
cause = ResponseError.SPECIFIC_ERROR.format( cause = ResponseError.SPECIFIC_ERROR.format(status_code=response.status)
status_code=response.status)
status = response.status status = response.status
history = self.history + (RequestHistory(method, url, error, status, redirect_location),) history = self.history + (
RequestHistory(method, url, error, status, redirect_location),
)
new_retry = self.new( new_retry = self.new(
total=total, total=total,
connect=connect, read=read, redirect=redirect, status=status_count, connect=connect,
history=history) read=read,
redirect=redirect,
status=status_count,
other=other,
history=history,
)
if new_retry.is_exhausted(): if new_retry.is_exhausted():
raise MaxRetryError(_pool, url, error or ResponseError(cause)) raise MaxRetryError(_pool, url, error or ResponseError(cause))
@ -392,9 +578,24 @@ class Retry(object):
return new_retry return new_retry
def __repr__(self): def __repr__(self):
return ('{cls.__name__}(total={self.total}, connect={self.connect}, ' return (
'read={self.read}, redirect={self.redirect}, status={self.status})').format( "{cls.__name__}(total={self.total}, connect={self.connect}, "
cls=type(self), self=self) "read={self.read}, redirect={self.redirect}, status={self.status})"
).format(cls=type(self), self=self)
def __getattr__(self, item):
if item == "method_whitelist":
# TODO: Remove this deprecated alias in v2.0
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
return self.allowed_methods
try:
return getattr(super(Retry, self), item)
except AttributeError:
return getattr(Retry, item)
# For backwards compatibility (equivalent to pre-v1.9): # For backwards compatibility (equivalent to pre-v1.9):

View file

@ -1,581 +0,0 @@
# Backport of selectors.py from Python 3.5+ to support Python < 3.4
# Also has the behavior specified in PEP 475 which is to retry syscalls
# in the case of an EINTR error. This module is required because selectors34
# does not follow this behavior and instead returns that no dile descriptor
# events have occurred rather than retry the syscall. The decision to drop
# support for select.devpoll is made to maintain 100% test coverage.
import errno
import math
import select
import socket
import sys
import time
from collections import namedtuple, Mapping
try:
monotonic = time.monotonic
except (AttributeError, ImportError): # Python 3.3<
monotonic = time.time
EVENT_READ = (1 << 0)
EVENT_WRITE = (1 << 1)
HAS_SELECT = True # Variable that shows whether the platform has a selector.
_SYSCALL_SENTINEL = object() # Sentinel in case a system call returns None.
_DEFAULT_SELECTOR = None
class SelectorError(Exception):
def __init__(self, errcode):
super(SelectorError, self).__init__()
self.errno = errcode
def __repr__(self):
return "<SelectorError errno={0}>".format(self.errno)
def __str__(self):
return self.__repr__()
def _fileobj_to_fd(fileobj):
""" Return a file descriptor from a file object. If
given an integer will simply return that integer back. """
if isinstance(fileobj, int):
fd = fileobj
else:
try:
fd = int(fileobj.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError("Invalid file object: {0!r}".format(fileobj))
if fd < 0:
raise ValueError("Invalid file descriptor: {0}".format(fd))
return fd
# Determine which function to use to wrap system calls because Python 3.5+
# already handles the case when system calls are interrupted.
if sys.version_info >= (3, 5):
def _syscall_wrapper(func, _, *args, **kwargs):
""" This is the short-circuit version of the below logic
because in Python 3.5+ all system calls automatically restart
and recalculate their timeouts. """
try:
return func(*args, **kwargs)
except (OSError, IOError, select.error) as e:
errcode = None
if hasattr(e, "errno"):
errcode = e.errno
raise SelectorError(errcode)
else:
def _syscall_wrapper(func, recalc_timeout, *args, **kwargs):
""" Wrapper function for syscalls that could fail due to EINTR.
All functions should be retried if there is time left in the timeout
in accordance with PEP 475. """
timeout = kwargs.get("timeout", None)
if timeout is None:
expires = None
recalc_timeout = False
else:
timeout = float(timeout)
if timeout < 0.0: # Timeout less than 0 treated as no timeout.
expires = None
else:
expires = monotonic() + timeout
args = list(args)
if recalc_timeout and "timeout" not in kwargs:
raise ValueError(
"Timeout must be in args or kwargs to be recalculated")
result = _SYSCALL_SENTINEL
while result is _SYSCALL_SENTINEL:
try:
result = func(*args, **kwargs)
# OSError is thrown by select.select
# IOError is thrown by select.epoll.poll
# select.error is thrown by select.poll.poll
# Aren't we thankful for Python 3.x rework for exceptions?
except (OSError, IOError, select.error) as e:
# select.error wasn't a subclass of OSError in the past.
errcode = None
if hasattr(e, "errno"):
errcode = e.errno
elif hasattr(e, "args"):
errcode = e.args[0]
# Also test for the Windows equivalent of EINTR.
is_interrupt = (errcode == errno.EINTR or (hasattr(errno, "WSAEINTR") and
errcode == errno.WSAEINTR))
if is_interrupt:
if expires is not None:
current_time = monotonic()
if current_time > expires:
raise OSError(errno=errno.ETIMEDOUT)
if recalc_timeout:
if "timeout" in kwargs:
kwargs["timeout"] = expires - current_time
continue
if errcode:
raise SelectorError(errcode)
else:
raise
return result
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
class _SelectorMapping(Mapping):
""" Mapping of file objects to selector keys """
def __init__(self, selector):
self._selector = selector
def __len__(self):
return len(self._selector._fd_to_key)
def __getitem__(self, fileobj):
try:
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{0!r} is not registered.".format(fileobj))
def __iter__(self):
return iter(self._selector._fd_to_key)
class BaseSelector(object):
""" Abstract Selector class
A selector supports registering file objects to be monitored
for specific I/O events.
A file object is a file descriptor or any object with a
`fileno()` method. An arbitrary object can be attached to the
file object which can be used for example to store context info,
a callback, etc.
A selector can use various implementations (select(), poll(), epoll(),
and kqueue()) depending on the platform. The 'DefaultSelector' class uses
the most efficient implementation for the current platform.
"""
def __init__(self):
# Maps file descriptors to keys.
self._fd_to_key = {}
# Read-only mapping returned by get_map()
self._map = _SelectorMapping(self)
def _fileobj_lookup(self, fileobj):
""" Return a file descriptor from a file object.
This wraps _fileobj_to_fd() to do an exhaustive
search in case the object is invalid but we still
have it in our map. Used by unregister() so we can
unregister an object that was previously registered
even if it is closed. It is also used by _SelectorMapping
"""
try:
return _fileobj_to_fd(fileobj)
except ValueError:
# Search through all our mapped keys.
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
return key.fd
# Raise ValueError after all.
raise
def register(self, fileobj, events, data=None):
""" Register a file object for a set of events to monitor. """
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)):
raise ValueError("Invalid events: {0!r}".format(events))
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
if key.fd in self._fd_to_key:
raise KeyError("{0!r} (FD {1}) is already registered"
.format(fileobj, key.fd))
self._fd_to_key[key.fd] = key
return key
def unregister(self, fileobj):
""" Unregister a file object from being monitored. """
try:
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
# Getting the fileno of a closed socket on Windows errors with EBADF.
except socket.error as e: # Platform-specific: Windows.
if e.errno != errno.EBADF:
raise
else:
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
self._fd_to_key.pop(key.fd)
break
else:
raise KeyError("{0!r} is not registered".format(fileobj))
return key
def modify(self, fileobj, events, data=None):
""" Change a registered file object monitored events and data. """
# NOTE: Some subclasses optimize this operation even further.
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
if events != key.events:
self.unregister(fileobj)
key = self.register(fileobj, events, data)
elif data != key.data:
# Use a shortcut to update the data.
key = key._replace(data=data)
self._fd_to_key[key.fd] = key
return key
def select(self, timeout=None):
""" Perform the actual selection until some monitored file objects
are ready or the timeout expires. """
raise NotImplementedError()
def close(self):
""" Close the selector. This must be called to ensure that all
underlying resources are freed. """
self._fd_to_key.clear()
self._map = None
def get_key(self, fileobj):
""" Return the key associated with a registered file object. """
mapping = self.get_map()
if mapping is None:
raise RuntimeError("Selector is closed")
try:
return mapping[fileobj]
except KeyError:
raise KeyError("{0!r} is not registered".format(fileobj))
def get_map(self):
""" Return a mapping of file objects to selector keys """
return self._map
def _key_from_fd(self, fd):
""" Return the key associated to a given file descriptor
Return None if it is not found. """
try:
return self._fd_to_key[fd]
except KeyError:
return None
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
# Almost all platforms have select.select()
if hasattr(select, "select"):
class SelectSelector(BaseSelector):
""" Select-based selector. """
def __init__(self):
super(SelectSelector, self).__init__()
self._readers = set()
self._writers = set()
def register(self, fileobj, events, data=None):
key = super(SelectSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
self._readers.add(key.fd)
if events & EVENT_WRITE:
self._writers.add(key.fd)
return key
def unregister(self, fileobj):
key = super(SelectSelector, self).unregister(fileobj)
self._readers.discard(key.fd)
self._writers.discard(key.fd)
return key
def _select(self, r, w, timeout=None):
""" Wrapper for select.select because timeout is a positional arg """
return select.select(r, w, [], timeout)
def select(self, timeout=None):
# Selecting on empty lists on Windows errors out.
if not len(self._readers) and not len(self._writers):
return []
timeout = None if timeout is None else max(timeout, 0.0)
ready = []
r, w, _ = _syscall_wrapper(self._select, True, self._readers,
self._writers, timeout)
r = set(r)
w = set(w)
for fd in r | w:
events = 0
if fd in r:
events |= EVENT_READ
if fd in w:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, "poll"):
class PollSelector(BaseSelector):
""" Poll-based selector """
def __init__(self):
super(PollSelector, self).__init__()
self._poll = select.poll()
def register(self, fileobj, events, data=None):
key = super(PollSelector, self).register(fileobj, events, data)
event_mask = 0
if events & EVENT_READ:
event_mask |= select.POLLIN
if events & EVENT_WRITE:
event_mask |= select.POLLOUT
self._poll.register(key.fd, event_mask)
return key
def unregister(self, fileobj):
key = super(PollSelector, self).unregister(fileobj)
self._poll.unregister(key.fd)
return key
def _wrap_poll(self, timeout=None):
""" Wrapper function for select.poll.poll() so that
_syscall_wrapper can work with only seconds. """
if timeout is not None:
if timeout <= 0:
timeout = 0
else:
# select.poll.poll() has a resolution of 1 millisecond,
# round away from zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
result = self._poll.poll(timeout)
return result
def select(self, timeout=None):
ready = []
fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout)
for fd, event_mask in fd_events:
events = 0
if event_mask & ~select.POLLIN:
events |= EVENT_WRITE
if event_mask & ~select.POLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, "epoll"):
class EpollSelector(BaseSelector):
""" Epoll-based selector """
def __init__(self):
super(EpollSelector, self).__init__()
self._epoll = select.epoll()
def fileno(self):
return self._epoll.fileno()
def register(self, fileobj, events, data=None):
key = super(EpollSelector, self).register(fileobj, events, data)
events_mask = 0
if events & EVENT_READ:
events_mask |= select.EPOLLIN
if events & EVENT_WRITE:
events_mask |= select.EPOLLOUT
_syscall_wrapper(self._epoll.register, False, key.fd, events_mask)
return key
def unregister(self, fileobj):
key = super(EpollSelector, self).unregister(fileobj)
try:
_syscall_wrapper(self._epoll.unregister, False, key.fd)
except SelectorError:
# This can occur when the fd was closed since registry.
pass
return key
def select(self, timeout=None):
if timeout is not None:
if timeout <= 0:
timeout = 0.0
else:
# select.epoll.poll() has a resolution of 1 millisecond
# but luckily takes seconds so we don't need a wrapper
# like PollSelector. Just for better rounding.
timeout = math.ceil(timeout * 1e3) * 1e-3
timeout = float(timeout)
else:
timeout = -1.0 # epoll.poll() must have a float.
# We always want at least 1 to ensure that select can be called
# with no file descriptors registered. Otherwise will fail.
max_events = max(len(self._fd_to_key), 1)
ready = []
fd_events = _syscall_wrapper(self._epoll.poll, True,
timeout=timeout,
maxevents=max_events)
for fd, event_mask in fd_events:
events = 0
if event_mask & ~select.EPOLLIN:
events |= EVENT_WRITE
if event_mask & ~select.EPOLLOUT:
events |= EVENT_READ
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._epoll.close()
super(EpollSelector, self).close()
if hasattr(select, "kqueue"):
class KqueueSelector(BaseSelector):
""" Kqueue / Kevent-based selector """
def __init__(self):
super(KqueueSelector, self).__init__()
self._kqueue = select.kqueue()
def fileno(self):
return self._kqueue.fileno()
def register(self, fileobj, events, data=None):
key = super(KqueueSelector, self).register(fileobj, events, data)
if events & EVENT_READ:
kevent = select.kevent(key.fd,
select.KQ_FILTER_READ,
select.KQ_EV_ADD)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
if events & EVENT_WRITE:
kevent = select.kevent(key.fd,
select.KQ_FILTER_WRITE,
select.KQ_EV_ADD)
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
return key
def unregister(self, fileobj):
key = super(KqueueSelector, self).unregister(fileobj)
if key.events & EVENT_READ:
kevent = select.kevent(key.fd,
select.KQ_FILTER_READ,
select.KQ_EV_DELETE)
try:
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
except SelectorError:
pass
if key.events & EVENT_WRITE:
kevent = select.kevent(key.fd,
select.KQ_FILTER_WRITE,
select.KQ_EV_DELETE)
try:
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
except SelectorError:
pass
return key
def select(self, timeout=None):
if timeout is not None:
timeout = max(timeout, 0)
max_events = len(self._fd_to_key) * 2
ready_fds = {}
kevent_list = _syscall_wrapper(self._kqueue.control, True,
None, max_events, timeout)
for kevent in kevent_list:
fd = kevent.ident
event_mask = kevent.filter
events = 0
if event_mask == select.KQ_FILTER_READ:
events |= EVENT_READ
if event_mask == select.KQ_FILTER_WRITE:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
if key.fd not in ready_fds:
ready_fds[key.fd] = (key, events & key.events)
else:
old_events = ready_fds[key.fd][1]
ready_fds[key.fd] = (key, (events | old_events) & key.events)
return list(ready_fds.values())
def close(self):
self._kqueue.close()
super(KqueueSelector, self).close()
if not hasattr(select, 'select'): # Platform-specific: AppEngine
HAS_SELECT = False
def _can_allocate(struct):
""" Checks that select structs can be allocated by the underlying
operating system, not just advertised by the select module. We don't
check select() because we'll be hopeful that most platforms that
don't have it available will not advertise it. (ie: GAE) """
try:
# select.poll() objects won't fail until used.
if struct == 'poll':
p = select.poll()
p.poll(0)
# All others will fail on allocation.
else:
getattr(select, struct)().close()
return True
except (OSError, AttributeError) as e:
return False
# Choose the best implementation, roughly:
# kqueue == epoll > poll > select. Devpoll not supported. (See above)
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
def DefaultSelector():
""" This function serves as a first call for DefaultSelector to
detect if the select module is being monkey-patched incorrectly
by eventlet, greenlet, and preserve proper behavior. """
global _DEFAULT_SELECTOR
if _DEFAULT_SELECTOR is None:
if _can_allocate('kqueue'):
_DEFAULT_SELECTOR = KqueueSelector
elif _can_allocate('epoll'):
_DEFAULT_SELECTOR = EpollSelector
elif _can_allocate('poll'):
_DEFAULT_SELECTOR = PollSelector
elif hasattr(select, 'select'):
_DEFAULT_SELECTOR = SelectSelector
else: # Platform-specific: AppEngine
raise ValueError('Platform does not have a selector')
return _DEFAULT_SELECTOR()

View file

@ -1,25 +1,30 @@
from __future__ import absolute_import from __future__ import absolute_import
import errno
import warnings
import hmac
import hmac
import os
import sys
import warnings
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from hashlib import md5, sha1, sha256 from hashlib import md5, sha1, sha256
from ..exceptions import SSLError, InsecurePlatformWarning, SNIMissingWarning from ..exceptions import (
InsecurePlatformWarning,
ProxySchemeUnsupported,
SNIMissingWarning,
SSLError,
)
from ..packages import six
from .url import BRACELESS_IPV6_ADDRZ_RE, IPV4_RE
SSLContext = None SSLContext = None
SSLTransport = None
HAS_SNI = False HAS_SNI = False
IS_PYOPENSSL = False IS_PYOPENSSL = False
IS_SECURETRANSPORT = False IS_SECURETRANSPORT = False
ALPN_PROTOCOLS = ["http/1.1"]
# Maps the length of a digest to a possible hash function producing this digest # Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = { HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
32: md5,
40: sha1,
64: sha256,
}
def _const_compare_digest_backport(a, b): def _const_compare_digest_backport(a, b):
@ -30,29 +35,61 @@ def _const_compare_digest_backport(a, b):
Returns True if the digests match, and False otherwise. Returns True if the digests match, and False otherwise.
""" """
result = abs(len(a) - len(b)) result = abs(len(a) - len(b))
for l, r in zip(bytearray(a), bytearray(b)): for left, right in zip(bytearray(a), bytearray(b)):
result |= l ^ r result |= left ^ right
return result == 0 return result == 0
_const_compare_digest = getattr(hmac, 'compare_digest', _const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport)
_const_compare_digest_backport)
try: # Test for SSL features try: # Test for SSL features
import ssl import ssl
from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23 from ssl import CERT_REQUIRED, wrap_socket
except ImportError:
pass
try:
from ssl import HAS_SNI # Has SNI? from ssl import HAS_SNI # Has SNI?
except ImportError: except ImportError:
pass pass
try:
from .ssltransport import SSLTransport
except ImportError:
pass
try: # Platform-specific: Python 3.6
from ssl import PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
try:
from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2
try: try:
from ssl import OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION from ssl import PROTOCOL_TLS_CLIENT
except ImportError:
PROTOCOL_TLS_CLIENT = PROTOCOL_TLS
try:
from ssl import OP_NO_COMPRESSION, OP_NO_SSLv2, OP_NO_SSLv3
except ImportError: except ImportError:
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000 OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
OP_NO_COMPRESSION = 0x20000 OP_NO_COMPRESSION = 0x20000
try: # OP_NO_TICKET was added in Python 3.6
from ssl import OP_NO_TICKET
except ImportError:
OP_NO_TICKET = 0x4000
# A secure default. # A secure default.
# Sources for more information on TLS ciphers: # Sources for more information on TLS ciphers:
# #
@ -61,41 +98,39 @@ except ImportError:
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ # - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
# #
# The general intent is: # The general intent is:
# - Prefer TLS 1.3 cipher suites
# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE), # - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance, # - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and # - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and
# security, # security,
# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common, # - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common,
# - disable NULL authentication, MD5 MACs and DSS for security reasons. # - disable NULL authentication, MD5 MACs, DSS, and other
DEFAULT_CIPHERS = ':'.join([ # insecure ciphers for security reasons.
'TLS13-AES-256-GCM-SHA384', # - NOTE: TLS 1.3 cipher suites are managed through a different interface
'TLS13-CHACHA20-POLY1305-SHA256', # not exposed by CPython (yet!) and are enabled by default if they're available.
'TLS13-AES-128-GCM-SHA256', DEFAULT_CIPHERS = ":".join(
'ECDH+AESGCM', [
'ECDH+CHACHA20', "ECDHE+AESGCM",
'DH+AESGCM', "ECDHE+CHACHA20",
'DH+CHACHA20', "DHE+AESGCM",
'ECDH+AES256', "DHE+CHACHA20",
'DH+AES256', "ECDH+AESGCM",
'ECDH+AES128', "DH+AESGCM",
'DH+AES', "ECDH+AES",
'RSA+AESGCM', "DH+AES",
'RSA+AES', "RSA+AESGCM",
'!aNULL', "RSA+AES",
'!eNULL', "!aNULL",
'!MD5', "!eNULL",
]) "!MD5",
"!DSS",
]
)
try: try:
from ssl import SSLContext # Modern SSL? from ssl import SSLContext # Modern SSL?
except ImportError: except ImportError:
import sys
class SSLContext(object): # Platform-specific: Python 2 & 3.1
supports_set_ciphers = ((2, 7) <= sys.version_info < (3,) or
(3, 2) <= sys.version_info)
class SSLContext(object): # Platform-specific: Python 2
def __init__(self, protocol_version): def __init__(self, protocol_version):
self.protocol = protocol_version self.protocol = protocol_version
# Use default values from a real SSLContext # Use default values from a real SSLContext
@ -111,43 +146,37 @@ except ImportError:
self.certfile = certfile self.certfile = certfile
self.keyfile = keyfile self.keyfile = keyfile
def load_verify_locations(self, cafile=None, capath=None): def load_verify_locations(self, cafile=None, capath=None, cadata=None):
self.ca_certs = cafile self.ca_certs = cafile
if capath is not None: if capath is not None:
raise SSLError("CA directories not supported in older Pythons") raise SSLError("CA directories not supported in older Pythons")
if cadata is not None:
raise SSLError("CA data not supported in older Pythons")
def set_ciphers(self, cipher_suite): def set_ciphers(self, cipher_suite):
if not self.supports_set_ciphers:
raise TypeError(
'Your version of Python does not support setting '
'a custom cipher suite. Please upgrade to Python '
'2.7, 3.2, or later if you need this functionality.'
)
self.ciphers = cipher_suite self.ciphers = cipher_suite
def wrap_socket(self, socket, server_hostname=None, server_side=False): def wrap_socket(self, socket, server_hostname=None, server_side=False):
warnings.warn( warnings.warn(
'A true SSLContext object is not available. This prevents ' "A true SSLContext object is not available. This prevents "
'urllib3 from configuring SSL appropriately and may cause ' "urllib3 from configuring SSL appropriately and may cause "
'certain SSL connections to fail. You can upgrade to a newer ' "certain SSL connections to fail. You can upgrade to a newer "
'version of Python to solve this. For more information, see ' "version of Python to solve this. For more information, see "
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html' "https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
'#ssl-warnings', "#ssl-warnings",
InsecurePlatformWarning InsecurePlatformWarning,
) )
kwargs = { kwargs = {
'keyfile': self.keyfile, "keyfile": self.keyfile,
'certfile': self.certfile, "certfile": self.certfile,
'ca_certs': self.ca_certs, "ca_certs": self.ca_certs,
'cert_reqs': self.verify_mode, "cert_reqs": self.verify_mode,
'ssl_version': self.protocol, "ssl_version": self.protocol,
'server_side': server_side, "server_side": server_side,
} }
if self.supports_set_ciphers: # Platform-specific: Python 2.7+
return wrap_socket(socket, ciphers=self.ciphers, **kwargs) return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
else: # Platform-specific: Python 2.6
return wrap_socket(socket, **kwargs)
def assert_fingerprint(cert, fingerprint): def assert_fingerprint(cert, fingerprint):
@ -160,12 +189,11 @@ def assert_fingerprint(cert, fingerprint):
Fingerprint as string of hexdigits, can be interspersed by colons. Fingerprint as string of hexdigits, can be interspersed by colons.
""" """
fingerprint = fingerprint.replace(':', '').lower() fingerprint = fingerprint.replace(":", "").lower()
digest_length = len(fingerprint) digest_length = len(fingerprint)
hashfunc = HASHFUNC_MAP.get(digest_length) hashfunc = HASHFUNC_MAP.get(digest_length)
if not hashfunc: if not hashfunc:
raise SSLError( raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint))
'Fingerprint of invalid length: {0}'.format(fingerprint))
# We need encode() here for py32; works on py2 and p33. # We need encode() here for py32; works on py2 and p33.
fingerprint_bytes = unhexlify(fingerprint.encode()) fingerprint_bytes = unhexlify(fingerprint.encode())
@ -173,28 +201,31 @@ def assert_fingerprint(cert, fingerprint):
cert_digest = hashfunc(cert).digest() cert_digest = hashfunc(cert).digest()
if not _const_compare_digest(cert_digest, fingerprint_bytes): if not _const_compare_digest(cert_digest, fingerprint_bytes):
raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".' raise SSLError(
.format(fingerprint, hexlify(cert_digest))) 'Fingerprints did not match. Expected "{0}", got "{1}".'.format(
fingerprint, hexlify(cert_digest)
)
)
def resolve_cert_reqs(candidate): def resolve_cert_reqs(candidate):
""" """
Resolves the argument to a numeric constant, which can be passed to Resolves the argument to a numeric constant, which can be passed to
the wrap_socket function/method from the ssl module. the wrap_socket function/method from the ssl module.
Defaults to :data:`ssl.CERT_NONE`. Defaults to :data:`ssl.CERT_REQUIRED`.
If given a string it is assumed to be the name of the constant in the If given a string it is assumed to be the name of the constant in the
:mod:`ssl` module or its abbrevation. :mod:`ssl` module or its abbreviation.
(So you can specify `REQUIRED` instead of `CERT_REQUIRED`. (So you can specify `REQUIRED` instead of `CERT_REQUIRED`.
If it's neither `None` nor a string we assume it is already the numeric If it's neither `None` nor a string we assume it is already the numeric
constant which can directly be passed to wrap_socket. constant which can directly be passed to wrap_socket.
""" """
if candidate is None: if candidate is None:
return CERT_NONE return CERT_REQUIRED
if isinstance(candidate, str): if isinstance(candidate, str):
res = getattr(ssl, candidate, None) res = getattr(ssl, candidate, None)
if res is None: if res is None:
res = getattr(ssl, 'CERT_' + candidate) res = getattr(ssl, "CERT_" + candidate)
return res return res
return candidate return candidate
@ -205,19 +236,20 @@ def resolve_ssl_version(candidate):
like resolve_cert_reqs like resolve_cert_reqs
""" """
if candidate is None: if candidate is None:
return PROTOCOL_SSLv23 return PROTOCOL_TLS
if isinstance(candidate, str): if isinstance(candidate, str):
res = getattr(ssl, candidate, None) res = getattr(ssl, candidate, None)
if res is None: if res is None:
res = getattr(ssl, 'PROTOCOL_' + candidate) res = getattr(ssl, "PROTOCOL_" + candidate)
return res return res
return candidate return candidate
def create_urllib3_context(ssl_version=None, cert_reqs=None, def create_urllib3_context(
options=None, ciphers=None): ssl_version=None, cert_reqs=None, options=None, ciphers=None
):
"""All arguments have the same meaning as ``ssl_wrap_socket``. """All arguments have the same meaning as ``ssl_wrap_socket``.
By default, this function does a lot of the same work that By default, this function does a lot of the same work that
@ -244,14 +276,20 @@ def create_urllib3_context(ssl_version=None, cert_reqs=None,
``ssl.CERT_REQUIRED``. ``ssl.CERT_REQUIRED``.
:param options: :param options:
Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``, Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``,
``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``. ``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``.
:param ciphers: :param ciphers:
Which cipher suites to allow the server to select. Which cipher suites to allow the server to select.
:returns: :returns:
Constructed SSLContext object with specified options Constructed SSLContext object with specified options
:rtype: SSLContext :rtype: SSLContext
""" """
context = SSLContext(ssl_version or ssl.PROTOCOL_SSLv23) # PROTOCOL_TLS is deprecated in Python 3.10
if not ssl_version or ssl_version == PROTOCOL_TLS:
ssl_version = PROTOCOL_TLS_CLIENT
context = SSLContext(ssl_version)
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
# Setting the default here, as we may have no ssl module on import # Setting the default here, as we may have no ssl module on import
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
@ -265,24 +303,70 @@ def create_urllib3_context(ssl_version=None, cert_reqs=None,
# Disable compression to prevent CRIME attacks for OpenSSL 1.0+ # Disable compression to prevent CRIME attacks for OpenSSL 1.0+
# (issue #309) # (issue #309)
options |= OP_NO_COMPRESSION options |= OP_NO_COMPRESSION
# TLSv1.2 only. Unless set explicitly, do not request tickets.
# This may save some bandwidth on wire, and although the ticket is encrypted,
# there is a risk associated with it being on wire,
# if the server is not rotating its ticketing keys properly.
options |= OP_NO_TICKET
context.options |= options context.options |= options
if getattr(context, 'supports_set_ciphers', True): # Platform-specific: Python 2.6 # Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is
context.set_ciphers(ciphers or DEFAULT_CIPHERS) # necessary for conditional client cert authentication with TLS 1.3.
# The attribute is None for OpenSSL <= 1.1.0 or does not exist in older
# versions of Python. We only enable on Python 3.7.4+ or if certificate
# verification is enabled to work around Python issue #37428
# See: https://bugs.python.org/issue37428
if (cert_reqs == ssl.CERT_REQUIRED or sys.version_info >= (3, 7, 4)) and getattr(
context, "post_handshake_auth", None
) is not None:
context.post_handshake_auth = True
context.verify_mode = cert_reqs def disable_check_hostname():
if getattr(context, 'check_hostname', None) is not None: # Platform-specific: Python 3.2 if (
getattr(context, "check_hostname", None) is not None
): # Platform-specific: Python 3.2
# We do our own verification, including fingerprints and alternative # We do our own verification, including fingerprints and alternative
# hostnames. So disable it here # hostnames. So disable it here
context.check_hostname = False context.check_hostname = False
# The order of the below lines setting verify_mode and check_hostname
# matter due to safe-guards SSLContext has to prevent an SSLContext with
# check_hostname=True, verify_mode=NONE/OPTIONAL. This is made even more
# complex because we don't know whether PROTOCOL_TLS_CLIENT will be used
# or not so we don't know the initial state of the freshly created SSLContext.
if cert_reqs == ssl.CERT_REQUIRED:
context.verify_mode = cert_reqs
disable_check_hostname()
else:
disable_check_hostname()
context.verify_mode = cert_reqs
# Enable logging of TLS session keys via defacto standard environment variable
# 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values.
if hasattr(context, "keylog_filename"):
sslkeylogfile = os.environ.get("SSLKEYLOGFILE")
if sslkeylogfile:
context.keylog_filename = sslkeylogfile
return context return context
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, def ssl_wrap_socket(
ca_certs=None, server_hostname=None, sock,
ssl_version=None, ciphers=None, ssl_context=None, keyfile=None,
ca_cert_dir=None): certfile=None,
cert_reqs=None,
ca_certs=None,
server_hostname=None,
ssl_version=None,
ciphers=None,
ssl_context=None,
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
tls_in_tls=False,
):
""" """
All arguments except for server_hostname, ssl_context, and ca_cert_dir have All arguments except for server_hostname, ssl_context, and ca_cert_dir have
the same meaning as they do when using :func:`ssl.wrap_socket`. the same meaning as they do when using :func:`ssl.wrap_socket`.
@ -293,49 +377,119 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
A pre-made :class:`SSLContext` object. If none is provided, one will A pre-made :class:`SSLContext` object. If none is provided, one will
be created using :func:`create_urllib3_context`. be created using :func:`create_urllib3_context`.
:param ciphers: :param ciphers:
A string of ciphers we wish the client to support. This is not A string of ciphers we wish the client to support.
supported on Python 2.6 as the ssl module does not support it.
:param ca_cert_dir: :param ca_cert_dir:
A directory containing CA certificates in multiple separate files, as A directory containing CA certificates in multiple separate files, as
supported by OpenSSL's -CApath flag or the capath argument to supported by OpenSSL's -CApath flag or the capath argument to
SSLContext.load_verify_locations(). SSLContext.load_verify_locations().
:param key_password:
Optional password if the keyfile is encrypted.
:param ca_cert_data:
Optional string containing CA certificates in PEM format suitable for
passing as the cadata parameter to SSLContext.load_verify_locations()
:param tls_in_tls:
Use SSLTransport to wrap the existing socket.
""" """
context = ssl_context context = ssl_context
if context is None: if context is None:
# Note: This branch of code and all the variables in it are no longer # Note: This branch of code and all the variables in it are no longer
# used by urllib3 itself. We should consider deprecating and removing # used by urllib3 itself. We should consider deprecating and removing
# this code. # this code.
context = create_urllib3_context(ssl_version, cert_reqs, context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
ciphers=ciphers)
if ca_certs or ca_cert_dir: if ca_certs or ca_cert_dir or ca_cert_data:
try: try:
context.load_verify_locations(ca_certs, ca_cert_dir) context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
except IOError as e: # Platform-specific: Python 2.6, 2.7, 3.2 except (IOError, OSError) as e:
raise SSLError(e) raise SSLError(e)
# Py33 raises FileNotFoundError which subclasses OSError
# These are not equivalent unless we check the errno attribute elif ssl_context is None and hasattr(context, "load_default_certs"):
except OSError as e: # Platform-specific: Python 3.3 and beyond
if e.errno == errno.ENOENT:
raise SSLError(e)
raise
elif getattr(context, 'load_default_certs', None) is not None:
# try to load OS default certs; works well on Windows (require Python3.4+) # try to load OS default certs; works well on Windows (require Python3.4+)
context.load_default_certs() context.load_default_certs()
if certfile: # Attempt to detect if we get the goofy behavior of the
context.load_cert_chain(certfile, keyfile) # keyfile being encrypted and OpenSSL asking for the
if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI # passphrase via the terminal and instead error out.
return context.wrap_socket(sock, server_hostname=server_hostname) if keyfile and key_password is None and _is_key_file_encrypted(keyfile):
raise SSLError("Client private key is encrypted, password is required")
warnings.warn( if certfile:
'An HTTPS request has been made, but the SNI (Subject Name ' if key_password is None:
'Indication) extension to TLS is not available on this platform. ' context.load_cert_chain(certfile, keyfile)
'This may cause the server to present an incorrect TLS ' else:
'certificate, which can cause validation failures. You can upgrade to ' context.load_cert_chain(certfile, keyfile, key_password)
'a newer version of Python to solve this. For more information, see '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html' try:
'#ssl-warnings', if hasattr(context, "set_alpn_protocols"):
SNIMissingWarning context.set_alpn_protocols(ALPN_PROTOCOLS)
except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols
pass
# If we detect server_hostname is an IP address then the SNI
# extension should not be used according to RFC3546 Section 3.1
use_sni_hostname = server_hostname and not is_ipaddress(server_hostname)
# SecureTransport uses server_hostname in certificate verification.
send_sni = (use_sni_hostname and HAS_SNI) or (
IS_SECURETRANSPORT and server_hostname
) )
return context.wrap_socket(sock) # Do not warn the user if server_hostname is an invalid SNI hostname.
if not HAS_SNI and use_sni_hostname:
warnings.warn(
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
"This may cause the server to present an incorrect TLS "
"certificate, which can cause validation failures. You can upgrade to "
"a newer version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
"#ssl-warnings",
SNIMissingWarning,
)
if send_sni:
ssl_sock = _ssl_wrap_socket_impl(
sock, context, tls_in_tls, server_hostname=server_hostname
)
else:
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls)
return ssl_sock
def is_ipaddress(hostname):
"""Detects whether the hostname given is an IPv4 or IPv6 address.
Also detects IPv6 addresses with Zone IDs.
:param str hostname: Hostname to examine.
:return: True if the hostname is an IP address, False otherwise.
"""
if not six.PY2 and isinstance(hostname, bytes):
# IDN A-label bytes are ASCII compatible.
hostname = hostname.decode("ascii")
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname))
def _is_key_file_encrypted(key_file):
"""Detects if a key file is encrypted or not."""
with open(key_file, "r") as f:
for line in f:
# Look for Proc-Type: 4,ENCRYPTED
if "ENCRYPTED" in line:
return True
return False
def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None):
if tls_in_tls:
if not SSLTransport:
# Import error, ssl is not available.
raise ProxySchemeUnsupported(
"TLS in TLS requires support for the 'ssl' module"
)
SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context)
return SSLTransport(sock, ssl_context, server_hostname)
if server_hostname:
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
else:
return ssl_context.wrap_socket(sock)

View file

@ -0,0 +1,221 @@
import io
import socket
import ssl
from urllib3.exceptions import ProxySchemeUnsupported
from urllib3.packages import six
SSL_BLOCKSIZE = 16384
class SSLTransport:
"""
The SSLTransport wraps an existing socket and establishes an SSL connection.
Contrary to Python's implementation of SSLSocket, it allows you to chain
multiple TLS connections together. It's particularly useful if you need to
implement TLS within TLS.
The class supports most of the socket API operations.
"""
@staticmethod
def _validate_ssl_context_for_tls_in_tls(ssl_context):
"""
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
for TLS in TLS.
The only requirement is that the ssl_context provides the 'wrap_bio'
methods.
"""
if not hasattr(ssl_context, "wrap_bio"):
if six.PY2:
raise ProxySchemeUnsupported(
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
"supported on Python 2"
)
else:
raise ProxySchemeUnsupported(
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
"available on non-native SSLContext"
)
def __init__(
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
):
"""
Create an SSLTransport around socket using the provided ssl_context.
"""
self.incoming = ssl.MemoryBIO()
self.outgoing = ssl.MemoryBIO()
self.suppress_ragged_eofs = suppress_ragged_eofs
self.socket = socket
self.sslobj = ssl_context.wrap_bio(
self.incoming, self.outgoing, server_hostname=server_hostname
)
# Perform initial handshake.
self._ssl_io_loop(self.sslobj.do_handshake)
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
def fileno(self):
return self.socket.fileno()
def read(self, len=1024, buffer=None):
return self._wrap_ssl_read(len, buffer)
def recv(self, len=1024, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv")
return self._wrap_ssl_read(len)
def recv_into(self, buffer, nbytes=None, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv_into")
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
return self.read(nbytes, buffer)
def sendall(self, data, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to sendall")
count = 0
with memoryview(data) as view, view.cast("B") as byte_view:
amount = len(byte_view)
while count < amount:
v = self.send(byte_view[count:])
count += v
def send(self, data, flags=0):
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to send")
response = self._ssl_io_loop(self.sslobj.write, data)
return response
def makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
"""
Python's httpclient uses makefile and buffered io when reading HTTP
messages and we need to support it.
This is unfortunately a copy and paste of socket.py makefile with small
changes to point to the socket directly.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = socket.SocketIO(self, rawmode)
self.socket._io_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text
def unwrap(self):
self._ssl_io_loop(self.sslobj.unwrap)
def close(self):
self.socket.close()
def getpeercert(self, binary_form=False):
return self.sslobj.getpeercert(binary_form)
def version(self):
return self.sslobj.version()
def cipher(self):
return self.sslobj.cipher()
def selected_alpn_protocol(self):
return self.sslobj.selected_alpn_protocol()
def selected_npn_protocol(self):
return self.sslobj.selected_npn_protocol()
def shared_ciphers(self):
return self.sslobj.shared_ciphers()
def compression(self):
return self.sslobj.compression()
def settimeout(self, value):
self.socket.settimeout(value)
def gettimeout(self):
return self.socket.gettimeout()
def _decref_socketios(self):
self.socket._decref_socketios()
def _wrap_ssl_read(self, len, buffer=None):
try:
return self._ssl_io_loop(self.sslobj.read, len, buffer)
except ssl.SSLError as e:
if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
return 0 # eof, return 0.
else:
raise
def _ssl_io_loop(self, func, *args):
"""Performs an I/O loop between incoming/outgoing and the socket."""
should_loop = True
ret = None
while should_loop:
errno = None
try:
ret = func(*args)
except ssl.SSLError as e:
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
# WANT_READ, and WANT_WRITE are expected, others are not.
raise e
errno = e.errno
buf = self.outgoing.read()
self.socket.sendall(buf)
if errno is None:
should_loop = False
elif errno == ssl.SSL_ERROR_WANT_READ:
buf = self.socket.recv(SSL_BLOCKSIZE)
if buf:
self.incoming.write(buf)
else:
self.incoming.write_eof()
return ret

View file

@ -1,8 +1,10 @@
from __future__ import absolute_import from __future__ import absolute_import
import time
# The default socket timeout, used by httplib to indicate that no timeout was # The default socket timeout, used by httplib to indicate that no timeout was
# specified by the user # specified by the user
from socket import _GLOBAL_DEFAULT_TIMEOUT from socket import _GLOBAL_DEFAULT_TIMEOUT
import time
from ..exceptions import TimeoutStateError from ..exceptions import TimeoutStateError
@ -18,17 +20,23 @@ current_time = getattr(time, "monotonic", time.time)
class Timeout(object): class Timeout(object):
"""Timeout configuration. """Timeout configuration.
Timeouts can be defined as a default for a pool:: Timeouts can be defined as a default for a pool:
.. code-block:: python
timeout = Timeout(connect=2.0, read=7.0) timeout = Timeout(connect=2.0, read=7.0)
http = PoolManager(timeout=timeout) http = PoolManager(timeout=timeout)
response = http.request('GET', 'http://example.com/') response = http.request('GET', 'http://example.com/')
Or per-request (which overrides the default for the pool):: Or per-request (which overrides the default for the pool):
.. code-block:: python
response = http.request('GET', 'http://example.com/', timeout=Timeout(10)) response = http.request('GET', 'http://example.com/', timeout=Timeout(10))
Timeouts can be disabled by setting all the parameters to ``None``:: Timeouts can be disabled by setting all the parameters to ``None``:
.. code-block:: python
no_timeout = Timeout(connect=None, read=None) no_timeout = Timeout(connect=None, read=None)
response = http.request('GET', 'http://example.com/, timeout=no_timeout) response = http.request('GET', 'http://example.com/, timeout=no_timeout)
@ -42,26 +50,27 @@ class Timeout(object):
Defaults to None. Defaults to None.
:type total: integer, float, or None :type total: int, float, or None
:param connect: :param connect:
The maximum amount of time to wait for a connection attempt to a server The maximum amount of time (in seconds) to wait for a connection
to succeed. Omitting the parameter will default the connect timeout to attempt to a server to succeed. Omitting the parameter will default the
the system default, probably `the global default timeout in socket.py connect timeout to the system default, probably `the global default
timeout in socket.py
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_. <http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
None will set an infinite timeout for connection attempts. None will set an infinite timeout for connection attempts.
:type connect: integer, float, or None :type connect: int, float, or None
:param read: :param read:
The maximum amount of time to wait between consecutive The maximum amount of time (in seconds) to wait between consecutive
read operations for a response from the server. Omitting read operations for a response from the server. Omitting the parameter
the parameter will default the read timeout to the system will default the read timeout to the system default, probably `the
default, probably `the global default timeout in socket.py global default timeout in socket.py
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_. <http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
None will set an infinite timeout. None will set an infinite timeout.
:type read: integer, float, or None :type read: int, float, or None
.. note:: .. note::
@ -91,14 +100,21 @@ class Timeout(object):
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
def __init__(self, total=None, connect=_Default, read=_Default): def __init__(self, total=None, connect=_Default, read=_Default):
self._connect = self._validate_timeout(connect, 'connect') self._connect = self._validate_timeout(connect, "connect")
self._read = self._validate_timeout(read, 'read') self._read = self._validate_timeout(read, "read")
self.total = self._validate_timeout(total, 'total') self.total = self._validate_timeout(total, "total")
self._start_connect = None self._start_connect = None
def __str__(self): def __repr__(self):
return '%s(connect=%r, read=%r, total=%r)' % ( return "%s(connect=%r, read=%r, total=%r)" % (
type(self).__name__, self._connect, self._read, self.total) type(self).__name__,
self._connect,
self._read,
self.total,
)
# __str__ provided for backwards compatibility
__str__ = __repr__
@classmethod @classmethod
def _validate_timeout(cls, value, name): def _validate_timeout(cls, value, name):
@ -118,22 +134,31 @@ class Timeout(object):
return value return value
if isinstance(value, bool): if isinstance(value, bool):
raise ValueError("Timeout cannot be a boolean value. It must " raise ValueError(
"be an int, float or None.") "Timeout cannot be a boolean value. It must "
"be an int, float or None."
)
try: try:
float(value) float(value)
except (TypeError, ValueError): except (TypeError, ValueError):
raise ValueError("Timeout value %s was %s, but it must be an " raise ValueError(
"int, float or None." % (name, value)) "Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value)
)
try: try:
if value <= 0: if value <= 0:
raise ValueError("Attempted to set %s timeout to %s, but the " raise ValueError(
"Attempted to set %s timeout to %s, but the "
"timeout cannot be set to a value less " "timeout cannot be set to a value less "
"than or equal to 0." % (name, value)) "than or equal to 0." % (name, value)
except TypeError: # Python 3 )
raise ValueError("Timeout value %s was %s, but it must be an " except TypeError:
"int, float or None." % (name, value)) # Python 3
raise ValueError(
"Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value)
)
return value return value
@ -165,8 +190,7 @@ class Timeout(object):
# We can't use copy.deepcopy because that will also create a new object # We can't use copy.deepcopy because that will also create a new object
# for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to # for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to
# detect the user default. # detect the user default.
return Timeout(connect=self._connect, read=self._read, return Timeout(connect=self._connect, read=self._read, total=self.total)
total=self.total)
def start_connect(self): def start_connect(self):
"""Start the timeout clock, used during a connect() attempt """Start the timeout clock, used during a connect() attempt
@ -182,14 +206,15 @@ class Timeout(object):
def get_connect_duration(self): def get_connect_duration(self):
"""Gets the time elapsed since the call to :meth:`start_connect`. """Gets the time elapsed since the call to :meth:`start_connect`.
:return: Elapsed time. :return: Elapsed time in seconds.
:rtype: float :rtype: float
:raises urllib3.exceptions.TimeoutStateError: if you attempt :raises urllib3.exceptions.TimeoutStateError: if you attempt
to get duration for a timer that hasn't been started. to get duration for a timer that hasn't been started.
""" """
if self._start_connect is None: if self._start_connect is None:
raise TimeoutStateError("Can't get connect duration for timer " raise TimeoutStateError(
"that has not started.") "Can't get connect duration for timer that has not started."
)
return current_time() - self._start_connect return current_time() - self._start_connect
@property @property
@ -227,15 +252,16 @@ class Timeout(object):
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect` :raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
has not yet been called on this object. has not yet been called on this object.
""" """
if (self.total is not None and if (
self.total is not self.DEFAULT_TIMEOUT and self.total is not None
self._read is not None and and self.total is not self.DEFAULT_TIMEOUT
self._read is not self.DEFAULT_TIMEOUT): and self._read is not None
and self._read is not self.DEFAULT_TIMEOUT
):
# In case the connect timeout has not yet been established. # In case the connect timeout has not yet been established.
if self._start_connect is None: if self._start_connect is None:
return self._read return self._read
return max(0, min(self.total - self.get_connect_duration(), return max(0, min(self.total - self.get_connect_duration(), self._read))
self._read))
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT: elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
return max(0, self.total - self.get_connect_duration()) return max(0, self.total - self.get_connect_duration())
else: else:

View file

@ -1,34 +1,110 @@
from __future__ import absolute_import from __future__ import absolute_import
import re
from collections import namedtuple from collections import namedtuple
from ..exceptions import LocationParseError from ..exceptions import LocationParseError
from ..packages import six
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
url_attrs = ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment']
# We only want to normalize urls with an HTTP(S) scheme. # We only want to normalize urls with an HTTP(S) scheme.
# urllib3 infers URLs without a scheme (None) to be http. # urllib3 infers URLs without a scheme (None) to be http.
NORMALIZABLE_SCHEMES = ('http', 'https', None) NORMALIZABLE_SCHEMES = ("http", "https", None)
# Almost all of these patterns were derived from the
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
URI_RE = re.compile(
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
r"(?://([^\\/?#]*))?"
r"([^?#]*)"
r"(?:\?([^#]*))?"
r"(?:#(.*))?$",
re.UNICODE | re.DOTALL,
)
IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
HEX_PAT = "[0-9A-Fa-f]{1,4}"
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT)
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT}
_variations = [
# 6( h16 ":" ) ls32
"(?:%(hex)s:){6}%(ls32)s",
# "::" 5( h16 ":" ) ls32
"::(?:%(hex)s:){5}%(ls32)s",
# [ h16 ] "::" 4( h16 ":" ) ls32
"(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s",
# [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
"(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s",
# [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
"(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s",
# [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
"(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s",
# [ *4( h16 ":" ) h16 ] "::" ls32
"(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s",
# [ *5( h16 ":" ) h16 ] "::" h16
"(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s",
# [ *6( h16 ":" ) h16 ] "::"
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
]
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~"
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]"
REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
IPV4_RE = re.compile("^" + IPV4_PAT + "$")
IPV6_RE = re.compile("^" + IPV6_PAT + "$")
IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$")
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$")
ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$")
_HOST_PORT_PAT = ("^(%s|%s|%s)(?::([0-9]{0,5}))?$") % (
REG_NAME_PAT,
IPV4_PAT,
IPV6_ADDRZ_PAT,
)
_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL)
UNRESERVED_CHARS = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
)
SUB_DELIM_CHARS = set("!$&'()*+,;=")
USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"}
PATH_CHARS = USERINFO_CHARS | {"@", "/"}
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"}
class Url(namedtuple('Url', url_attrs)): class Url(namedtuple("Url", url_attrs)):
""" """
Data structure for representing an HTTP URL. Used as a return value for Data structure for representing an HTTP URL. Used as a return value for
:func:`parse_url`. Both the scheme and host are normalized as they are :func:`parse_url`. Both the scheme and host are normalized as they are
both case-insensitive according to RFC 3986. both case-insensitive according to RFC 3986.
""" """
__slots__ = () __slots__ = ()
def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None, def __new__(
query=None, fragment=None): cls,
if path and not path.startswith('/'): scheme=None,
path = '/' + path auth=None,
if scheme: host=None,
port=None,
path=None,
query=None,
fragment=None,
):
if path and not path.startswith("/"):
path = "/" + path
if scheme is not None:
scheme = scheme.lower() scheme = scheme.lower()
if host and scheme in NORMALIZABLE_SCHEMES: return super(Url, cls).__new__(
host = host.lower() cls, scheme, auth, host, port, path, query, fragment
return super(Url, cls).__new__(cls, scheme, auth, host, port, path, )
query, fragment)
@property @property
def hostname(self): def hostname(self):
@ -38,10 +114,10 @@ class Url(namedtuple('Url', url_attrs)):
@property @property
def request_uri(self): def request_uri(self):
"""Absolute path including the query string.""" """Absolute path including the query string."""
uri = self.path or '/' uri = self.path or "/"
if self.query is not None: if self.query is not None:
uri += '?' + self.query uri += "?" + self.query
return uri return uri
@ -49,7 +125,7 @@ class Url(namedtuple('Url', url_attrs)):
def netloc(self): def netloc(self):
"""Network location including host and port""" """Network location including host and port"""
if self.port: if self.port:
return '%s:%d' % (self.host, self.port) return "%s:%d" % (self.host, self.port)
return self.host return self.host
@property @property
@ -72,23 +148,23 @@ class Url(namedtuple('Url', url_attrs)):
'http://username:password@host.com:80/path?query#fragment' 'http://username:password@host.com:80/path?query#fragment'
""" """
scheme, auth, host, port, path, query, fragment = self scheme, auth, host, port, path, query, fragment = self
url = '' url = u""
# We use "is not None" we want things to happen with empty strings (or 0 port) # We use "is not None" we want things to happen with empty strings (or 0 port)
if scheme is not None: if scheme is not None:
url += scheme + '://' url += scheme + u"://"
if auth is not None: if auth is not None:
url += auth + '@' url += auth + u"@"
if host is not None: if host is not None:
url += host url += host
if port is not None: if port is not None:
url += ':' + str(port) url += u":" + str(port)
if path is not None: if path is not None:
url += path url += path
if query is not None: if query is not None:
url += '?' + query url += u"?" + query
if fragment is not None: if fragment is not None:
url += '#' + fragment url += u"#" + fragment
return url return url
@ -98,6 +174,8 @@ class Url(namedtuple('Url', url_attrs)):
def split_first(s, delims): def split_first(s, delims):
""" """
.. deprecated:: 1.25
Given a string and an iterable of delimiters, split on the first found Given a string and an iterable of delimiters, split on the first found
delimiter. Return two split parts and the matched delimiter. delimiter. Return two split parts and the matched delimiter.
@ -124,15 +202,141 @@ def split_first(s, delims):
min_delim = d min_delim = d
if min_idx is None or min_idx < 0: if min_idx is None or min_idx < 0:
return s, '', None return s, "", None
return s[:min_idx], s[min_idx + 1 :], min_delim return s[:min_idx], s[min_idx + 1 :], min_delim
def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
"""Percent-encodes a URI component without reapplying
onto an already percent-encoded component.
"""
if component is None:
return component
component = six.ensure_text(component)
# Normalize existing percent-encoded bytes.
# Try to see if the component we're encoding is already percent-encoded
# so we can skip all '%' characters but still encode all others.
component, percent_encodings = PERCENT_RE.subn(
lambda match: match.group(0).upper(), component
)
uri_bytes = component.encode("utf-8", "surrogatepass")
is_percent_encoded = percent_encodings == uri_bytes.count(b"%")
encoded_component = bytearray()
for i in range(0, len(uri_bytes)):
# Will return a single character bytestring on both Python 2 & 3
byte = uri_bytes[i : i + 1]
byte_ord = ord(byte)
if (is_percent_encoded and byte == b"%") or (
byte_ord < 128 and byte.decode() in allowed_chars
):
encoded_component += byte
continue
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
return encoded_component.decode(encoding)
def _remove_path_dot_segments(path):
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
segments = path.split("/") # Turn the path into a list of segments
output = [] # Initialize the variable to use to store output
for segment in segments:
# '.' is the current directory, so ignore it, it is superfluous
if segment == ".":
continue
# Anything other than '..', should be appended to the output
elif segment != "..":
output.append(segment)
# In this case segment == '..', if we can, we should pop the last
# element
elif output:
output.pop()
# If the path starts with '/' and the output is empty or the first string
# is non-empty
if path.startswith("/") and (not output or output[0]):
output.insert(0, "")
# If the path starts with '/.' or '/..' ensure we add one more empty
# string to add a trailing '/'
if path.endswith(("/.", "/..")):
output.append("")
return "/".join(output)
def _normalize_host(host, scheme):
if host:
if isinstance(host, six.binary_type):
host = six.ensure_str(host)
if scheme in NORMALIZABLE_SCHEMES:
is_ipv6 = IPV6_ADDRZ_RE.match(host)
if is_ipv6:
match = ZONE_ID_RE.search(host)
if match:
start, end = match.span(1)
zone_id = host[start:end]
if zone_id.startswith("%25") and zone_id != "%25":
zone_id = zone_id[3:]
else:
zone_id = zone_id[1:]
zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS)
return host[:start].lower() + zone_id + host[end:]
else:
return host.lower()
elif not IPV4_RE.match(host):
return six.ensure_str(
b".".join([_idna_encode(label) for label in host.split(".")])
)
return host
def _idna_encode(name):
if name and any([ord(x) > 128 for x in name]):
try:
import idna
except ImportError:
six.raise_from(
LocationParseError("Unable to parse URL without the 'idna' module"),
None,
)
try:
return idna.encode(name.lower(), strict=True, std3_rules=True)
except idna.IDNAError:
six.raise_from(
LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None
)
return name.lower().encode("ascii")
def _encode_target(target):
"""Percent-encodes a request target so that there are no invalid characters"""
path, query = TARGET_RE.match(target).groups()
target = _encode_invalid_chars(path, PATH_CHARS)
query = _encode_invalid_chars(query, QUERY_CHARS)
if query is not None:
target += "?" + query
return target
def parse_url(url): def parse_url(url):
""" """
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
performed to parse incomplete urls. Fields not provided will be None. performed to parse incomplete urls. Fields not provided will be None.
This parser is RFC 3986 compliant.
The parser logic and helper functions are based heavily on
work done in the ``rfc3986`` module.
:param str url: URL to parse into a :class:`.Url` namedtuple.
Partly backwards-compatible with :mod:`urlparse`. Partly backwards-compatible with :mod:`urlparse`.
@ -145,81 +349,79 @@ def parse_url(url):
>>> parse_url('/foo?bar') >>> parse_url('/foo?bar')
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
""" """
# While this code has overlap with stdlib's urlparse, it is much
# simplified for our needs and less annoying.
# Additionally, this implementations does silly things to be optimal
# on CPython.
if not url: if not url:
# Empty # Empty
return Url() return Url()
scheme = None source_url = url
auth = None if not SCHEME_RE.search(url):
host = None url = "//" + url
port = None
path = None
fragment = None
query = None
# Scheme
if '://' in url:
scheme, url = url.split('://', 1)
# Find the earliest Authority Terminator
# (http://tools.ietf.org/html/rfc3986#section-3.2)
url, path_, delim = split_first(url, ['/', '?', '#'])
if delim:
# Reassemble the path
path = delim + path_
# Auth
if '@' in url:
# Last '@' denotes end of auth part
auth, url = url.rsplit('@', 1)
# IPv6
if url and url[0] == '[':
host, url = url.split(']', 1)
host += ']'
# Port
if ':' in url:
_host, port = url.split(':', 1)
if not host:
host = _host
if port:
# If given, ports must be integers. No whitespace, no plus or
# minus prefixes, no non-integer digits such as ^2 (superscript).
if not port.isdigit():
raise LocationParseError(url)
try: try:
port = int(port) scheme, authority, path, query, fragment = URI_RE.match(url).groups()
except ValueError: normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES
raise LocationParseError(url)
else: if scheme:
# Blank ports are cool, too. (rfc3986#section-3.2.3) scheme = scheme.lower()
if authority:
auth, _, host_port = authority.rpartition("@")
auth = auth or None
host, port = _HOST_PORT_RE.match(host_port).groups()
if auth and normalize_uri:
auth = _encode_invalid_chars(auth, USERINFO_CHARS)
if port == "":
port = None port = None
else:
auth, host, port = None, None, None
elif not host and url: if port is not None:
host = url port = int(port)
if not (0 <= port <= 65535):
raise LocationParseError(url)
host = _normalize_host(host, scheme)
if normalize_uri and path:
path = _remove_path_dot_segments(path)
path = _encode_invalid_chars(path, PATH_CHARS)
if normalize_uri and query:
query = _encode_invalid_chars(query, QUERY_CHARS)
if normalize_uri and fragment:
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS)
except (ValueError, AttributeError):
return six.raise_from(LocationParseError(source_url), None)
# For the sake of backwards compatibility we put empty
# string values for path if there are any defined values
# beyond the path in the URL.
# TODO: Remove this when we break backwards compatibility.
if not path: if not path:
return Url(scheme, auth, host, port, path, query, fragment) if query is not None or fragment is not None:
path = ""
else:
path = None
# Fragment # Ensure that each part of the URL is a `str` for
if '#' in path: # backwards compatibility.
path, fragment = path.split('#', 1) if isinstance(url, six.text_type):
ensure_func = six.ensure_text
else:
ensure_func = six.ensure_str
# Query def ensure_type(x):
if '?' in path: return x if x is None else ensure_func(x)
path, query = path.split('?', 1)
return Url(scheme, auth, host, port, path, query, fragment) return Url(
scheme=ensure_type(scheme),
auth=ensure_type(auth),
host=ensure_type(host),
port=port,
path=ensure_type(path),
query=ensure_type(query),
fragment=ensure_type(fragment),
)
def get_host(url): def get_host(url):
@ -227,4 +429,4 @@ def get_host(url):
Deprecated. Use :func:`parse_url` instead. Deprecated. Use :func:`parse_url` instead.
""" """
p = parse_url(url) p = parse_url(url)
return p.scheme or 'http', p.hostname, p.port return p.scheme or "http", p.hostname, p.port

View file

@ -1,40 +1,153 @@
from .selectors import ( import errno
HAS_SELECT, import select
DefaultSelector, import sys
EVENT_READ, from functools import partial
EVENT_WRITE
) try:
from time import monotonic
except ImportError:
from time import time as monotonic
__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"]
class NoWayToWaitForSocketError(Exception):
pass
# How should we wait on sockets?
#
# There are two types of APIs you can use for waiting on sockets: the fancy
# modern stateful APIs like epoll/kqueue, and the older stateless APIs like
# select/poll. The stateful APIs are more efficient when you have a lots of
# sockets to keep track of, because you can set them up once and then use them
# lots of times. But we only ever want to wait on a single socket at a time
# and don't want to keep track of state, so the stateless APIs are actually
# more efficient. So we want to use select() or poll().
#
# Now, how do we choose between select() and poll()? On traditional Unixes,
# select() has a strange calling convention that makes it slow, or fail
# altogether, for high-numbered file descriptors. The point of poll() is to fix
# that, so on Unixes, we prefer poll().
#
# On Windows, there is no poll() (or at least Python doesn't provide a wrapper
# for it), but that's OK, because on Windows, select() doesn't have this
# strange calling convention; plain select() works fine.
#
# So: on Windows we use select(), and everywhere else we use poll(). We also
# fall back to select() in case poll() is somehow broken or missing.
if sys.version_info >= (3, 5):
# Modern Python, that retries syscalls by default
def _retry_on_intr(fn, timeout):
return fn(timeout)
def _wait_for_io_events(socks, events, timeout=None):
""" Waits for IO events to be available from a list of sockets
or optionally a single socket if passed in. Returns a list of
sockets that can be interacted with immediately. """
if not HAS_SELECT:
raise ValueError('Platform does not have a selector')
if not isinstance(socks, list):
# Probably just a single socket.
if hasattr(socks, "fileno"):
socks = [socks]
# Otherwise it might be a non-list iterable.
else: else:
socks = list(socks) # Old and broken Pythons.
with DefaultSelector() as selector: def _retry_on_intr(fn, timeout):
for sock in socks: if timeout is None:
selector.register(sock, events) deadline = float("inf")
return [key[0].fileobj for key in else:
selector.select(timeout) if key[1] & events] deadline = monotonic() + timeout
while True:
try:
return fn(timeout)
# OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7
except (OSError, select.error) as e:
# 'e.args[0]' incantation works for both OSError and select.error
if e.args[0] != errno.EINTR:
raise
else:
timeout = deadline - monotonic()
if timeout < 0:
timeout = 0
if timeout == float("inf"):
timeout = None
continue
def wait_for_read(socks, timeout=None): def select_wait_for_socket(sock, read=False, write=False, timeout=None):
""" Waits for reading to be available from a list of sockets if not read and not write:
or optionally a single socket if passed in. Returns a list of raise RuntimeError("must specify at least one of read=True, write=True")
sockets that can be read from immediately. """ rcheck = []
return _wait_for_io_events(socks, EVENT_READ, timeout) wcheck = []
if read:
rcheck.append(sock)
if write:
wcheck.append(sock)
# When doing a non-blocking connect, most systems signal success by
# marking the socket writable. Windows, though, signals success by marked
# it as "exceptional". We paper over the difference by checking the write
# sockets for both conditions. (The stdlib selectors module does the same
# thing.)
fn = partial(select.select, rcheck, wcheck, wcheck)
rready, wready, xready = _retry_on_intr(fn, timeout)
return bool(rready or wready or xready)
def wait_for_write(socks, timeout=None): def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
""" Waits for writing to be available from a list of sockets if not read and not write:
or optionally a single socket if passed in. Returns a list of raise RuntimeError("must specify at least one of read=True, write=True")
sockets that can be written to immediately. """ mask = 0
return _wait_for_io_events(socks, EVENT_WRITE, timeout) if read:
mask |= select.POLLIN
if write:
mask |= select.POLLOUT
poll_obj = select.poll()
poll_obj.register(sock, mask)
# For some reason, poll() takes timeout in milliseconds
def do_poll(t):
if t is not None:
t *= 1000
return poll_obj.poll(t)
return bool(_retry_on_intr(do_poll, timeout))
def null_wait_for_socket(*args, **kwargs):
raise NoWayToWaitForSocketError("no select-equivalent available")
def _have_working_poll():
# Apparently some systems have a select.poll that fails as soon as you try
# to use it, either due to strange configuration or broken monkeypatching
# from libraries like eventlet/greenlet.
try:
poll_obj = select.poll()
_retry_on_intr(poll_obj.poll, 0)
except (AttributeError, OSError):
return False
else:
return True
def wait_for_socket(*args, **kwargs):
# We delay choosing which implementation to use until the first time we're
# called. We could do it at import time, but then we might make the wrong
# decision if someone goes wild with monkeypatching select.poll after
# we're imported.
global wait_for_socket
if _have_working_poll():
wait_for_socket = poll_wait_for_socket
elif hasattr(select, "select"):
wait_for_socket = select_wait_for_socket
else: # Platform-specific: Appengine.
wait_for_socket = null_wait_for_socket
return wait_for_socket(*args, **kwargs)
def wait_for_read(sock, timeout=None):
"""Waits for reading to be available on a given socket.
Returns True if the socket is readable, or False if the timeout expired.
"""
return wait_for_socket(sock, read=True, timeout=timeout)
def wait_for_write(sock, timeout=None):
"""Waits for writing to be available on a given socket.
Returns True if the socket is readable, or False if the timeout expired.
"""
return wait_for_socket(sock, write=True, timeout=timeout)