Bump requests from 2.28.2 to 2.31.0 (#2078)

* Bump requests from 2.28.2 to 2.31.0

Bumps [requests](https://github.com/psf/requests) from 2.28.2 to 2.31.0.
- [Release notes](https://github.com/psf/requests/releases)
- [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md)
- [Commits](https://github.com/psf/requests/compare/v2.28.2...v2.31.0)

---
updated-dependencies:
- dependency-name: requests
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update requests==2.31.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2023-08-23 21:40:02 -07:00 committed by GitHub
commit 6b6d43ef43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
54 changed files with 4861 additions and 4958 deletions

View file

@ -1,46 +1,41 @@
from __future__ import absolute_import
# For backwards compatibility, provide imports that used to be here.
from __future__ import annotations
from .connection import is_connection_dropped
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
from .response import is_fp_closed
from .retry import Retry
from .ssl_ import (
ALPN_PROTOCOLS,
HAS_SNI,
IS_PYOPENSSL,
IS_SECURETRANSPORT,
PROTOCOL_TLS,
SSLContext,
assert_fingerprint,
create_urllib3_context,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
)
from .timeout import Timeout, current_time
from .url import Url, get_host, parse_url, split_first
from .timeout import Timeout
from .url import Url, parse_url
from .wait import wait_for_read, wait_for_write
__all__ = (
"HAS_SNI",
"IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext",
"PROTOCOL_TLS",
"ALPN_PROTOCOLS",
"Retry",
"Timeout",
"Url",
"assert_fingerprint",
"current_time",
"create_urllib3_context",
"is_connection_dropped",
"is_fp_closed",
"get_host",
"parse_url",
"make_headers",
"resolve_cert_reqs",
"resolve_ssl_version",
"split_first",
"ssl_wrap_socket",
"wait_for_read",
"wait_for_write",

View file

@ -1,33 +1,23 @@
from __future__ import absolute_import
from __future__ import annotations
import socket
import typing
from ..contrib import _appengine_environ
from ..exceptions import LocationParseError
from ..packages import six
from .wait import NoWayToWaitForSocketError, wait_for_read
from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
_TYPE_SOCKET_OPTIONS = typing.Sequence[typing.Tuple[int, int, typing.Union[int, bytes]]]
if typing.TYPE_CHECKING:
from .._base_connection import BaseHTTPConnection
def is_connection_dropped(conn): # Platform-specific
def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific
"""
Returns True if the connection is dropped and should be closed.
:param conn:
:class:`http.client.HTTPConnection` object.
Note: For platforms like AppEngine, this will always return ``False`` to
let the platform handle connection recycling transparently for us.
:param conn: :class:`urllib3.connection.HTTPConnection` object.
"""
sock = getattr(conn, "sock", False)
if sock is False: # Platform-specific: AppEngine
return False
if sock is None: # Connection already closed (such as by httplib).
return True
try:
# Returns True if readable, which here means it's been dropped
return wait_for_read(sock, timeout=0.0)
except NoWayToWaitForSocketError: # Platform-specific: AppEngine
return False
return not conn.is_connected
# This function is copied from socket.py in the Python 2.7 standard
@ -35,11 +25,11 @@ def is_connection_dropped(conn): # Platform-specific
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection(
address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None,
socket_options=None,
):
address: tuple[str, int],
timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
) -> socket.socket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
@ -65,9 +55,7 @@ def create_connection(
try:
host.encode("idna")
except UnicodeError:
return six.raise_from(
LocationParseError(u"'%s', label empty or too long" % host), None
)
raise LocationParseError(f"'{host}', label empty or too long") from None
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
@ -78,26 +66,33 @@ def create_connection(
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
# Break explicitly a reference cycle
err = None
return sock
except socket.error as e:
err = e
except OSError as _:
err = _
if sock is not None:
sock.close()
sock = None
if err is not None:
raise err
raise socket.error("getaddrinfo returns an empty list")
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
def _set_socket_options(sock, options):
def _set_socket_options(
sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None
) -> None:
if options is None:
return
@ -105,7 +100,7 @@ def _set_socket_options(sock, options):
sock.setsockopt(*opt)
def allowed_gai_family():
def allowed_gai_family() -> socket.AddressFamily:
"""This function is designed to work in the context of
getaddrinfo, where family=socket.AF_UNSPEC is the default and
will perform a DNS search for both IPv6 and IPv4 records."""
@ -116,18 +111,11 @@ def allowed_gai_family():
return family
def _has_ipv6(host):
def _has_ipv6(host: str) -> bool:
"""Returns True if the system can bind an IPv6 address."""
sock = None
has_ipv6 = False
# App Engine doesn't support IPV6 sockets and actually has a quota on the
# number of sockets that can be used, so just early out here instead of
# creating a socket needlessly.
# See https://github.com/urllib3/urllib3/issues/1446
if _appengine_environ.is_appengine_sandbox():
return False
if socket.has_ipv6:
# has_ipv6 returns true if cPython was compiled with IPv6 support.
# It does not tell us if the system has IPv6 support enabled. To

View file

@ -1,9 +1,18 @@
from .ssl_ import create_urllib3_context, resolve_cert_reqs, resolve_ssl_version
from __future__ import annotations
import typing
from .url import Url
if typing.TYPE_CHECKING:
from ..connection import ProxyConfig
def connection_requires_http_tunnel(
proxy_url=None, proxy_config=None, destination_scheme=None
):
proxy_url: Url | None = None,
proxy_config: ProxyConfig | None = None,
destination_scheme: str | None = None,
) -> bool:
"""
Returns True if the connection requires an HTTP CONNECT through the proxy.
@ -32,26 +41,3 @@ def connection_requires_http_tunnel(
# Otherwise always use a tunnel.
return True
def create_proxy_ssl_context(
ssl_version, cert_reqs, ca_certs=None, ca_cert_dir=None, ca_cert_data=None
):
"""
Generates a default proxy ssl context if one hasn't been provided by the
user.
"""
ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(ssl_version),
cert_reqs=resolve_cert_reqs(cert_reqs),
)
if (
not ca_certs
and not ca_cert_dir
and not ca_cert_data
and hasattr(ssl_context, "load_default_certs")
):
ssl_context.load_default_certs()
return ssl_context

View file

@ -1,22 +0,0 @@
import collections
from ..packages import six
from ..packages.six.moves import queue
if six.PY2:
# Queue is imported for side effects on MS Windows. See issue #229.
import Queue as _unused_module_Queue # noqa: F401
class LifoQueue(queue.Queue):
def _init(self, _):
self.queue = collections.deque()
def _qsize(self, len=len):
return len(self.queue)
def _put(self, item):
self.queue.append(item)
def _get(self):
return self.queue.pop()

View file

@ -1,9 +1,15 @@
from __future__ import absolute_import
from __future__ import annotations
import io
import typing
from base64 import b64encode
from enum import Enum
from ..exceptions import UnrewindableBodyError
from ..packages.six import b, integer_types
from .util import to_bytes
if typing.TYPE_CHECKING:
from typing_extensions import Final
# Pass as a value within ``headers`` to skip
# emitting some HTTP headers that are added automatically.
@ -15,25 +21,45 @@ SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"])
ACCEPT_ENCODING = "gzip,deflate"
try:
try:
import brotlicffi as _unused_module_brotli # noqa: F401
import brotlicffi as _unused_module_brotli # type: ignore[import] # noqa: F401
except ImportError:
import brotli as _unused_module_brotli # noqa: F401
import brotli as _unused_module_brotli # type: ignore[import] # noqa: F401
except ImportError:
pass
else:
ACCEPT_ENCODING += ",br"
try:
import zstandard as _unused_module_zstd # type: ignore[import] # noqa: F401
except ImportError:
pass
else:
ACCEPT_ENCODING += ",zstd"
_FAILEDTELL = object()
class _TYPE_FAILEDTELL(Enum):
token = 0
_FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token
_TYPE_BODY_POSITION = typing.Union[int, _TYPE_FAILEDTELL]
# When sending a request with these methods we aren't expecting
# a body so don't need to set an explicit 'Content-Length: 0'
# The reason we do this in the negative instead of tracking methods
# which 'should' have a body is because unknown methods should be
# treated as if they were 'POST' which *does* expect a body.
_METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"}
def make_headers(
keep_alive=None,
accept_encoding=None,
user_agent=None,
basic_auth=None,
proxy_basic_auth=None,
disable_cache=None,
):
keep_alive: bool | None = None,
accept_encoding: bool | list[str] | str | None = None,
user_agent: str | None = None,
basic_auth: str | None = None,
proxy_basic_auth: str | None = None,
disable_cache: bool | None = None,
) -> dict[str, str]:
"""
Shortcuts for generating request headers.
@ -42,7 +68,8 @@ def make_headers(
:param accept_encoding:
Can be a boolean, list, or string.
``True`` translates to 'gzip,deflate'.
``True`` translates to 'gzip,deflate'. If either the ``brotli`` or
``brotlicffi`` package is installed 'gzip,deflate,br' is used instead.
List will get joined by comma.
String will be used as provided.
@ -61,14 +88,18 @@ def make_headers(
:param disable_cache:
If ``True``, adds 'cache-control: no-cache' header.
Example::
Example:
>>> make_headers(keep_alive=True, user_agent="Batman/1.0")
{'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
>>> make_headers(accept_encoding=True)
{'accept-encoding': 'gzip,deflate'}
.. code-block:: python
import urllib3
print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0"))
# {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'}
print(urllib3.util.make_headers(accept_encoding=True))
# {'accept-encoding': 'gzip,deflate'}
"""
headers = {}
headers: dict[str, str] = {}
if accept_encoding:
if isinstance(accept_encoding, str):
pass
@ -85,12 +116,14 @@ def make_headers(
headers["connection"] = "keep-alive"
if basic_auth:
headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8")
headers[
"authorization"
] = f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}"
if proxy_basic_auth:
headers["proxy-authorization"] = "Basic " + b64encode(
b(proxy_basic_auth)
).decode("utf-8")
headers[
"proxy-authorization"
] = f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}"
if disable_cache:
headers["cache-control"] = "no-cache"
@ -98,7 +131,9 @@ def make_headers(
return headers
def set_file_position(body, pos):
def set_file_position(
body: typing.Any, pos: _TYPE_BODY_POSITION | None
) -> _TYPE_BODY_POSITION | None:
"""
If a position is provided, move file to that point.
Otherwise, we'll attempt to record a position for future use.
@ -108,7 +143,7 @@ def set_file_position(body, pos):
elif getattr(body, "tell", None) is not None:
try:
pos = body.tell()
except (IOError, OSError):
except OSError:
# This differentiates from None, allowing us to catch
# a failed `tell()` later when trying to rewind the body.
pos = _FAILEDTELL
@ -116,7 +151,7 @@ def set_file_position(body, pos):
return pos
def rewind_body(body, body_pos):
def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None:
"""
Attempt to rewind body to a certain position.
Primarily used for request redirects and retries.
@ -128,13 +163,13 @@ def rewind_body(body, body_pos):
Position to seek to in file.
"""
body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types):
if body_seek is not None and isinstance(body_pos, int):
try:
body_seek(body_pos)
except (IOError, OSError):
except OSError as e:
raise UnrewindableBodyError(
"An error occurred when rewinding request body for redirect/retry."
)
) from e
elif body_pos is _FAILEDTELL:
raise UnrewindableBodyError(
"Unable to record file position for rewinding "
@ -142,5 +177,80 @@ def rewind_body(body, body_pos):
)
else:
raise ValueError(
"body_pos must be of type integer, instead it was %s." % type(body_pos)
f"body_pos must be of type integer, instead it was {type(body_pos)}."
)
class ChunksAndContentLength(typing.NamedTuple):
chunks: typing.Iterable[bytes] | None
content_length: int | None
def body_to_chunks(
body: typing.Any | None, method: str, blocksize: int
) -> ChunksAndContentLength:
"""Takes the HTTP request method, body, and blocksize and
transforms them into an iterable of chunks to pass to
socket.sendall() and an optional 'Content-Length' header.
A 'Content-Length' of 'None' indicates the length of the body
can't be determined so should use 'Transfer-Encoding: chunked'
for framing instead.
"""
chunks: typing.Iterable[bytes] | None
content_length: int | None
# No body, we need to make a recommendation on 'Content-Length'
# based on whether that request method is expected to have
# a body or not.
if body is None:
chunks = None
if method.upper() not in _METHODS_NOT_EXPECTING_BODY:
content_length = 0
else:
content_length = None
# Bytes or strings become bytes
elif isinstance(body, (str, bytes)):
chunks = (to_bytes(body),)
content_length = len(chunks[0])
# File-like object, TODO: use seek() and tell() for length?
elif hasattr(body, "read"):
def chunk_readable() -> typing.Iterable[bytes]:
nonlocal body, blocksize
encode = isinstance(body, io.TextIOBase)
while True:
datablock = body.read(blocksize)
if not datablock:
break
if encode:
datablock = datablock.encode("iso-8859-1")
yield datablock
chunks = chunk_readable()
content_length = None
# Otherwise we need to start checking via duck-typing.
else:
try:
# Check if the body implements the buffer API.
mv = memoryview(body)
except TypeError:
try:
# Check if the body is an iterable
chunks = iter(body)
content_length = None
except TypeError:
raise TypeError(
f"'body' must be a bytes-like object, file-like "
f"object, or iterable. Instead was {body!r}"
) from None
else:
# Since it implements the buffer API can be passed directly to socket.sendall()
chunks = (body,)
content_length = mv.nbytes
return ChunksAndContentLength(chunks=chunks, content_length=content_length)

View file

@ -1,12 +1,12 @@
from __future__ import absolute_import
from __future__ import annotations
import http.client as httplib
from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect
from ..exceptions import HeaderParsingError
from ..packages.six.moves import http_client as httplib
def is_fp_closed(obj):
def is_fp_closed(obj: object) -> bool:
"""
Checks whether a given file-like object is closed.
@ -17,27 +17,27 @@ def is_fp_closed(obj):
try:
# Check `isclosed()` first, in case Python3 doesn't set `closed`.
# GH Issue #928
return obj.isclosed()
return obj.isclosed() # type: ignore[no-any-return, attr-defined]
except AttributeError:
pass
try:
# Check via the official file-like-object way.
return obj.closed
return obj.closed # type: ignore[no-any-return, attr-defined]
except AttributeError:
pass
try:
# Check if the object is a container for another file-like object that
# gets released on exhaustion (e.g. HTTPResponse).
return obj.fp is None
return obj.fp is None # type: ignore[attr-defined]
except AttributeError:
pass
raise ValueError("Unable to determine whether fp is closed.")
def assert_header_parsing(headers):
def assert_header_parsing(headers: httplib.HTTPMessage) -> None:
"""
Asserts whether all headers have been successfully parsed.
Extracts encountered errors from the result of parsing headers.
@ -53,55 +53,49 @@ def assert_header_parsing(headers):
# This will fail silently if we pass in the wrong kind of parameter.
# To make debugging easier add an explicit check.
if not isinstance(headers, httplib.HTTPMessage):
raise TypeError("expected httplib.Message, got {0}.".format(type(headers)))
defects = getattr(headers, "defects", None)
get_payload = getattr(headers, "get_payload", None)
raise TypeError(f"expected httplib.Message, got {type(headers)}.")
unparsed_data = None
if get_payload:
# get_payload is actually email.message.Message.get_payload;
# we're only interested in the result if it's not a multipart message
if not headers.is_multipart():
payload = get_payload()
if isinstance(payload, (bytes, str)):
unparsed_data = payload
if defects:
# httplib is assuming a response body is available
# when parsing headers even when httplib only sends
# header data to parse_headers() This results in
# defects on multipart responses in particular.
# See: https://github.com/urllib3/urllib3/issues/800
# get_payload is actually email.message.Message.get_payload;
# we're only interested in the result if it's not a multipart message
if not headers.is_multipart():
payload = headers.get_payload()
# So we ignore the following defects:
# - StartBoundaryNotFoundDefect:
# The claimed start boundary was never found.
# - MultipartInvariantViolationDefect:
# A message claimed to be a multipart but no subparts were found.
defects = [
defect
for defect in defects
if not isinstance(
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
)
]
if isinstance(payload, (bytes, str)):
unparsed_data = payload
# httplib is assuming a response body is available
# when parsing headers even when httplib only sends
# header data to parse_headers() This results in
# defects on multipart responses in particular.
# See: https://github.com/urllib3/urllib3/issues/800
# So we ignore the following defects:
# - StartBoundaryNotFoundDefect:
# The claimed start boundary was never found.
# - MultipartInvariantViolationDefect:
# A message claimed to be a multipart but no subparts were found.
defects = [
defect
for defect in headers.defects
if not isinstance(
defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect)
)
]
if defects or unparsed_data:
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
def is_response_to_head(response):
def is_response_to_head(response: httplib.HTTPResponse) -> bool:
"""
Checks whether the request of a response has been a HEAD-request.
Handles the quirks of AppEngine.
:param http.client.HTTPResponse response:
Response to check if the originating request
used 'HEAD' as a method.
"""
# FIXME: Can we do this somehow without accessing private httplib _method?
method = response._method
if isinstance(method, int): # Platform-specific: Appengine
return method == 3
return method.upper() == "HEAD"
method_str = response._method # type: str # type: ignore[attr-defined]
return method_str.upper() == "HEAD"

View file

@ -1,12 +1,13 @@
from __future__ import absolute_import
from __future__ import annotations
import email
import logging
import random
import re
import time
import warnings
from collections import namedtuple
import typing
from itertools import takewhile
from types import TracebackType
from ..exceptions import (
ConnectTimeoutError,
@ -17,97 +18,49 @@ from ..exceptions import (
ReadTimeoutError,
ResponseError,
)
from ..packages import six
from .util import reraise
if typing.TYPE_CHECKING:
from ..connectionpool import ConnectionPool
from ..response import BaseHTTPResponse
log = logging.getLogger(__name__)
# Data structure for representing the metadata of requests that result in a retry.
RequestHistory = namedtuple(
"RequestHistory", ["method", "url", "error", "status", "redirect_location"]
)
class RequestHistory(typing.NamedTuple):
method: str | None
url: str | None
error: Exception | None
status: int | None
redirect_location: str | None
# TODO: In v2 we can remove this sentinel and metaclass with deprecated options.
_Default = object()
class _RetryMeta(type):
@property
def DEFAULT_METHOD_WHITELIST(cls):
warnings.warn(
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
DeprecationWarning,
)
return cls.DEFAULT_ALLOWED_METHODS
@DEFAULT_METHOD_WHITELIST.setter
def DEFAULT_METHOD_WHITELIST(cls, value):
warnings.warn(
"Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead",
DeprecationWarning,
)
cls.DEFAULT_ALLOWED_METHODS = value
@property
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls):
warnings.warn(
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
DeprecationWarning,
)
return cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
@DEFAULT_REDIRECT_HEADERS_BLACKLIST.setter
def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls, value):
warnings.warn(
"Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead",
DeprecationWarning,
)
cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT = value
@property
def BACKOFF_MAX(cls):
warnings.warn(
"Using 'Retry.BACKOFF_MAX' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_BACKOFF_MAX' instead",
DeprecationWarning,
)
return cls.DEFAULT_BACKOFF_MAX
@BACKOFF_MAX.setter
def BACKOFF_MAX(cls, value):
warnings.warn(
"Using 'Retry.BACKOFF_MAX' is deprecated and "
"will be removed in v2.0. Use 'Retry.DEFAULT_BACKOFF_MAX' instead",
DeprecationWarning,
)
cls.DEFAULT_BACKOFF_MAX = value
@six.add_metaclass(_RetryMeta)
class Retry(object):
class Retry:
"""Retry configuration.
Each retry attempt will create a new Retry object with updated values, so
they can be safely reused.
Retries can be defined as a default for a pool::
Retries can be defined as a default for a pool:
.. code-block:: python
retries = Retry(connect=5, read=2, redirect=5)
http = PoolManager(retries=retries)
response = http.request('GET', 'http://example.com/')
response = http.request("GET", "https://example.com/")
Or per-request (which overrides the default for the pool)::
Or per-request (which overrides the default for the pool):
response = http.request('GET', 'http://example.com/', retries=Retry(10))
.. code-block:: python
Retries can be disabled by passing ``False``::
response = http.request("GET", "https://example.com/", retries=Retry(10))
response = http.request('GET', 'http://example.com/', retries=False)
Retries can be disabled by passing ``False``:
.. code-block:: python
response = http.request("GET", "https://example.com/", retries=False)
Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless
retries are disabled, in which case the causing exception will be raised.
@ -169,21 +122,16 @@ class Retry(object):
If ``total`` is not set, it's a good idea to set this to 0 to account
for unexpected edge cases and avoid infinite retry loops.
:param iterable allowed_methods:
:param Collection allowed_methods:
Set of uppercased HTTP method verbs that we should retry on.
By default, we only retry on methods which are considered to be
idempotent (multiple requests with the same parameters end with the
same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`.
Set to a ``False`` value to retry on any verb.
Set to a ``None`` value to retry on any verb.
.. warning::
Previously this parameter was named ``method_whitelist``, that
usage is deprecated in v1.26.0 and will be removed in v2.0.
:param iterable status_forcelist:
:param Collection status_forcelist:
A set of integer HTTP status codes that we should force a retry on.
A retry is initiated if the request method is in ``allowed_methods``
and the response status code is in ``status_forcelist``.
@ -195,13 +143,17 @@ class Retry(object):
(most errors are resolved immediately by a second try without a
delay). urllib3 will sleep for::
{backoff factor} * (2 ** ({number of total retries} - 1))
{backoff factor} * (2 ** ({number of previous retries}))
seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep
for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer
than :attr:`Retry.DEFAULT_BACKOFF_MAX`.
seconds. If `backoff_jitter` is non-zero, this sleep is extended by::
By default, backoff is disabled (set to 0).
random.uniform(0, {backoff jitter})
seconds. For example, if the backoff_factor is 0.1, then :func:`Retry.sleep` will
sleep for [0.0s, 0.2s, 0.4s, 0.8s, ...] between retries. No backoff will ever
be longer than `backoff_max`.
By default, backoff is disabled (factor set to 0).
:param bool raise_on_redirect: Whether, if the number of redirects is
exhausted, to raise a MaxRetryError, or to return a response with a
@ -220,7 +172,7 @@ class Retry(object):
Whether to respect Retry-After header on status codes defined as
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
:param iterable remove_headers_on_redirect:
:param Collection remove_headers_on_redirect:
Sequence of headers to remove from the request when a response
indicating a redirect is returned before firing off the redirected
request.
@ -237,48 +189,33 @@ class Retry(object):
#: Default headers to be used for ``remove_headers_on_redirect``
DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"])
#: Maximum backoff time.
#: Default maximum backoff time.
DEFAULT_BACKOFF_MAX = 120
# Backward compatibility; assigned outside of the class.
DEFAULT: typing.ClassVar[Retry]
def __init__(
self,
total=10,
connect=None,
read=None,
redirect=None,
status=None,
other=None,
allowed_methods=_Default,
status_forcelist=None,
backoff_factor=0,
raise_on_redirect=True,
raise_on_status=True,
history=None,
respect_retry_after_header=True,
remove_headers_on_redirect=_Default,
# TODO: Deprecated, remove in v2.0
method_whitelist=_Default,
):
if method_whitelist is not _Default:
if allowed_methods is not _Default:
raise ValueError(
"Using both 'allowed_methods' and "
"'method_whitelist' together is not allowed. "
"Instead only use 'allowed_methods'"
)
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
stacklevel=2,
)
allowed_methods = method_whitelist
if allowed_methods is _Default:
allowed_methods = self.DEFAULT_ALLOWED_METHODS
if remove_headers_on_redirect is _Default:
remove_headers_on_redirect = self.DEFAULT_REMOVE_HEADERS_ON_REDIRECT
total: bool | int | None = 10,
connect: int | None = None,
read: int | None = None,
redirect: bool | int | None = None,
status: int | None = None,
other: int | None = None,
allowed_methods: typing.Collection[str] | None = DEFAULT_ALLOWED_METHODS,
status_forcelist: typing.Collection[int] | None = None,
backoff_factor: float = 0,
backoff_max: float = DEFAULT_BACKOFF_MAX,
raise_on_redirect: bool = True,
raise_on_status: bool = True,
history: tuple[RequestHistory, ...] | None = None,
respect_retry_after_header: bool = True,
remove_headers_on_redirect: typing.Collection[
str
] = DEFAULT_REMOVE_HEADERS_ON_REDIRECT,
backoff_jitter: float = 0.0,
) -> None:
self.total = total
self.connect = connect
self.read = read
@ -293,15 +230,17 @@ class Retry(object):
self.status_forcelist = status_forcelist or set()
self.allowed_methods = allowed_methods
self.backoff_factor = backoff_factor
self.backoff_max = backoff_max
self.raise_on_redirect = raise_on_redirect
self.raise_on_status = raise_on_status
self.history = history or tuple()
self.history = history or ()
self.respect_retry_after_header = respect_retry_after_header
self.remove_headers_on_redirect = frozenset(
[h.lower() for h in remove_headers_on_redirect]
h.lower() for h in remove_headers_on_redirect
)
self.backoff_jitter = backoff_jitter
def new(self, **kw):
def new(self, **kw: typing.Any) -> Retry:
params = dict(
total=self.total,
connect=self.connect,
@ -309,36 +248,28 @@ class Retry(object):
redirect=self.redirect,
status=self.status,
other=self.other,
allowed_methods=self.allowed_methods,
status_forcelist=self.status_forcelist,
backoff_factor=self.backoff_factor,
backoff_max=self.backoff_max,
raise_on_redirect=self.raise_on_redirect,
raise_on_status=self.raise_on_status,
history=self.history,
remove_headers_on_redirect=self.remove_headers_on_redirect,
respect_retry_after_header=self.respect_retry_after_header,
backoff_jitter=self.backoff_jitter,
)
# TODO: If already given in **kw we use what's given to us
# If not given we need to figure out what to pass. We decide
# based on whether our class has the 'method_whitelist' property
# and if so we pass the deprecated 'method_whitelist' otherwise
# we use 'allowed_methods'. Remove in v2.0
if "method_whitelist" not in kw and "allowed_methods" not in kw:
if "method_whitelist" in self.__dict__:
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
params["method_whitelist"] = self.allowed_methods
else:
params["allowed_methods"] = self.allowed_methods
params.update(kw)
return type(self)(**params)
return type(self)(**params) # type: ignore[arg-type]
@classmethod
def from_int(cls, retries, redirect=True, default=None):
def from_int(
cls,
retries: Retry | bool | int | None,
redirect: bool | int | None = True,
default: Retry | bool | int | None = None,
) -> Retry:
"""Backwards-compatibility for the old retries format."""
if retries is None:
retries = default if default is not None else cls.DEFAULT
@ -351,7 +282,7 @@ class Retry(object):
log.debug("Converted retries value: %r -> %r", retries, new_retries)
return new_retries
def get_backoff_time(self):
def get_backoff_time(self) -> float:
"""Formula for computing the current backoff
:rtype: float
@ -366,32 +297,28 @@ class Retry(object):
return 0
backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1))
return min(self.DEFAULT_BACKOFF_MAX, backoff_value)
if self.backoff_jitter != 0.0:
backoff_value += random.random() * self.backoff_jitter
return float(max(0, min(self.backoff_max, backoff_value)))
def parse_retry_after(self, retry_after):
def parse_retry_after(self, retry_after: str) -> float:
seconds: float
# Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4
if re.match(r"^\s*[0-9]+\s*$", retry_after):
seconds = int(retry_after)
else:
retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None:
raise InvalidHeader("Invalid Retry-After header: %s" % retry_after)
if retry_date_tuple[9] is None: # Python 2
# Assume UTC if no timezone was specified
# On Python2.7, parsedate_tz returns None for a timezone offset
# instead of 0 if no timezone is given, where mktime_tz treats
# a None timezone offset as local time.
retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:]
raise InvalidHeader(f"Invalid Retry-After header: {retry_after}")
retry_date = email.utils.mktime_tz(retry_date_tuple)
seconds = retry_date - time.time()
if seconds < 0:
seconds = 0
seconds = max(seconds, 0)
return seconds
def get_retry_after(self, response):
def get_retry_after(self, response: BaseHTTPResponse) -> float | None:
"""Get the value of Retry-After in seconds."""
retry_after = response.headers.get("Retry-After")
@ -401,7 +328,7 @@ class Retry(object):
return self.parse_retry_after(retry_after)
def sleep_for_retry(self, response=None):
def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
retry_after = self.get_retry_after(response)
if retry_after:
time.sleep(retry_after)
@ -409,13 +336,13 @@ class Retry(object):
return False
def _sleep_backoff(self):
def _sleep_backoff(self) -> None:
backoff = self.get_backoff_time()
if backoff <= 0:
return
time.sleep(backoff)
def sleep(self, response=None):
def sleep(self, response: BaseHTTPResponse | None = None) -> None:
"""Sleep between retry attempts.
This method will respect a server's ``Retry-After`` response header
@ -431,7 +358,7 @@ class Retry(object):
self._sleep_backoff()
def _is_connection_error(self, err):
def _is_connection_error(self, err: Exception) -> bool:
"""Errors when we're fairly sure that the server did not receive the
request, so it should be safe to retry.
"""
@ -439,33 +366,23 @@ class Retry(object):
err = err.original_error
return isinstance(err, ConnectTimeoutError)
def _is_read_error(self, err):
def _is_read_error(self, err: Exception) -> bool:
"""Errors that occur after the request has been started, so we should
assume that the server began processing it.
"""
return isinstance(err, (ReadTimeoutError, ProtocolError))
def _is_method_retryable(self, method):
def _is_method_retryable(self, method: str) -> bool:
"""Checks if a given HTTP method should be retried upon, depending if
it is included in the allowed_methods
"""
# TODO: For now favor if the Retry implementation sets its own method_whitelist
# property outside of our constructor to avoid breaking custom implementations.
if "method_whitelist" in self.__dict__:
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
allowed_methods = self.method_whitelist
else:
allowed_methods = self.allowed_methods
if allowed_methods and method.upper() not in allowed_methods:
if self.allowed_methods and method.upper() not in self.allowed_methods:
return False
return True
def is_retry(self, method, status_code, has_retry_after=False):
def is_retry(
self, method: str, status_code: int, has_retry_after: bool = False
) -> bool:
"""Is this method/status code retryable? (Based on allowlists and control
variables such as the number of total retries to allow, whether to
respect the Retry-After header, whether this header is present, and
@ -478,24 +395,27 @@ class Retry(object):
if self.status_forcelist and status_code in self.status_forcelist:
return True
return (
return bool(
self.total
and self.respect_retry_after_header
and has_retry_after
and (status_code in self.RETRY_AFTER_STATUS_CODES)
)
def is_exhausted(self):
def is_exhausted(self) -> bool:
"""Are we out of retries?"""
retry_counts = (
self.total,
self.connect,
self.read,
self.redirect,
self.status,
self.other,
)
retry_counts = list(filter(None, retry_counts))
retry_counts = [
x
for x in (
self.total,
self.connect,
self.read,
self.redirect,
self.status,
self.other,
)
if x
]
if not retry_counts:
return False
@ -503,18 +423,18 @@ class Retry(object):
def increment(
self,
method=None,
url=None,
response=None,
error=None,
_pool=None,
_stacktrace=None,
):
method: str | None = None,
url: str | None = None,
response: BaseHTTPResponse | None = None,
error: Exception | None = None,
_pool: ConnectionPool | None = None,
_stacktrace: TracebackType | None = None,
) -> Retry:
"""Return a new Retry object with incremented retry counters.
:param response: A response object, or None, if the server did not
return a response.
:type response: :class:`~urllib3.response.HTTPResponse`
:type response: :class:`~urllib3.response.BaseHTTPResponse`
:param Exception error: An error encountered during the request, or
None if the response was received successfully.
@ -522,7 +442,7 @@ class Retry(object):
"""
if self.total is False and error:
# Disabled, indicate to re-raise the error.
raise six.reraise(type(error), error, _stacktrace)
raise reraise(type(error), error, _stacktrace)
total = self.total
if total is not None:
@ -540,14 +460,14 @@ class Retry(object):
if error and self._is_connection_error(error):
# Connect retry?
if connect is False:
raise six.reraise(type(error), error, _stacktrace)
raise reraise(type(error), error, _stacktrace)
elif connect is not None:
connect -= 1
elif error and self._is_read_error(error):
# Read retry?
if read is False or not self._is_method_retryable(method):
raise six.reraise(type(error), error, _stacktrace)
if read is False or method is None or not self._is_method_retryable(method):
raise reraise(type(error), error, _stacktrace)
elif read is not None:
read -= 1
@ -561,7 +481,9 @@ class Retry(object):
if redirect is not None:
redirect -= 1
cause = "too many redirects"
redirect_location = response.get_redirect_location()
response_redirect_location = response.get_redirect_location()
if response_redirect_location:
redirect_location = response_redirect_location
status = response.status
else:
@ -589,31 +511,18 @@ class Retry(object):
)
if new_retry.is_exhausted():
raise MaxRetryError(_pool, url, error or ResponseError(cause))
reason = error or ResponseError(cause)
raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]
log.debug("Incremented Retry for (url='%s'): %r", url, new_retry)
return new_retry
def __repr__(self):
def __repr__(self) -> str:
return (
"{cls.__name__}(total={self.total}, connect={self.connect}, "
"read={self.read}, redirect={self.redirect}, status={self.status})"
).format(cls=type(self), self=self)
def __getattr__(self, item):
if item == "method_whitelist":
# TODO: Remove this deprecated alias in v2.0
warnings.warn(
"Using 'method_whitelist' with Retry is deprecated and "
"will be removed in v2.0. Use 'allowed_methods' instead",
DeprecationWarning,
)
return self.allowed_methods
try:
return getattr(super(Retry, self), item)
except AttributeError:
return getattr(Retry, item)
f"{type(self).__name__}(total={self.total}, connect={self.connect}, "
f"read={self.read}, redirect={self.redirect}, status={self.status})"
)
# For backwards compatibility (equivalent to pre-v1.9):

View file

@ -1,185 +1,152 @@
from __future__ import absolute_import
from __future__ import annotations
import hmac
import os
import socket
import sys
import typing
import warnings
from binascii import hexlify, unhexlify
from binascii import unhexlify
from hashlib import md5, sha1, sha256
from ..exceptions import (
InsecurePlatformWarning,
ProxySchemeUnsupported,
SNIMissingWarning,
SSLError,
)
from ..packages import six
from .url import BRACELESS_IPV6_ADDRZ_RE, IPV4_RE
from ..exceptions import ProxySchemeUnsupported, SSLError
from .url import _BRACELESS_IPV6_ADDRZ_RE, _IPV4_RE
SSLContext = None
SSLTransport = None
HAS_SNI = False
HAS_NEVER_CHECK_COMMON_NAME = False
IS_PYOPENSSL = False
IS_SECURETRANSPORT = False
ALPN_PROTOCOLS = ["http/1.1"]
_TYPE_VERSION_INFO = typing.Tuple[int, int, int, str, int]
# Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
def _const_compare_digest_backport(a, b):
def _is_bpo_43522_fixed(
implementation_name: str,
version_info: _TYPE_VERSION_INFO,
pypy_version_info: _TYPE_VERSION_INFO | None,
) -> bool:
"""Return True for CPython 3.8.9+, 3.9.3+ or 3.10+ and PyPy 7.3.8+ where
setting SSLContext.hostname_checks_common_name to False works.
Outside of CPython and PyPy we don't know which implementations work
or not so we conservatively use our hostname matching as we know that works
on all implementations.
https://github.com/urllib3/urllib3/issues/2192#issuecomment-821832963
https://foss.heptapod.net/pypy/pypy/-/issues/3539
"""
Compare two digests of equal length in constant time.
The digests must be of type str/bytes.
Returns True if the digests match, and False otherwise.
"""
result = abs(len(a) - len(b))
for left, right in zip(bytearray(a), bytearray(b)):
result |= left ^ right
return result == 0
if implementation_name == "pypy":
# https://foss.heptapod.net/pypy/pypy/-/issues/3129
return pypy_version_info >= (7, 3, 8) and version_info >= (3, 8) # type: ignore[operator]
elif implementation_name == "cpython":
major_minor = version_info[:2]
micro = version_info[2]
return (
(major_minor == (3, 8) and micro >= 9)
or (major_minor == (3, 9) and micro >= 3)
or major_minor >= (3, 10)
)
else: # Defensive:
return False
_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport)
def _is_has_never_check_common_name_reliable(
openssl_version: str,
openssl_version_number: int,
implementation_name: str,
version_info: _TYPE_VERSION_INFO,
pypy_version_info: _TYPE_VERSION_INFO | None,
) -> bool:
# As of May 2023, all released versions of LibreSSL fail to reject certificates with
# only common names, see https://github.com/urllib3/urllib3/pull/3024
is_openssl = openssl_version.startswith("OpenSSL ")
# Before fixing OpenSSL issue #14579, the SSL_new() API was not copying hostflags
# like X509_CHECK_FLAG_NEVER_CHECK_SUBJECT, which tripped up CPython.
# https://github.com/openssl/openssl/issues/14579
# This was released in OpenSSL 1.1.1l+ (>=0x101010cf)
is_openssl_issue_14579_fixed = openssl_version_number >= 0x101010CF
try: # Test for SSL features
return is_openssl and (
is_openssl_issue_14579_fixed
or _is_bpo_43522_fixed(implementation_name, version_info, pypy_version_info)
)
if typing.TYPE_CHECKING:
from ssl import VerifyMode
from typing_extensions import Literal, TypedDict
from .ssltransport import SSLTransport as SSLTransportType
class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False):
subjectAltName: tuple[tuple[str, str], ...]
subject: tuple[tuple[tuple[str, str], ...], ...]
serialNumber: str
# Mapping from 'ssl.PROTOCOL_TLSX' to 'TLSVersion.X'
_SSL_VERSION_TO_TLS_VERSION: dict[int, int] = {}
try: # Do we have ssl at all?
import ssl
from ssl import CERT_REQUIRED, wrap_socket
except ImportError:
pass
try:
from ssl import HAS_SNI # Has SNI?
except ImportError:
pass
try:
from .ssltransport import SSLTransport
except ImportError:
pass
try: # Platform-specific: Python 3.6
from ssl import PROTOCOL_TLS
from ssl import ( # type: ignore[assignment]
CERT_REQUIRED,
HAS_NEVER_CHECK_COMMON_NAME,
OP_NO_COMPRESSION,
OP_NO_TICKET,
OPENSSL_VERSION,
OPENSSL_VERSION_NUMBER,
PROTOCOL_TLS,
PROTOCOL_TLS_CLIENT,
OP_NO_SSLv2,
OP_NO_SSLv3,
SSLContext,
TLSVersion,
)
PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
try:
from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2
# Setting SSLContext.hostname_checks_common_name = False didn't work before CPython
# 3.8.9, 3.9.3, and 3.10 (but OK on PyPy) or OpenSSL 1.1.1l+
if HAS_NEVER_CHECK_COMMON_NAME and not _is_has_never_check_common_name_reliable(
OPENSSL_VERSION,
OPENSSL_VERSION_NUMBER,
sys.implementation.name,
sys.version_info,
sys.pypy_version_info if sys.implementation.name == "pypy" else None, # type: ignore[attr-defined]
):
HAS_NEVER_CHECK_COMMON_NAME = False
try:
from ssl import PROTOCOL_TLS_CLIENT
except ImportError:
PROTOCOL_TLS_CLIENT = PROTOCOL_TLS
try:
from ssl import OP_NO_COMPRESSION, OP_NO_SSLv2, OP_NO_SSLv3
except ImportError:
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
OP_NO_COMPRESSION = 0x20000
try: # OP_NO_TICKET was added in Python 3.6
from ssl import OP_NO_TICKET
except ImportError:
OP_NO_TICKET = 0x4000
# A secure default.
# Sources for more information on TLS ciphers:
#
# - https://wiki.mozilla.org/Security/Server_Side_TLS
# - https://www.ssllabs.com/projects/best-practices/index.html
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
#
# The general intent is:
# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
# - prefer ECDHE over DHE for better performance,
# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and
# security,
# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common,
# - disable NULL authentication, MD5 MACs, DSS, and other
# insecure ciphers for security reasons.
# - NOTE: TLS 1.3 cipher suites are managed through a different interface
# not exposed by CPython (yet!) and are enabled by default if they're available.
DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
"ECDHE+CHACHA20",
"DHE+AESGCM",
"DHE+CHACHA20",
"ECDH+AESGCM",
"DH+AESGCM",
"ECDH+AES",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
"!DSS",
]
)
try:
from ssl import SSLContext # Modern SSL?
except ImportError:
class SSLContext(object): # Platform-specific: Python 2
def __init__(self, protocol_version):
self.protocol = protocol_version
# Use default values from a real SSLContext
self.check_hostname = False
self.verify_mode = ssl.CERT_NONE
self.ca_certs = None
self.options = 0
self.certfile = None
self.keyfile = None
self.ciphers = None
def load_cert_chain(self, certfile, keyfile):
self.certfile = certfile
self.keyfile = keyfile
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
self.ca_certs = cafile
if capath is not None:
raise SSLError("CA directories not supported in older Pythons")
if cadata is not None:
raise SSLError("CA data not supported in older Pythons")
def set_ciphers(self, cipher_suite):
self.ciphers = cipher_suite
def wrap_socket(self, socket, server_hostname=None, server_side=False):
warnings.warn(
"A true SSLContext object is not available. This prevents "
"urllib3 from configuring SSL appropriately and may cause "
"certain SSL connections to fail. You can upgrade to a newer "
"version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
"#ssl-warnings",
InsecurePlatformWarning,
# Need to be careful here in case old TLS versions get
# removed in future 'ssl' module implementations.
for attr in ("TLSv1", "TLSv1_1", "TLSv1_2"):
try:
_SSL_VERSION_TO_TLS_VERSION[getattr(ssl, f"PROTOCOL_{attr}")] = getattr(
TLSVersion, attr
)
kwargs = {
"keyfile": self.keyfile,
"certfile": self.certfile,
"ca_certs": self.ca_certs,
"cert_reqs": self.verify_mode,
"ssl_version": self.protocol,
"server_side": server_side,
}
return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
except AttributeError: # Defensive:
continue
from .ssltransport import SSLTransport # type: ignore[assignment]
except ImportError:
OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment]
OP_NO_TICKET = 0x4000 # type: ignore[assignment]
OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment]
OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment]
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment]
PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment]
def assert_fingerprint(cert, fingerprint):
_TYPE_PEER_CERT_RET = typing.Union["_TYPE_PEER_CERT_RET_DICT", bytes, None]
def assert_fingerprint(cert: bytes | None, fingerprint: str) -> None:
"""
Checks if given fingerprint matches the supplied certificate.
@ -189,26 +156,27 @@ def assert_fingerprint(cert, fingerprint):
Fingerprint as string of hexdigits, can be interspersed by colons.
"""
if cert is None:
raise SSLError("No certificate for the peer.")
fingerprint = fingerprint.replace(":", "").lower()
digest_length = len(fingerprint)
hashfunc = HASHFUNC_MAP.get(digest_length)
if not hashfunc:
raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint))
raise SSLError(f"Fingerprint of invalid length: {fingerprint}")
# We need encode() here for py32; works on py2 and p33.
fingerprint_bytes = unhexlify(fingerprint.encode())
cert_digest = hashfunc(cert).digest()
if not _const_compare_digest(cert_digest, fingerprint_bytes):
if not hmac.compare_digest(cert_digest, fingerprint_bytes):
raise SSLError(
'Fingerprints did not match. Expected "{0}", got "{1}".'.format(
fingerprint, hexlify(cert_digest)
)
f'Fingerprints did not match. Expected "{fingerprint}", got "{cert_digest.hex()}"'
)
def resolve_cert_reqs(candidate):
def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode:
"""
Resolves the argument to a numeric constant, which can be passed to
the wrap_socket function/method from the ssl module.
@ -226,12 +194,12 @@ def resolve_cert_reqs(candidate):
res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, "CERT_" + candidate)
return res
return res # type: ignore[no-any-return]
return candidate
return candidate # type: ignore[return-value]
def resolve_ssl_version(candidate):
def resolve_ssl_version(candidate: None | int | str) -> int:
"""
like resolve_cert_reqs
"""
@ -242,35 +210,33 @@ def resolve_ssl_version(candidate):
res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, "PROTOCOL_" + candidate)
return res
return typing.cast(int, res)
return candidate
def create_urllib3_context(
ssl_version=None, cert_reqs=None, options=None, ciphers=None
):
"""All arguments have the same meaning as ``ssl_wrap_socket``.
By default, this function does a lot of the same work that
``ssl.create_default_context`` does on Python 3.4+. It:
- Disables SSLv2, SSLv3, and compression
- Sets a restricted set of server ciphers
If you wish to enable SSLv3, you can do::
from urllib3.util import ssl_
context = ssl_.create_urllib3_context()
context.options &= ~ssl_.OP_NO_SSLv3
You can do the same to enable compression (substituting ``COMPRESSION``
for ``SSLv3`` in the last line above).
ssl_version: int | None = None,
cert_reqs: int | None = None,
options: int | None = None,
ciphers: str | None = None,
ssl_minimum_version: int | None = None,
ssl_maximum_version: int | None = None,
) -> ssl.SSLContext:
"""Creates and configures an :class:`ssl.SSLContext` instance for use with urllib3.
:param ssl_version:
The desired protocol version to use. This will default to
PROTOCOL_SSLv23 which will negotiate the highest protocol that both
the server and your installation of OpenSSL support.
This parameter is deprecated instead use 'ssl_minimum_version'.
:param ssl_minimum_version:
The minimum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
:param ssl_maximum_version:
The maximum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value.
Not recommended to set to anything other than 'ssl.TLSVersion.MAXIMUM_SUPPORTED' which is the
default value.
:param cert_reqs:
Whether to require the certificate verification. This defaults to
``ssl.CERT_REQUIRED``.
@ -278,18 +244,60 @@ def create_urllib3_context(
Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``,
``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``.
:param ciphers:
Which cipher suites to allow the server to select.
Which cipher suites to allow the server to select. Defaults to either system configured
ciphers if OpenSSL 1.1.1+, otherwise uses a secure default set of ciphers.
:returns:
Constructed SSLContext object with specified options
:rtype: SSLContext
"""
# PROTOCOL_TLS is deprecated in Python 3.10
if not ssl_version or ssl_version == PROTOCOL_TLS:
ssl_version = PROTOCOL_TLS_CLIENT
if SSLContext is None:
raise TypeError("Can't create an SSLContext object without an ssl module")
context = SSLContext(ssl_version)
# This means 'ssl_version' was specified as an exact value.
if ssl_version not in (None, PROTOCOL_TLS, PROTOCOL_TLS_CLIENT):
# Disallow setting 'ssl_version' and 'ssl_minimum|maximum_version'
# to avoid conflicts.
if ssl_minimum_version is not None or ssl_maximum_version is not None:
raise ValueError(
"Can't specify both 'ssl_version' and either "
"'ssl_minimum_version' or 'ssl_maximum_version'"
)
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
# 'ssl_version' is deprecated and will be removed in the future.
else:
# Use 'ssl_minimum_version' and 'ssl_maximum_version' instead.
ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MINIMUM_SUPPORTED
)
ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MAXIMUM_SUPPORTED
)
# This warning message is pushing users to use 'ssl_minimum_version'
# instead of both min/max. Best practice is to only set the minimum version and
# keep the maximum version to be it's default value: 'TLSVersion.MAXIMUM_SUPPORTED'
warnings.warn(
"'ssl_version' option is deprecated and will be "
"removed in urllib3 v2.1.0. Instead use 'ssl_minimum_version'",
category=DeprecationWarning,
stacklevel=2,
)
# PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT
context = SSLContext(PROTOCOL_TLS_CLIENT)
if ssl_minimum_version is not None:
context.minimum_version = ssl_minimum_version
else: # Python <3.10 defaults to 'MINIMUM_SUPPORTED' so explicitly set TLSv1.2 here
context.minimum_version = TLSVersion.TLSv1_2
if ssl_maximum_version is not None:
context.maximum_version = ssl_maximum_version
# Unless we're given ciphers defer to either system ciphers in
# the case of OpenSSL 1.1.1+ or use our own secure default ciphers.
if ciphers:
context.set_ciphers(ciphers)
# Setting the default here, as we may have no ssl module on import
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
@ -322,26 +330,23 @@ def create_urllib3_context(
) is not None:
context.post_handshake_auth = True
def disable_check_hostname():
if (
getattr(context, "check_hostname", None) is not None
): # Platform-specific: Python 3.2
# We do our own verification, including fingerprints and alternative
# hostnames. So disable it here
context.check_hostname = False
# The order of the below lines setting verify_mode and check_hostname
# matter due to safe-guards SSLContext has to prevent an SSLContext with
# check_hostname=True, verify_mode=NONE/OPTIONAL. This is made even more
# complex because we don't know whether PROTOCOL_TLS_CLIENT will be used
# or not so we don't know the initial state of the freshly created SSLContext.
if cert_reqs == ssl.CERT_REQUIRED:
# check_hostname=True, verify_mode=NONE/OPTIONAL.
# We always set 'check_hostname=False' for pyOpenSSL so we rely on our own
# 'ssl.match_hostname()' implementation.
if cert_reqs == ssl.CERT_REQUIRED and not IS_PYOPENSSL:
context.verify_mode = cert_reqs
disable_check_hostname()
context.check_hostname = True
else:
disable_check_hostname()
context.check_hostname = False
context.verify_mode = cert_reqs
try:
context.hostname_checks_common_name = False
except AttributeError: # Defensive: for CPython < 3.8.9 and 3.9.3; for PyPy < 7.3.8
pass
# Enable logging of TLS session keys via defacto standard environment variable
# 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values.
if hasattr(context, "keylog_filename"):
@ -352,21 +357,59 @@ def create_urllib3_context(
return context
@typing.overload
def ssl_wrap_socket(
sock,
keyfile=None,
certfile=None,
cert_reqs=None,
ca_certs=None,
server_hostname=None,
ssl_version=None,
ciphers=None,
ssl_context=None,
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
tls_in_tls=False,
):
sock: socket.socket,
keyfile: str | None = ...,
certfile: str | None = ...,
cert_reqs: int | None = ...,
ca_certs: str | None = ...,
server_hostname: str | None = ...,
ssl_version: int | None = ...,
ciphers: str | None = ...,
ssl_context: ssl.SSLContext | None = ...,
ca_cert_dir: str | None = ...,
key_password: str | None = ...,
ca_cert_data: None | str | bytes = ...,
tls_in_tls: Literal[False] = ...,
) -> ssl.SSLSocket:
...
@typing.overload
def ssl_wrap_socket(
sock: socket.socket,
keyfile: str | None = ...,
certfile: str | None = ...,
cert_reqs: int | None = ...,
ca_certs: str | None = ...,
server_hostname: str | None = ...,
ssl_version: int | None = ...,
ciphers: str | None = ...,
ssl_context: ssl.SSLContext | None = ...,
ca_cert_dir: str | None = ...,
key_password: str | None = ...,
ca_cert_data: None | str | bytes = ...,
tls_in_tls: bool = ...,
) -> ssl.SSLSocket | SSLTransportType:
...
def ssl_wrap_socket(
sock: socket.socket,
keyfile: str | None = None,
certfile: str | None = None,
cert_reqs: int | None = None,
ca_certs: str | None = None,
server_hostname: str | None = None,
ssl_version: int | None = None,
ciphers: str | None = None,
ssl_context: ssl.SSLContext | None = None,
ca_cert_dir: str | None = None,
key_password: str | None = None,
ca_cert_data: None | str | bytes = None,
tls_in_tls: bool = False,
) -> ssl.SSLSocket | SSLTransportType:
"""
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
the same meaning as they do when using :func:`ssl.wrap_socket`.
@ -392,19 +435,18 @@ def ssl_wrap_socket(
"""
context = ssl_context
if context is None:
# Note: This branch of code and all the variables in it are no longer
# used by urllib3 itself. We should consider deprecating and removing
# this code.
# Note: This branch of code and all the variables in it are only used in tests.
# We should consider deprecating and removing this code.
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
if ca_certs or ca_cert_dir or ca_cert_data:
try:
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
except (IOError, OSError) as e:
raise SSLError(e)
except OSError as e:
raise SSLError(e) from e
elif ssl_context is None and hasattr(context, "load_default_certs"):
# try to load OS default certs; works well on Windows (require Python3.4+)
# try to load OS default certs; works well on Windows.
context.load_default_certs()
# Attempt to detect if we get the goofy behavior of the
@ -420,56 +462,30 @@ def ssl_wrap_socket(
context.load_cert_chain(certfile, keyfile, key_password)
try:
if hasattr(context, "set_alpn_protocols"):
context.set_alpn_protocols(ALPN_PROTOCOLS)
context.set_alpn_protocols(ALPN_PROTOCOLS)
except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols
pass
# If we detect server_hostname is an IP address then the SNI
# extension should not be used according to RFC3546 Section 3.1
use_sni_hostname = server_hostname and not is_ipaddress(server_hostname)
# SecureTransport uses server_hostname in certificate verification.
send_sni = (use_sni_hostname and HAS_SNI) or (
IS_SECURETRANSPORT and server_hostname
)
# Do not warn the user if server_hostname is an invalid SNI hostname.
if not HAS_SNI and use_sni_hostname:
warnings.warn(
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
"This may cause the server to present an incorrect TLS "
"certificate, which can cause validation failures. You can upgrade to "
"a newer version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html"
"#ssl-warnings",
SNIMissingWarning,
)
if send_sni:
ssl_sock = _ssl_wrap_socket_impl(
sock, context, tls_in_tls, server_hostname=server_hostname
)
else:
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls)
ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname)
return ssl_sock
def is_ipaddress(hostname):
def is_ipaddress(hostname: str | bytes) -> bool:
"""Detects whether the hostname given is an IPv4 or IPv6 address.
Also detects IPv6 addresses with Zone IDs.
:param str hostname: Hostname to examine.
:return: True if the hostname is an IP address, False otherwise.
"""
if not six.PY2 and isinstance(hostname, bytes):
if isinstance(hostname, bytes):
# IDN A-label bytes are ASCII compatible.
hostname = hostname.decode("ascii")
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname))
return bool(_IPV4_RE.match(hostname) or _BRACELESS_IPV6_ADDRZ_RE.match(hostname))
def _is_key_file_encrypted(key_file):
def _is_key_file_encrypted(key_file: str) -> bool:
"""Detects if a key file is encrypted or not."""
with open(key_file, "r") as f:
with open(key_file) as f:
for line in f:
# Look for Proc-Type: 4,ENCRYPTED
if "ENCRYPTED" in line:
@ -478,7 +494,12 @@ def _is_key_file_encrypted(key_file):
return False
def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None):
def _ssl_wrap_socket_impl(
sock: socket.socket,
ssl_context: ssl.SSLContext,
tls_in_tls: bool,
server_hostname: str | None = None,
) -> ssl.SSLSocket | SSLTransportType:
if tls_in_tls:
if not SSLTransport:
# Import error, ssl is not available.
@ -489,7 +510,4 @@ def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None):
SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context)
return SSLTransport(sock, ssl_context, server_hostname)
if server_hostname:
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)
else:
return ssl_context.wrap_socket(sock)
return ssl_context.wrap_socket(sock, server_hostname=server_hostname)

View file

@ -1,19 +1,18 @@
"""The match_hostname() function from Python 3.3.3, essential when using SSL."""
"""The match_hostname() function from Python 3.5, essential when using SSL."""
# Note: This file is under the PSF license as the code comes from the python
# stdlib. http://docs.python.org/3/license.html
# It is modified to remove commonName support.
from __future__ import annotations
import ipaddress
import re
import sys
import typing
from ipaddress import IPv4Address, IPv6Address
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the
# system, use it to handle IPAddress ServerAltnames (this was added in
# python-3.5) otherwise only do DNS matching. This allows
# util.ssl_match_hostname to continue to be used in Python 2.7.
try:
import ipaddress
except ImportError:
ipaddress = None
if typing.TYPE_CHECKING:
from .ssl_ import _TYPE_PEER_CERT_RET_DICT
__version__ = "3.5.0.1"
@ -22,7 +21,9 @@ class CertificateError(ValueError):
pass
def _dnsname_match(dn, hostname, max_wildcards=1):
def _dnsname_match(
dn: typing.Any, hostname: str, max_wildcards: int = 1
) -> typing.Match[str] | None | bool:
"""Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3
@ -49,7 +50,7 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# speed up common case w/o wildcards
if not wildcards:
return dn.lower() == hostname.lower()
return bool(dn.lower() == hostname.lower())
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
@ -76,26 +77,26 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
return pat.match(hostname)
def _to_unicode(obj):
if isinstance(obj, str) and sys.version_info < (3,):
# ignored flake8 # F821 to support python 2.7 function
obj = unicode(obj, encoding="ascii", errors="strict") # noqa: F821
return obj
def _ipaddress_match(ipname, host_ip):
def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool:
"""Exact matching of IP addresses.
RFC 6125 explicitly doesn't define an algorithm for this
(section 1.7.2 - "Out of Scope").
RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded
bytes of the IP address. An IP version 4 address is 4 octets, and an IP
version 6 address is 16 octets. [...] A reference identity of type IP-ID
matches if the address is identical to an iPAddress value of the
subjectAltName extension of the certificate."
"""
# OpenSSL may add a trailing newline to a subjectAltName's IP address
# Divergence from upstream: ipaddress can't handle byte str
ip = ipaddress.ip_address(_to_unicode(ipname).rstrip())
return ip == host_ip
ip = ipaddress.ip_address(ipname.rstrip())
return bool(ip.packed == host_ip.packed)
def match_hostname(cert, hostname):
def match_hostname(
cert: _TYPE_PEER_CERT_RET_DICT | None,
hostname: str,
hostname_checks_common_name: bool = False,
) -> None:
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed, but IP addresses are not accepted for *hostname*.
@ -111,21 +112,22 @@ def match_hostname(cert, hostname):
)
try:
# Divergence from upstream: ipaddress can't handle byte str
host_ip = ipaddress.ip_address(_to_unicode(hostname))
except (UnicodeError, ValueError):
# ValueError: Not an IP address (common case)
# UnicodeError: Divergence from upstream: Have to deal with ipaddress not taking
# byte strings. addresses should be all ascii, so we consider it not
# an ipaddress in this case
#
# The ipaddress module shipped with Python < 3.9 does not support
# scoped IPv6 addresses so we unconditionally strip the Zone IDs for
# now. Once we drop support for Python 3.9 we can remove this branch.
if "%" in hostname:
host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")])
else:
host_ip = ipaddress.ip_address(hostname)
except ValueError:
# Not an IP address (common case)
host_ip = None
except AttributeError:
# Divergence from upstream: Make ipaddress library optional
if ipaddress is None:
host_ip = None
else: # Defensive
raise
dnsnames = []
san = cert.get("subjectAltName", ())
san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ())
key: str
value: str
for key, value in san:
if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname):
@ -135,25 +137,23 @@ def match_hostname(cert, hostname):
if host_ip is not None and _ipaddress_match(value, host_ip):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
# We only check 'commonName' if it's enabled and we're not verifying
# an IP address. IP addresses aren't valid within 'commonName'.
if hostname_checks_common_name and host_ip is None and not dnsnames:
for sub in cert.get("subject", ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == "commonName":
if _dnsname_match(value, hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise CertificateError(
"hostname %r "
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1:
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}")
else:
raise CertificateError(
"no appropriate commonName or subjectAltName fields were found"
)
raise CertificateError("no appropriate subjectAltName fields were found")

View file

@ -1,9 +1,21 @@
from __future__ import annotations
import io
import socket
import ssl
import typing
from ..exceptions import ProxySchemeUnsupported
from ..packages import six
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport")
_WriteBuffer = typing.Union[bytearray, memoryview]
_ReturnValue = typing.TypeVar("_ReturnValue")
SSL_BLOCKSIZE = 16384
@ -20,7 +32,7 @@ class SSLTransport:
"""
@staticmethod
def _validate_ssl_context_for_tls_in_tls(ssl_context):
def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None:
"""
Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
for TLS in TLS.
@ -30,20 +42,18 @@ class SSLTransport:
"""
if not hasattr(ssl_context, "wrap_bio"):
if six.PY2:
raise ProxySchemeUnsupported(
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
"supported on Python 2"
)
else:
raise ProxySchemeUnsupported(
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
"available on non-native SSLContext"
)
raise ProxySchemeUnsupported(
"TLS in TLS requires SSLContext.wrap_bio() which isn't "
"available on non-native SSLContext"
)
def __init__(
self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
):
self,
socket: socket.socket,
ssl_context: ssl.SSLContext,
server_hostname: str | None = None,
suppress_ragged_eofs: bool = True,
) -> None:
"""
Create an SSLTransport around socket using the provided ssl_context.
"""
@ -60,33 +70,36 @@ class SSLTransport:
# Perform initial handshake.
self._ssl_io_loop(self.sslobj.do_handshake)
def __enter__(self):
def __enter__(self: _SelfT) -> _SelfT:
return self
def __exit__(self, *_):
def __exit__(self, *_: typing.Any) -> None:
self.close()
def fileno(self):
def fileno(self) -> int:
return self.socket.fileno()
def read(self, len=1024, buffer=None):
def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes:
return self._wrap_ssl_read(len, buffer)
def recv(self, len=1024, flags=0):
def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes:
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv")
return self._wrap_ssl_read(len)
return self._wrap_ssl_read(buflen)
def recv_into(self, buffer, nbytes=None, flags=0):
def recv_into(
self,
buffer: _WriteBuffer,
nbytes: int | None = None,
flags: int = 0,
) -> None | int | bytes:
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv_into")
if buffer and (nbytes is None):
if nbytes is None:
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
return self.read(nbytes, buffer)
def sendall(self, data, flags=0):
def sendall(self, data: bytes, flags: int = 0) -> None:
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to sendall")
count = 0
@ -96,15 +109,20 @@ class SSLTransport:
v = self.send(byte_view[count:])
count += v
def send(self, data, flags=0):
def send(self, data: bytes, flags: int = 0) -> int:
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to send")
response = self._ssl_io_loop(self.sslobj.write, data)
return response
return self._ssl_io_loop(self.sslobj.write, data)
def makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
self,
mode: str,
buffering: int | None = None,
*,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
) -> typing.BinaryIO | typing.TextIO | socket.SocketIO:
"""
Python's httpclient uses makefile and buffered io when reading HTTP
messages and we need to support it.
@ -113,7 +131,7 @@ class SSLTransport:
changes to point to the socket directly.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)")
writing = "w" in mode
reading = "r" in mode or not writing
@ -124,8 +142,8 @@ class SSLTransport:
rawmode += "r"
if writing:
rawmode += "w"
raw = socket.SocketIO(self, rawmode)
self.socket._io_refs += 1
raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type]
self.socket._io_refs += 1 # type: ignore[attr-defined]
if buffering is None:
buffering = -1
if buffering < 0:
@ -134,8 +152,9 @@ class SSLTransport:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
buffer: typing.BinaryIO
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment]
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
@ -144,46 +163,56 @@ class SSLTransport:
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
text.mode = mode # type: ignore[misc]
return text
def unwrap(self):
def unwrap(self) -> None:
self._ssl_io_loop(self.sslobj.unwrap)
def close(self):
def close(self) -> None:
self.socket.close()
def getpeercert(self, binary_form=False):
return self.sslobj.getpeercert(binary_form)
@typing.overload
def getpeercert(
self, binary_form: Literal[False] = ...
) -> _TYPE_PEER_CERT_RET_DICT | None:
...
def version(self):
@typing.overload
def getpeercert(self, binary_form: Literal[True]) -> bytes | None:
...
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
return self.sslobj.getpeercert(binary_form) # type: ignore[return-value]
def version(self) -> str | None:
return self.sslobj.version()
def cipher(self):
def cipher(self) -> tuple[str, str, int] | None:
return self.sslobj.cipher()
def selected_alpn_protocol(self):
def selected_alpn_protocol(self) -> str | None:
return self.sslobj.selected_alpn_protocol()
def selected_npn_protocol(self):
def selected_npn_protocol(self) -> str | None:
return self.sslobj.selected_npn_protocol()
def shared_ciphers(self):
def shared_ciphers(self) -> list[tuple[str, str, int]] | None:
return self.sslobj.shared_ciphers()
def compression(self):
def compression(self) -> str | None:
return self.sslobj.compression()
def settimeout(self, value):
def settimeout(self, value: float | None) -> None:
self.socket.settimeout(value)
def gettimeout(self):
def gettimeout(self) -> float | None:
return self.socket.gettimeout()
def _decref_socketios(self):
self.socket._decref_socketios()
def _decref_socketios(self) -> None:
self.socket._decref_socketios() # type: ignore[attr-defined]
def _wrap_ssl_read(self, len, buffer=None):
def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes:
try:
return self._ssl_io_loop(self.sslobj.read, len, buffer)
except ssl.SSLError as e:
@ -192,7 +221,32 @@ class SSLTransport:
else:
raise
def _ssl_io_loop(self, func, *args):
# func is sslobj.do_handshake or sslobj.unwrap
@typing.overload
def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None:
...
# func is sslobj.write, arg1 is data
@typing.overload
def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int:
...
# func is sslobj.read, arg1 is len, arg2 is buffer
@typing.overload
def _ssl_io_loop(
self,
func: typing.Callable[[int, bytearray | None], bytes],
arg1: int,
arg2: bytearray | None,
) -> bytes:
...
def _ssl_io_loop(
self,
func: typing.Callable[..., _ReturnValue],
arg1: None | bytes | int = None,
arg2: bytearray | None = None,
) -> _ReturnValue:
"""Performs an I/O loop between incoming/outgoing and the socket."""
should_loop = True
ret = None
@ -200,7 +254,12 @@ class SSLTransport:
while should_loop:
errno = None
try:
ret = func(*args)
if arg1 is None and arg2 is None:
ret = func()
elif arg2 is None:
ret = func(arg1)
else:
ret = func(arg1, arg2)
except ssl.SSLError as e:
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
# WANT_READ, and WANT_WRITE are expected, others are not.
@ -218,4 +277,4 @@ class SSLTransport:
self.incoming.write(buf)
else:
self.incoming.write_eof()
return ret
return typing.cast(_ReturnValue, ret)

View file

@ -1,44 +1,56 @@
from __future__ import absolute_import
from __future__ import annotations
import time
# The default socket timeout, used by httplib to indicate that no timeout was; specified by the user
from socket import _GLOBAL_DEFAULT_TIMEOUT, getdefaulttimeout
import typing
from enum import Enum
from socket import getdefaulttimeout
from ..exceptions import TimeoutStateError
# A sentinel value to indicate that no timeout was specified by the user in
# urllib3
_Default = object()
if typing.TYPE_CHECKING:
from typing_extensions import Final
# Use time.monotonic if available.
current_time = getattr(time, "monotonic", time.time)
class _TYPE_DEFAULT(Enum):
# This value should never be passed to socket.settimeout() so for safety we use a -1.
# socket.settimout() raises a ValueError for negative values.
token = -1
class Timeout(object):
_DEFAULT_TIMEOUT: Final[_TYPE_DEFAULT] = _TYPE_DEFAULT.token
_TYPE_TIMEOUT = typing.Optional[typing.Union[float, _TYPE_DEFAULT]]
class Timeout:
"""Timeout configuration.
Timeouts can be defined as a default for a pool:
.. code-block:: python
timeout = Timeout(connect=2.0, read=7.0)
http = PoolManager(timeout=timeout)
response = http.request('GET', 'http://example.com/')
import urllib3
timeout = urllib3.util.Timeout(connect=2.0, read=7.0)
http = urllib3.PoolManager(timeout=timeout)
resp = http.request("GET", "https://example.com/")
print(resp.status)
Or per-request (which overrides the default for the pool):
.. code-block:: python
response = http.request('GET', 'http://example.com/', timeout=Timeout(10))
response = http.request("GET", "https://example.com/", timeout=Timeout(10))
Timeouts can be disabled by setting all the parameters to ``None``:
.. code-block:: python
no_timeout = Timeout(connect=None, read=None)
response = http.request('GET', 'http://example.com/, timeout=no_timeout)
response = http.request("GET", "https://example.com/", timeout=no_timeout)
:param total:
@ -96,31 +108,31 @@ class Timeout(object):
"""
#: A sentinel object representing the default timeout value
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
DEFAULT_TIMEOUT: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT
def __init__(self, total=None, connect=_Default, read=_Default):
def __init__(
self,
total: _TYPE_TIMEOUT = None,
connect: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
read: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
) -> None:
self._connect = self._validate_timeout(connect, "connect")
self._read = self._validate_timeout(read, "read")
self.total = self._validate_timeout(total, "total")
self._start_connect = None
self._start_connect: float | None = None
def __repr__(self):
return "%s(connect=%r, read=%r, total=%r)" % (
type(self).__name__,
self._connect,
self._read,
self.total,
)
def __repr__(self) -> str:
return f"{type(self).__name__}(connect={self._connect!r}, read={self._read!r}, total={self.total!r})"
# __str__ provided for backwards compatibility
__str__ = __repr__
@classmethod
def resolve_default_timeout(cls, timeout):
return getdefaulttimeout() if timeout is cls.DEFAULT_TIMEOUT else timeout
@staticmethod
def resolve_default_timeout(timeout: _TYPE_TIMEOUT) -> float | None:
return getdefaulttimeout() if timeout is _DEFAULT_TIMEOUT else timeout
@classmethod
def _validate_timeout(cls, value, name):
def _validate_timeout(cls, value: _TYPE_TIMEOUT, name: str) -> _TYPE_TIMEOUT:
"""Check that a timeout attribute is valid.
:param value: The timeout value to validate
@ -130,10 +142,7 @@ class Timeout(object):
:raises ValueError: If it is a numeric value less than or equal to
zero, or the type is not an integer, float, or None.
"""
if value is _Default:
return cls.DEFAULT_TIMEOUT
if value is None or value is cls.DEFAULT_TIMEOUT:
if value is None or value is _DEFAULT_TIMEOUT:
return value
if isinstance(value, bool):
@ -147,7 +156,7 @@ class Timeout(object):
raise ValueError(
"Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value)
)
) from None
try:
if value <= 0:
@ -157,16 +166,15 @@ class Timeout(object):
"than or equal to 0." % (name, value)
)
except TypeError:
# Python 3
raise ValueError(
"Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value)
)
) from None
return value
@classmethod
def from_float(cls, timeout):
def from_float(cls, timeout: _TYPE_TIMEOUT) -> Timeout:
"""Create a new Timeout from a legacy timeout value.
The timeout value used by httplib.py sets the same timeout on the
@ -175,13 +183,13 @@ class Timeout(object):
passed to this function.
:param timeout: The legacy timeout value.
:type timeout: integer, float, sentinel default object, or None
:type timeout: integer, float, :attr:`urllib3.util.Timeout.DEFAULT_TIMEOUT`, or None
:return: Timeout object
:rtype: :class:`Timeout`
"""
return Timeout(read=timeout, connect=timeout)
def clone(self):
def clone(self) -> Timeout:
"""Create a copy of the timeout object
Timeout properties are stored per-pool but each request needs a fresh
@ -195,7 +203,7 @@ class Timeout(object):
# detect the user default.
return Timeout(connect=self._connect, read=self._read, total=self.total)
def start_connect(self):
def start_connect(self) -> float:
"""Start the timeout clock, used during a connect() attempt
:raises urllib3.exceptions.TimeoutStateError: if you attempt
@ -203,10 +211,10 @@ class Timeout(object):
"""
if self._start_connect is not None:
raise TimeoutStateError("Timeout timer has already been started.")
self._start_connect = current_time()
self._start_connect = time.monotonic()
return self._start_connect
def get_connect_duration(self):
def get_connect_duration(self) -> float:
"""Gets the time elapsed since the call to :meth:`start_connect`.
:return: Elapsed time in seconds.
@ -218,10 +226,10 @@ class Timeout(object):
raise TimeoutStateError(
"Can't get connect duration for timer that has not started."
)
return current_time() - self._start_connect
return time.monotonic() - self._start_connect
@property
def connect_timeout(self):
def connect_timeout(self) -> _TYPE_TIMEOUT:
"""Get the value to use when setting a connection timeout.
This will be a positive float or integer, the value None
@ -233,13 +241,13 @@ class Timeout(object):
if self.total is None:
return self._connect
if self._connect is None or self._connect is self.DEFAULT_TIMEOUT:
if self._connect is None or self._connect is _DEFAULT_TIMEOUT:
return self.total
return min(self._connect, self.total)
return min(self._connect, self.total) # type: ignore[type-var]
@property
def read_timeout(self):
def read_timeout(self) -> float | None:
"""Get the value for the read timeout.
This assumes some time has elapsed in the connection timeout and
@ -251,21 +259,21 @@ class Timeout(object):
raised.
:return: Value to use for the read timeout.
:rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None
:rtype: int, float or None
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
has not yet been called on this object.
"""
if (
self.total is not None
and self.total is not self.DEFAULT_TIMEOUT
and self.total is not _DEFAULT_TIMEOUT
and self._read is not None
and self._read is not self.DEFAULT_TIMEOUT
and self._read is not _DEFAULT_TIMEOUT
):
# In case the connect timeout has not yet been established.
if self._start_connect is None:
return self._read
return max(0, min(self.total - self.get_connect_duration(), self._read))
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
elif self.total is not None and self.total is not _DEFAULT_TIMEOUT:
return max(0, self.total - self.get_connect_duration())
else:
return self._read
return self.resolve_default_timeout(self._read)

View file

@ -1,22 +1,20 @@
from __future__ import absolute_import
from __future__ import annotations
import re
from collections import namedtuple
import typing
from ..exceptions import LocationParseError
from ..packages import six
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
from .util import to_str
# We only want to normalize urls with an HTTP(S) scheme.
# urllib3 infers URLs without a scheme (None) to be http.
NORMALIZABLE_SCHEMES = ("http", "https", None)
_NORMALIZABLE_SCHEMES = ("http", "https", None)
# Almost all of these patterns were derived from the
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
URI_RE = re.compile(
_PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
_SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
_URI_RE = re.compile(
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
r"(?://([^\\/?#]*))?"
r"([^?#]*)"
@ -25,10 +23,10 @@ URI_RE = re.compile(
re.UNICODE | re.DOTALL,
)
IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
HEX_PAT = "[0-9A-Fa-f]{1,4}"
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT)
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT}
_IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
_HEX_PAT = "[0-9A-Fa-f]{1,4}"
_LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=_HEX_PAT, ipv4=_IPV4_PAT)
_subs = {"hex": _HEX_PAT, "ls32": _LS32_PAT}
_variations = [
# 6( h16 ":" ) ls32
"(?:%(hex)s:){6}%(ls32)s",
@ -50,69 +48,78 @@ _variations = [
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
]
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~"
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]"
REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
_UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~"
_IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
_ZONE_ID_PAT = "(?:%25|%)(?:[" + _UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
_IPV6_ADDRZ_PAT = r"\[" + _IPV6_PAT + r"(?:" + _ZONE_ID_PAT + r")?\]"
_REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
_TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
IPV4_RE = re.compile("^" + IPV4_PAT + "$")
IPV6_RE = re.compile("^" + IPV6_PAT + "$")
IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$")
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$")
ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$")
_IPV4_RE = re.compile("^" + _IPV4_PAT + "$")
_IPV6_RE = re.compile("^" + _IPV6_PAT + "$")
_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT + "$")
_BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT[2:-2] + "$")
_ZONE_ID_RE = re.compile("(" + _ZONE_ID_PAT + r")\]$")
_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % (
REG_NAME_PAT,
IPV4_PAT,
IPV6_ADDRZ_PAT,
_REG_NAME_PAT,
_IPV4_PAT,
_IPV6_ADDRZ_PAT,
)
_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL)
UNRESERVED_CHARS = set(
_UNRESERVED_CHARS = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
)
SUB_DELIM_CHARS = set("!$&'()*+,;=")
USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"}
PATH_CHARS = USERINFO_CHARS | {"@", "/"}
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"}
_SUB_DELIM_CHARS = set("!$&'()*+,;=")
_USERINFO_CHARS = _UNRESERVED_CHARS | _SUB_DELIM_CHARS | {":"}
_PATH_CHARS = _USERINFO_CHARS | {"@", "/"}
_QUERY_CHARS = _FRAGMENT_CHARS = _PATH_CHARS | {"?"}
class Url(namedtuple("Url", url_attrs)):
class Url(
typing.NamedTuple(
"Url",
[
("scheme", typing.Optional[str]),
("auth", typing.Optional[str]),
("host", typing.Optional[str]),
("port", typing.Optional[int]),
("path", typing.Optional[str]),
("query", typing.Optional[str]),
("fragment", typing.Optional[str]),
],
)
):
"""
Data structure for representing an HTTP URL. Used as a return value for
:func:`parse_url`. Both the scheme and host are normalized as they are
both case-insensitive according to RFC 3986.
"""
__slots__ = ()
def __new__(
def __new__( # type: ignore[no-untyped-def]
cls,
scheme=None,
auth=None,
host=None,
port=None,
path=None,
query=None,
fragment=None,
scheme: str | None = None,
auth: str | None = None,
host: str | None = None,
port: int | None = None,
path: str | None = None,
query: str | None = None,
fragment: str | None = None,
):
if path and not path.startswith("/"):
path = "/" + path
if scheme is not None:
scheme = scheme.lower()
return super(Url, cls).__new__(
cls, scheme, auth, host, port, path, query, fragment
)
return super().__new__(cls, scheme, auth, host, port, path, query, fragment)
@property
def hostname(self):
def hostname(self) -> str | None:
"""For backwards-compatibility with urlparse. We're nice like that."""
return self.host
@property
def request_uri(self):
def request_uri(self) -> str:
"""Absolute path including the query string."""
uri = self.path or "/"
@ -122,14 +129,37 @@ class Url(namedtuple("Url", url_attrs)):
return uri
@property
def netloc(self):
"""Network location including host and port"""
def authority(self) -> str | None:
"""
Authority component as defined in RFC 3986 3.2.
This includes userinfo (auth), host and port.
i.e.
userinfo@host:port
"""
userinfo = self.auth
netloc = self.netloc
if netloc is None or userinfo is None:
return netloc
else:
return f"{userinfo}@{netloc}"
@property
def netloc(self) -> str | None:
"""
Network location including host and port.
If you need the equivalent of urllib.parse's ``netloc``,
use the ``authority`` property instead.
"""
if self.host is None:
return None
if self.port:
return "%s:%d" % (self.host, self.port)
return f"{self.host}:{self.port}"
return self.host
@property
def url(self):
def url(self) -> str:
"""
Convert self into a url
@ -138,88 +168,77 @@ class Url(namedtuple("Url", url_attrs)):
:func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls
with a blank port will have : removed).
Example: ::
Example:
>>> U = parse_url('http://google.com/mail/')
>>> U.url
'http://google.com/mail/'
>>> Url('http', 'username:password', 'host.com', 80,
... '/path', 'query', 'fragment').url
'http://username:password@host.com:80/path?query#fragment'
.. code-block:: python
import urllib3
U = urllib3.util.parse_url("https://google.com/mail/")
print(U.url)
# "https://google.com/mail/"
print( urllib3.util.Url("https", "username:password",
"host.com", 80, "/path", "query", "fragment"
).url
)
# "https://username:password@host.com:80/path?query#fragment"
"""
scheme, auth, host, port, path, query, fragment = self
url = u""
url = ""
# We use "is not None" we want things to happen with empty strings (or 0 port)
if scheme is not None:
url += scheme + u"://"
url += scheme + "://"
if auth is not None:
url += auth + u"@"
url += auth + "@"
if host is not None:
url += host
if port is not None:
url += u":" + str(port)
url += ":" + str(port)
if path is not None:
url += path
if query is not None:
url += u"?" + query
url += "?" + query
if fragment is not None:
url += u"#" + fragment
url += "#" + fragment
return url
def __str__(self):
def __str__(self) -> str:
return self.url
def split_first(s, delims):
"""
.. deprecated:: 1.25
Given a string and an iterable of delimiters, split on the first found
delimiter. Return two split parts and the matched delimiter.
If not found, then the first part is the full input string.
Example::
>>> split_first('foo/bar?baz', '?/=')
('foo', 'bar?baz', '/')
>>> split_first('foo/bar?baz', '123')
('foo/bar?baz', '', None)
Scales linearly with number of delims. Not ideal for large number of delims.
"""
min_idx = None
min_delim = None
for d in delims:
idx = s.find(d)
if idx < 0:
continue
if min_idx is None or idx < min_idx:
min_idx = idx
min_delim = d
if min_idx is None or min_idx < 0:
return s, "", None
return s[:min_idx], s[min_idx + 1 :], min_delim
@typing.overload
def _encode_invalid_chars(
component: str, allowed_chars: typing.Container[str]
) -> str: # Abstract
...
def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
@typing.overload
def _encode_invalid_chars(
component: None, allowed_chars: typing.Container[str]
) -> None: # Abstract
...
def _encode_invalid_chars(
component: str | None, allowed_chars: typing.Container[str]
) -> str | None:
"""Percent-encodes a URI component without reapplying
onto an already percent-encoded component.
"""
if component is None:
return component
component = six.ensure_text(component)
component = to_str(component)
# Normalize existing percent-encoded bytes.
# Try to see if the component we're encoding is already percent-encoded
# so we can skip all '%' characters but still encode all others.
component, percent_encodings = PERCENT_RE.subn(
component, percent_encodings = _PERCENT_RE.subn(
lambda match: match.group(0).upper(), component
)
@ -228,7 +247,7 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
encoded_component = bytearray()
for i in range(0, len(uri_bytes)):
# Will return a single character bytestring on both Python 2 & 3
# Will return a single character bytestring
byte = uri_bytes[i : i + 1]
byte_ord = ord(byte)
if (is_percent_encoded and byte == b"%") or (
@ -238,10 +257,10 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
continue
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
return encoded_component.decode(encoding)
return encoded_component.decode()
def _remove_path_dot_segments(path):
def _remove_path_dot_segments(path: str) -> str:
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
segments = path.split("/") # Turn the path into a list of segments
output = [] # Initialize the variable to use to store output
@ -251,7 +270,7 @@ def _remove_path_dot_segments(path):
if segment == ".":
continue
# Anything other than '..', should be appended to the output
elif segment != "..":
if segment != "..":
output.append(segment)
# In this case segment == '..', if we can, we should pop the last
# element
@ -271,18 +290,25 @@ def _remove_path_dot_segments(path):
return "/".join(output)
def _normalize_host(host, scheme):
if host:
if isinstance(host, six.binary_type):
host = six.ensure_str(host)
@typing.overload
def _normalize_host(host: None, scheme: str | None) -> None:
...
if scheme in NORMALIZABLE_SCHEMES:
is_ipv6 = IPV6_ADDRZ_RE.match(host)
@typing.overload
def _normalize_host(host: str, scheme: str | None) -> str:
...
def _normalize_host(host: str | None, scheme: str | None) -> str | None:
if host:
if scheme in _NORMALIZABLE_SCHEMES:
is_ipv6 = _IPV6_ADDRZ_RE.match(host)
if is_ipv6:
# IPv6 hosts of the form 'a::b%zone' are encoded in a URL as
# such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID
# separator as necessary to return a valid RFC 4007 scoped IP.
match = ZONE_ID_RE.search(host)
match = _ZONE_ID_RE.search(host)
if match:
start, end = match.span(1)
zone_id = host[start:end]
@ -291,46 +317,56 @@ def _normalize_host(host, scheme):
zone_id = zone_id[3:]
else:
zone_id = zone_id[1:]
zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS)
return host[:start].lower() + zone_id + host[end:]
zone_id = _encode_invalid_chars(zone_id, _UNRESERVED_CHARS)
return f"{host[:start].lower()}%{zone_id}{host[end:]}"
else:
return host.lower()
elif not IPV4_RE.match(host):
return six.ensure_str(
b".".join([_idna_encode(label) for label in host.split(".")])
elif not _IPV4_RE.match(host):
return to_str(
b".".join([_idna_encode(label) for label in host.split(".")]),
"ascii",
)
return host
def _idna_encode(name):
if name and any(ord(x) >= 128 for x in name):
def _idna_encode(name: str) -> bytes:
if not name.isascii():
try:
import idna
except ImportError:
six.raise_from(
LocationParseError("Unable to parse URL without the 'idna' module"),
None,
)
raise LocationParseError(
"Unable to parse URL without the 'idna' module"
) from None
try:
return idna.encode(name.lower(), strict=True, std3_rules=True)
except idna.IDNAError:
six.raise_from(
LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None
)
raise LocationParseError(
f"Name '{name}' is not a valid IDNA label"
) from None
return name.lower().encode("ascii")
def _encode_target(target):
"""Percent-encodes a request target so that there are no invalid characters"""
path, query = TARGET_RE.match(target).groups()
target = _encode_invalid_chars(path, PATH_CHARS)
query = _encode_invalid_chars(query, QUERY_CHARS)
def _encode_target(target: str) -> str:
"""Percent-encodes a request target so that there are no invalid characters
Pre-condition for this function is that 'target' must start with '/'.
If that is the case then _TARGET_RE will always produce a match.
"""
match = _TARGET_RE.match(target)
if not match: # Defensive:
raise LocationParseError(f"{target!r} is not a valid request URI")
path, query = match.groups()
encoded_target = _encode_invalid_chars(path, _PATH_CHARS)
if query is not None:
target += "?" + query
return target
query = _encode_invalid_chars(query, _QUERY_CHARS)
encoded_target += "?" + query
return encoded_target
def parse_url(url):
def parse_url(url: str) -> Url:
"""
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
performed to parse incomplete urls. Fields not provided will be None.
@ -341,28 +377,44 @@ def parse_url(url):
:param str url: URL to parse into a :class:`.Url` namedtuple.
Partly backwards-compatible with :mod:`urlparse`.
Partly backwards-compatible with :mod:`urllib.parse`.
Example::
Example:
>>> parse_url('http://google.com/mail/')
Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
>>> parse_url('google.com:80')
Url(scheme=None, host='google.com', port=80, path=None, ...)
>>> parse_url('/foo?bar')
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
.. code-block:: python
import urllib3
print( urllib3.util.parse_url('http://google.com/mail/'))
# Url(scheme='http', host='google.com', port=None, path='/mail/', ...)
print( urllib3.util.parse_url('google.com:80'))
# Url(scheme=None, host='google.com', port=80, path=None, ...)
print( urllib3.util.parse_url('/foo?bar'))
# Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
"""
if not url:
# Empty
return Url()
source_url = url
if not SCHEME_RE.search(url):
if not _SCHEME_RE.search(url):
url = "//" + url
scheme: str | None
authority: str | None
auth: str | None
host: str | None
port: str | None
port_int: int | None
path: str | None
query: str | None
fragment: str | None
try:
scheme, authority, path, query, fragment = URI_RE.match(url).groups()
normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES
scheme, authority, path, query, fragment = _URI_RE.match(url).groups() # type: ignore[union-attr]
normalize_uri = scheme is None or scheme.lower() in _NORMALIZABLE_SCHEMES
if scheme:
scheme = scheme.lower()
@ -370,31 +422,33 @@ def parse_url(url):
if authority:
auth, _, host_port = authority.rpartition("@")
auth = auth or None
host, port = _HOST_PORT_RE.match(host_port).groups()
host, port = _HOST_PORT_RE.match(host_port).groups() # type: ignore[union-attr]
if auth and normalize_uri:
auth = _encode_invalid_chars(auth, USERINFO_CHARS)
auth = _encode_invalid_chars(auth, _USERINFO_CHARS)
if port == "":
port = None
else:
auth, host, port = None, None, None
if port is not None:
port = int(port)
if not (0 <= port <= 65535):
port_int = int(port)
if not (0 <= port_int <= 65535):
raise LocationParseError(url)
else:
port_int = None
host = _normalize_host(host, scheme)
if normalize_uri and path:
path = _remove_path_dot_segments(path)
path = _encode_invalid_chars(path, PATH_CHARS)
path = _encode_invalid_chars(path, _PATH_CHARS)
if normalize_uri and query:
query = _encode_invalid_chars(query, QUERY_CHARS)
query = _encode_invalid_chars(query, _QUERY_CHARS)
if normalize_uri and fragment:
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS)
fragment = _encode_invalid_chars(fragment, _FRAGMENT_CHARS)
except (ValueError, AttributeError):
return six.raise_from(LocationParseError(source_url), None)
except (ValueError, AttributeError) as e:
raise LocationParseError(source_url) from e
# For the sake of backwards compatibility we put empty
# string values for path if there are any defined values
@ -406,30 +460,12 @@ def parse_url(url):
else:
path = None
# Ensure that each part of the URL is a `str` for
# backwards compatibility.
if isinstance(url, six.text_type):
ensure_func = six.ensure_text
else:
ensure_func = six.ensure_str
def ensure_type(x):
return x if x is None else ensure_func(x)
return Url(
scheme=ensure_type(scheme),
auth=ensure_type(auth),
host=ensure_type(host),
port=port,
path=ensure_type(path),
query=ensure_type(query),
fragment=ensure_type(fragment),
scheme=scheme,
auth=auth,
host=host,
port=port_int,
path=path,
query=query,
fragment=fragment,
)
def get_host(url):
"""
Deprecated. Use :func:`parse_url` instead.
"""
p = parse_url(url)
return p.scheme or "http", p.hostname, p.port

42
lib/urllib3/util/util.py Normal file
View file

@ -0,0 +1,42 @@
from __future__ import annotations
import typing
from types import TracebackType
def to_bytes(
x: str | bytes, encoding: str | None = None, errors: str | None = None
) -> bytes:
if isinstance(x, bytes):
return x
elif not isinstance(x, str):
raise TypeError(f"not expecting type {type(x).__name__}")
if encoding or errors:
return x.encode(encoding or "utf-8", errors=errors or "strict")
return x.encode()
def to_str(
x: str | bytes, encoding: str | None = None, errors: str | None = None
) -> str:
if isinstance(x, str):
return x
elif not isinstance(x, bytes):
raise TypeError(f"not expecting type {type(x).__name__}")
if encoding or errors:
return x.decode(encoding or "utf-8", errors=errors or "strict")
return x.decode()
def reraise(
tp: type[BaseException] | None,
value: BaseException,
tb: TracebackType | None = None,
) -> typing.NoReturn:
try:
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
finally:
value = None # type: ignore[assignment]
tb = None

View file

@ -1,18 +1,10 @@
import errno
from __future__ import annotations
import select
import sys
import socket
from functools import partial
try:
from time import monotonic
except ImportError:
from time import time as monotonic
__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"]
class NoWayToWaitForSocketError(Exception):
pass
__all__ = ["wait_for_read", "wait_for_write"]
# How should we wait on sockets?
@ -37,37 +29,13 @@ class NoWayToWaitForSocketError(Exception):
# So: on Windows we use select(), and everywhere else we use poll(). We also
# fall back to select() in case poll() is somehow broken or missing.
if sys.version_info >= (3, 5):
# Modern Python, that retries syscalls by default
def _retry_on_intr(fn, timeout):
return fn(timeout)
else:
# Old and broken Pythons.
def _retry_on_intr(fn, timeout):
if timeout is None:
deadline = float("inf")
else:
deadline = monotonic() + timeout
while True:
try:
return fn(timeout)
# OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7
except (OSError, select.error) as e:
# 'e.args[0]' incantation works for both OSError and select.error
if e.args[0] != errno.EINTR:
raise
else:
timeout = deadline - monotonic()
if timeout < 0:
timeout = 0
if timeout == float("inf"):
timeout = None
continue
def select_wait_for_socket(sock, read=False, write=False, timeout=None):
def select_wait_for_socket(
sock: socket.socket,
read: bool = False,
write: bool = False,
timeout: float | None = None,
) -> bool:
if not read and not write:
raise RuntimeError("must specify at least one of read=True, write=True")
rcheck = []
@ -82,11 +50,16 @@ def select_wait_for_socket(sock, read=False, write=False, timeout=None):
# sockets for both conditions. (The stdlib selectors module does the same
# thing.)
fn = partial(select.select, rcheck, wcheck, wcheck)
rready, wready, xready = _retry_on_intr(fn, timeout)
rready, wready, xready = fn(timeout)
return bool(rready or wready or xready)
def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
def poll_wait_for_socket(
sock: socket.socket,
read: bool = False,
write: bool = False,
timeout: float | None = None,
) -> bool:
if not read and not write:
raise RuntimeError("must specify at least one of read=True, write=True")
mask = 0
@ -98,32 +71,33 @@ def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
poll_obj.register(sock, mask)
# For some reason, poll() takes timeout in milliseconds
def do_poll(t):
def do_poll(t: float | None) -> list[tuple[int, int]]:
if t is not None:
t *= 1000
return poll_obj.poll(t)
return bool(_retry_on_intr(do_poll, timeout))
return bool(do_poll(timeout))
def null_wait_for_socket(*args, **kwargs):
raise NoWayToWaitForSocketError("no select-equivalent available")
def _have_working_poll():
def _have_working_poll() -> bool:
# Apparently some systems have a select.poll that fails as soon as you try
# to use it, either due to strange configuration or broken monkeypatching
# from libraries like eventlet/greenlet.
try:
poll_obj = select.poll()
_retry_on_intr(poll_obj.poll, 0)
poll_obj.poll(0)
except (AttributeError, OSError):
return False
else:
return True
def wait_for_socket(*args, **kwargs):
def wait_for_socket(
sock: socket.socket,
read: bool = False,
write: bool = False,
timeout: float | None = None,
) -> bool:
# We delay choosing which implementation to use until the first time we're
# called. We could do it at import time, but then we might make the wrong
# decision if someone goes wild with monkeypatching select.poll after
@ -133,19 +107,17 @@ def wait_for_socket(*args, **kwargs):
wait_for_socket = poll_wait_for_socket
elif hasattr(select, "select"):
wait_for_socket = select_wait_for_socket
else: # Platform-specific: Appengine.
wait_for_socket = null_wait_for_socket
return wait_for_socket(*args, **kwargs)
return wait_for_socket(sock, read, write, timeout)
def wait_for_read(sock, timeout=None):
def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool:
"""Waits for reading to be available on a given socket.
Returns True if the socket is readable, or False if the timeout expired.
"""
return wait_for_socket(sock, read=True, timeout=timeout)
def wait_for_write(sock, timeout=None):
def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool:
"""Waits for writing to be available on a given socket.
Returns True if the socket is readable, or False if the timeout expired.
"""