Merge branch 'nightly' into dependabot/pip/nightly/plexapi-4.15.0

This commit is contained in:
JonnyWong16 2023-08-24 12:10:47 -07:00 committed by GitHub
commit aeca9a3445
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
156 changed files with 5486 additions and 3067 deletions

View file

@ -1,10 +1,15 @@
from pkg_resources import get_distribution, DistributionNotFound import sys
if sys.version_info >= (3, 8):
import importlib.metadata as importlib_metadata
else:
import importlib_metadata
try: try:
release = get_distribution('APScheduler').version.split('-')[0] release = importlib_metadata.version('APScheduler').split('-')[0]
except DistributionNotFound: except importlib_metadata.PackageNotFoundError:
release = '3.5.0' release = '3.5.0'
version_info = tuple(int(x) if x.isdigit() else x for x in release.split('.')) version_info = tuple(int(x) if x.isdigit() else x for x in release.split('.'))
version = __version__ = '.'.join(str(x) for x in version_info[:3]) version = __version__ = '.'.join(str(x) for x in version_info[:3])
del get_distribution, DistributionNotFound del sys, importlib_metadata

View file

@ -7,7 +7,6 @@ from logging import getLogger
import warnings import warnings
import sys import sys
from pkg_resources import iter_entry_points
from tzlocal import get_localzone from tzlocal import get_localzone
import six import six
@ -31,6 +30,11 @@ try:
except ImportError: except ImportError:
from collections import MutableMapping from collections import MutableMapping
try:
from importlib.metadata import entry_points
except ModuleNotFoundError:
from importlib_metadata import entry_points
#: constant indicating a scheduler's stopped state #: constant indicating a scheduler's stopped state
STATE_STOPPED = 0 STATE_STOPPED = 0
#: constant indicating a scheduler's running state (started and processing jobs) #: constant indicating a scheduler's running state (started and processing jobs)
@ -62,12 +66,18 @@ class BaseScheduler(six.with_metaclass(ABCMeta)):
.. seealso:: :ref:`scheduler-config` .. seealso:: :ref:`scheduler-config`
""" """
# The `group=...` API is only available in the backport, used in <=3.7, and in std>=3.10.
if (3, 8) <= sys.version_info < (3, 10):
_trigger_plugins = {ep.name: ep for ep in entry_points()['apscheduler.triggers']}
_executor_plugins = {ep.name: ep for ep in entry_points()['apscheduler.executors']}
_jobstore_plugins = {ep.name: ep for ep in entry_points()['apscheduler.jobstores']}
else:
_trigger_plugins = {ep.name: ep for ep in entry_points(group='apscheduler.triggers')}
_executor_plugins = {ep.name: ep for ep in entry_points(group='apscheduler.executors')}
_jobstore_plugins = {ep.name: ep for ep in entry_points(group='apscheduler.jobstores')}
_trigger_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.triggers'))
_trigger_classes = {} _trigger_classes = {}
_executor_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.executors'))
_executor_classes = {} _executor_classes = {}
_jobstore_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.jobstores'))
_jobstore_classes = {} _jobstore_classes = {}
# #
@ -1019,6 +1029,7 @@ class BaseScheduler(six.with_metaclass(ABCMeta)):
wait_seconds = None wait_seconds = None
self._logger.debug('No jobs; waiting until a job is added') self._logger.debug('No jobs; waiting until a job is added')
else: else:
now = datetime.now(self.timezone)
wait_seconds = min(max(timedelta_seconds(next_wakeup_time - now), 0), TIMEOUT_MAX) wait_seconds = min(max(timedelta_seconds(next_wakeup_time - now), 0), TIMEOUT_MAX)
self._logger.debug('Next wakeup is due at %s (in %f seconds)', next_wakeup_time, self._logger.debug('Next wakeup is due at %s (in %f seconds)', next_wakeup_time,
wait_seconds) wait_seconds)

View file

@ -1,24 +1,22 @@
from __future__ import absolute_import from __future__ import absolute_import
from importlib import import_module
from itertools import product
from apscheduler.schedulers.base import BaseScheduler from apscheduler.schedulers.base import BaseScheduler
try: for version, pkgname in product(range(6, 1, -1), ("PySide", "PyQt")):
from PyQt5.QtCore import QObject, QTimer
except (ImportError, RuntimeError): # pragma: nocover
try: try:
from PyQt4.QtCore import QObject, QTimer qtcore = import_module(pkgname + str(version) + ".QtCore")
except ImportError: except ImportError:
try: pass
from PySide6.QtCore import QObject, QTimer # noqa else:
except ImportError: QTimer = qtcore.QTimer
try: break
from PySide2.QtCore import QObject, QTimer # noqa else:
except ImportError: raise ImportError(
try: "QtScheduler requires either PySide/PyQt (v6 to v2) installed"
from PySide.QtCore import QObject, QTimer # noqa )
except ImportError:
raise ImportError('QtScheduler requires either PyQt5, PyQt4, PySide6, PySide2 '
'or PySide installed')
class QtScheduler(BaseScheduler): class QtScheduler(BaseScheduler):

View file

@ -6,7 +6,7 @@ from asyncio import iscoroutinefunction
from datetime import date, datetime, time, timedelta, tzinfo from datetime import date, datetime, time, timedelta, tzinfo
from calendar import timegm from calendar import timegm
from functools import partial from functools import partial
from inspect import isclass, ismethod from inspect import isbuiltin, isclass, isfunction, ismethod
import re import re
import sys import sys
@ -214,28 +214,15 @@ def get_callable_name(func):
:rtype: str :rtype: str
""" """
# the easy case (on Python 3.3+) if ismethod(func):
if hasattr(func, '__qualname__'): self = func.__self__
cls = self if isclass(self) else type(self)
return f"{cls.__qualname__}.{func.__name__}"
elif isclass(func) or isfunction(func) or isbuiltin(func):
return func.__qualname__ return func.__qualname__
elif hasattr(func, '__call__') and callable(func.__call__):
# class methods, bound and unbound methods
f_self = getattr(func, '__self__', None) or getattr(func, 'im_self', None)
if f_self and hasattr(func, '__name__'):
f_class = f_self if isclass(f_self) else f_self.__class__
else:
f_class = getattr(func, 'im_class', None)
if f_class and hasattr(func, '__name__'):
return '%s.%s' % (f_class.__name__, func.__name__)
# class or class instance
if hasattr(func, '__call__'):
# class
if hasattr(func, '__name__'):
return func.__name__
# instance of a class with a __call__ method # instance of a class with a __call__ method
return func.__class__.__name__ return type(func).__qualname__
raise TypeError('Unable to determine a name for %r -- maybe it is not a callable?' % func) raise TypeError('Unable to determine a name for %r -- maybe it is not a callable?' % func)
@ -260,16 +247,10 @@ def obj_to_ref(obj):
raise ValueError('Cannot create a reference to a nested function') raise ValueError('Cannot create a reference to a nested function')
if ismethod(obj): if ismethod(obj):
if hasattr(obj, 'im_self') and obj.im_self: module = obj.__self__.__module__
# bound method
module = obj.im_self.__module__
elif hasattr(obj, 'im_class') and obj.im_class:
# unbound method
module = obj.im_class.__module__
else:
module = obj.__module__
else: else:
module = obj.__module__ module = obj.__module__
return '%s:%s' % (module, name) return '%s:%s' % (module, name)

View file

@ -1,5 +1 @@
# A Python "namespace package" http://www.python.org/dev/peps/pep-0382/ __path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
# This always goes inside of a namespace package's __init__.py
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__) # type: ignore

View file

@ -89,7 +89,6 @@ def lru_cache(maxsize=100, typed=False): # noqa: C901
# to allow the implementation to change (including a possible C version). # to allow the implementation to change (including a possible C version).
def decorating_function(user_function): def decorating_function(user_function):
cache = dict() cache = dict()
stats = [0, 0] # make statistics updateable non-locally stats = [0, 0] # make statistics updateable non-locally
HITS, MISSES = 0, 1 # names for the stats fields HITS, MISSES = 0, 1 # names for the stats fields

View file

@ -24,6 +24,7 @@ SYS_PLATFORM = platform.system()
IS_WINDOWS = SYS_PLATFORM == 'Windows' IS_WINDOWS = SYS_PLATFORM == 'Windows'
IS_LINUX = SYS_PLATFORM == 'Linux' IS_LINUX = SYS_PLATFORM == 'Linux'
IS_MACOS = SYS_PLATFORM == 'Darwin' IS_MACOS = SYS_PLATFORM == 'Darwin'
IS_SOLARIS = SYS_PLATFORM == 'SunOS'
PLATFORM_ARCH = platform.machine() PLATFORM_ARCH = platform.machine()
IS_PPC = PLATFORM_ARCH.startswith('ppc') IS_PPC = PLATFORM_ARCH.startswith('ppc')

View file

@ -10,6 +10,7 @@ SYS_PLATFORM: str
IS_WINDOWS: bool IS_WINDOWS: bool
IS_LINUX: bool IS_LINUX: bool
IS_MACOS: bool IS_MACOS: bool
IS_SOLARIS: bool
PLATFORM_ARCH: str PLATFORM_ARCH: str
IS_PPC: bool IS_PPC: bool

View file

@ -274,8 +274,7 @@ class ConnectionManager:
# One of the reason on why a socket could cause an error # One of the reason on why a socket could cause an error
# is that the socket is already closed, ignore the # is that the socket is already closed, ignore the
# socket error if we try to close it at this point. # socket error if we try to close it at this point.
# This is equivalent to OSError in Py3 with suppress(OSError):
with suppress(socket.error):
conn.close() conn.close()
def _from_server_socket(self, server_socket): # noqa: C901 # FIXME def _from_server_socket(self, server_socket): # noqa: C901 # FIXME
@ -308,7 +307,7 @@ class ConnectionManager:
wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE) wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE)
try: try:
wfile.write(''.join(buf).encode('ISO-8859-1')) wfile.write(''.join(buf).encode('ISO-8859-1'))
except socket.error as ex: except OSError as ex:
if ex.args[0] not in errors.socket_errors_to_ignore: if ex.args[0] not in errors.socket_errors_to_ignore:
raise raise
return return
@ -343,7 +342,7 @@ class ConnectionManager:
# notice keyboard interrupts on Win32, which don't interrupt # notice keyboard interrupts on Win32, which don't interrupt
# accept() by default # accept() by default
return return
except socket.error as ex: except OSError as ex:
if self.server.stats['Enabled']: if self.server.stats['Enabled']:
self.server.stats['Socket Errors'] += 1 self.server.stats['Socket Errors'] += 1
if ex.args[0] in errors.socket_error_eintr: if ex.args[0] in errors.socket_error_eintr:

View file

@ -77,9 +77,4 @@ Refs:
* https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown * https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown
""" """
try: # py3 acceptable_sock_shutdown_exceptions = (BrokenPipeError, ConnectionResetError)
acceptable_sock_shutdown_exceptions = (
BrokenPipeError, ConnectionResetError,
)
except NameError: # py2
acceptable_sock_shutdown_exceptions = ()

View file

@ -1572,6 +1572,9 @@ class HTTPServer:
``PEERCREDS``-provided IDs. ``PEERCREDS``-provided IDs.
""" """
reuse_port = False
"""If True, set SO_REUSEPORT on the socket."""
keep_alive_conn_limit = 10 keep_alive_conn_limit = 10
"""Maximum number of waiting keep-alive connections that will be kept open. """Maximum number of waiting keep-alive connections that will be kept open.
@ -1581,6 +1584,7 @@ class HTTPServer:
self, bind_addr, gateway, self, bind_addr, gateway,
minthreads=10, maxthreads=-1, server_name=None, minthreads=10, maxthreads=-1, server_name=None,
peercreds_enabled=False, peercreds_resolve_enabled=False, peercreds_enabled=False, peercreds_resolve_enabled=False,
reuse_port=False,
): ):
"""Initialize HTTPServer instance. """Initialize HTTPServer instance.
@ -1591,6 +1595,8 @@ class HTTPServer:
maxthreads (int): maximum number of threads for HTTP thread pool maxthreads (int): maximum number of threads for HTTP thread pool
server_name (str): web server name to be advertised via Server server_name (str): web server name to be advertised via Server
HTTP header HTTP header
reuse_port (bool): if True SO_REUSEPORT option would be set to
socket
""" """
self.bind_addr = bind_addr self.bind_addr = bind_addr
self.gateway = gateway self.gateway = gateway
@ -1606,6 +1612,7 @@ class HTTPServer:
self.peercreds_resolve_enabled = ( self.peercreds_resolve_enabled = (
peercreds_resolve_enabled and peercreds_enabled peercreds_resolve_enabled and peercreds_enabled
) )
self.reuse_port = reuse_port
self.clear_stats() self.clear_stats()
def clear_stats(self): def clear_stats(self):
@ -1880,6 +1887,7 @@ class HTTPServer:
self.bind_addr, self.bind_addr,
family, type, proto, family, type, proto,
self.nodelay, self.ssl_adapter, self.nodelay, self.ssl_adapter,
self.reuse_port,
) )
sock = self.socket = self.bind_socket(sock, self.bind_addr) sock = self.socket = self.bind_socket(sock, self.bind_addr)
self.bind_addr = self.resolve_real_bind_addr(sock) self.bind_addr = self.resolve_real_bind_addr(sock)
@ -1911,9 +1919,6 @@ class HTTPServer:
'remove() argument 1 must be encoded ' 'remove() argument 1 must be encoded '
'string without null bytes, not unicode' 'string without null bytes, not unicode'
not in err_msg not in err_msg
and 'embedded NUL character' not in err_msg # py34
and 'argument must be a '
'string without NUL characters' not in err_msg # pypy2
): ):
raise raise
except ValueError as val_err: except ValueError as val_err:
@ -1931,6 +1936,7 @@ class HTTPServer:
bind_addr=bind_addr, bind_addr=bind_addr,
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0,
nodelay=self.nodelay, ssl_adapter=self.ssl_adapter, nodelay=self.nodelay, ssl_adapter=self.ssl_adapter,
reuse_port=self.reuse_port,
) )
try: try:
@ -1971,7 +1977,36 @@ class HTTPServer:
return sock return sock
@staticmethod @staticmethod
def prepare_socket(bind_addr, family, type, proto, nodelay, ssl_adapter): def _make_socket_reusable(socket_, bind_addr):
host, port = bind_addr[:2]
IS_EPHEMERAL_PORT = port == 0
if socket_.family not in (socket.AF_INET, socket.AF_INET6):
raise ValueError('Cannot reuse a non-IP socket')
if IS_EPHEMERAL_PORT:
raise ValueError('Cannot reuse an ephemeral port (0)')
# Most BSD kernels implement SO_REUSEPORT the way that only the
# latest listener can read from socket. Some of BSD kernels also
# have SO_REUSEPORT_LB that works similarly to SO_REUSEPORT
# in Linux.
if hasattr(socket, 'SO_REUSEPORT_LB'):
socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT_LB, 1)
elif hasattr(socket, 'SO_REUSEPORT'):
socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
elif IS_WINDOWS:
socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
else:
raise NotImplementedError(
'Current platform does not support port reuse',
)
@classmethod
def prepare_socket(
cls, bind_addr, family, type, proto, nodelay, ssl_adapter,
reuse_port=False,
):
"""Create and prepare the socket object.""" """Create and prepare the socket object."""
sock = socket.socket(family, type, proto) sock = socket.socket(family, type, proto)
connections.prevent_socket_inheritance(sock) connections.prevent_socket_inheritance(sock)
@ -1979,6 +2014,9 @@ class HTTPServer:
host, port = bind_addr[:2] host, port = bind_addr[:2]
IS_EPHEMERAL_PORT = port == 0 IS_EPHEMERAL_PORT = port == 0
if reuse_port:
cls._make_socket_reusable(socket_=sock, bind_addr=bind_addr)
if not (IS_WINDOWS or IS_EPHEMERAL_PORT): if not (IS_WINDOWS or IS_EPHEMERAL_PORT):
"""Enable SO_REUSEADDR for the current socket. """Enable SO_REUSEADDR for the current socket.

View file

@ -130,9 +130,10 @@ class HTTPServer:
ssl_adapter: Any ssl_adapter: Any
peercreds_enabled: bool peercreds_enabled: bool
peercreds_resolve_enabled: bool peercreds_resolve_enabled: bool
reuse_port: bool
keep_alive_conn_limit: int keep_alive_conn_limit: int
requests: Any requests: Any
def __init__(self, bind_addr, gateway, minthreads: int = ..., maxthreads: int = ..., server_name: Any | None = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ...) -> None: ... def __init__(self, bind_addr, gateway, minthreads: int = ..., maxthreads: int = ..., server_name: Any | None = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ..., reuse_port: bool = ...) -> None: ...
stats: Any stats: Any
def clear_stats(self): ... def clear_stats(self): ...
def runtime(self): ... def runtime(self): ...
@ -152,7 +153,9 @@ class HTTPServer:
def bind(self, family, type, proto: int = ...): ... def bind(self, family, type, proto: int = ...): ...
def bind_unix_socket(self, bind_addr): ... def bind_unix_socket(self, bind_addr): ...
@staticmethod @staticmethod
def prepare_socket(bind_addr, family, type, proto, nodelay, ssl_adapter): ... def _make_socket_reusable(socket_, bind_addr) -> None: ...
@classmethod
def prepare_socket(cls, bind_addr, family, type, proto, nodelay, ssl_adapter, reuse_port: bool = ...): ...
@staticmethod @staticmethod
def bind_socket(socket_, bind_addr): ... def bind_socket(socket_, bind_addr): ...
@staticmethod @staticmethod

View file

@ -1,7 +1,7 @@
from abc import abstractmethod from abc import abstractmethod, ABCMeta
from typing import Any from typing import Any
class Adapter(): class Adapter(metaclass=ABCMeta):
certificate: Any certificate: Any
private_key: Any private_key: Any
certificate_chain: Any certificate_chain: Any

View file

@ -4,11 +4,7 @@ Contains hooks, which are tightly bound to the Cheroot framework
itself, useless for end-users' app testing. itself, useless for end-users' app testing.
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest import pytest
import six
pytest_version = tuple(map(int, pytest.__version__.split('.'))) pytest_version = tuple(map(int, pytest.__version__.split('.')))
@ -45,16 +41,3 @@ def pytest_load_initial_conftests(early_config, parser, args):
'type=SocketKind.SOCK_STREAM, proto=.:' 'type=SocketKind.SOCK_STREAM, proto=.:'
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception', 'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
)) ))
if six.PY2:
return
# NOTE: `ResourceWarning` does not exist under Python 2 and so using
# NOTE: it in warning filters results in an `_OptionError` exception
# NOTE: being raised.
early_config._inicache['filterwarnings'].extend((
# FIXME: Try to figure out what causes this and ensure that the socket
# FIXME: gets closed.
'ignore:unclosed <socket.socket fd=:ResourceWarning',
'ignore:unclosed <ssl.SSLSocket fd=:ResourceWarning',
))

View file

@ -1218,8 +1218,7 @@ def test_No_CRLF(test_client, invalid_terminator):
# Initialize a persistent HTTP connection # Initialize a persistent HTTP connection
conn = test_client.get_connection() conn = test_client.get_connection()
# (b'%s' % b'') is not supported in Python 3.4, so just use bytes.join() conn.send(b'GET /hello HTTP/1.1%s' % invalid_terminator)
conn.send(b''.join((b'GET /hello HTTP/1.1', invalid_terminator)))
response = conn.response_class(conn.sock, method='GET') response = conn.response_class(conn.sock, method='GET')
response.begin() response.begin()
actual_resp_body = response.read() actual_resp_body = response.read()

View file

@ -69,11 +69,7 @@ class HelloController(helper.Controller):
def _get_http_response(connection, method='GET'): def _get_http_response(connection, method='GET'):
c = connection return connection.response_class(connection.sock, method=method)
kwargs = {'strict': c.strict} if hasattr(c, 'strict') else {}
# Python 3.2 removed the 'strict' feature, saying:
# "http.client now always assumes HTTP/1.x compliant servers."
return c.response_class(c.sock, method=method, **kwargs)
@pytest.fixture @pytest.fixture

View file

@ -4,7 +4,7 @@ import pytest
from cheroot import errors from cheroot import errors
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS # noqa: WPS130 from .._compat import IS_LINUX, IS_MACOS, IS_SOLARIS, IS_WINDOWS # noqa: WPS130
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -18,6 +18,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS # noqa: WPS130
), ),
(91, 11, 32) if IS_LINUX else (91, 11, 32) if IS_LINUX else
(32, 35, 41) if IS_MACOS else (32, 35, 41) if IS_MACOS else
(98, 11, 32) if IS_SOLARIS else
(32, 10041, 11, 10035) if IS_WINDOWS else (32, 10041, 11, 10035) if IS_WINDOWS else
(), (),
), ),

View file

@ -5,6 +5,7 @@ import queue
import socket import socket
import tempfile import tempfile
import threading import threading
import types
import uuid import uuid
import urllib.parse # noqa: WPS301 import urllib.parse # noqa: WPS301
@ -17,6 +18,7 @@ from pypytools.gc.custom import DefaultGc
from .._compat import bton, ntob from .._compat import bton, ntob
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM
from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer
from ..workers.threadpool import ThreadPool
from ..testing import ( from ..testing import (
ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV6,
@ -254,6 +256,7 @@ def peercreds_enabled_server(http_server, unix_sock_file):
@unix_only_sock_test @unix_only_sock_test
@non_macos_sock_test @non_macos_sock_test
@pytest.mark.flaky(reruns=3, reruns_delay=2)
def test_peercreds_unix_sock(http_request_timeout, peercreds_enabled_server): def test_peercreds_unix_sock(http_request_timeout, peercreds_enabled_server):
"""Check that ``PEERCRED`` lookup works when enabled.""" """Check that ``PEERCRED`` lookup works when enabled."""
httpserver = peercreds_enabled_server httpserver = peercreds_enabled_server
@ -370,6 +373,33 @@ def test_high_number_of_file_descriptors(native_server_client, resource_limit):
assert any(fn >= resource_limit for fn in native_process_conn.filenos) assert any(fn >= resource_limit for fn in native_process_conn.filenos)
@pytest.mark.skipif(
not hasattr(socket, 'SO_REUSEPORT'),
reason='socket.SO_REUSEPORT is not supported on this platform',
)
@pytest.mark.parametrize(
'ip_addr',
(
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
),
)
def test_reuse_port(http_server, ip_addr, mocker):
"""Check that port initialized externally can be reused."""
family = socket.getaddrinfo(ip_addr, EPHEMERAL_PORT)[0][0]
s = socket.socket(family)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s.bind((ip_addr, EPHEMERAL_PORT))
server = HTTPServer(
bind_addr=s.getsockname()[:2], gateway=Gateway, reuse_port=True,
)
spy = mocker.spy(server, 'prepare')
server.prepare()
server.stop()
s.close()
assert spy.spy_exception is None
ISSUE511 = IS_MACOS ISSUE511 = IS_MACOS
@ -439,3 +469,90 @@ def many_open_sockets(request, resource_limit):
# Close our open resources # Close our open resources
for test_socket in test_sockets: for test_socket in test_sockets:
test_socket.close() test_socket.close()
@pytest.mark.parametrize(
('minthreads', 'maxthreads', 'inited_maxthreads'),
(
(
# NOTE: The docstring only mentions -1 to mean "no max", but other
# NOTE: negative numbers should also work.
1,
-2,
float('inf'),
),
(1, -1, float('inf')),
(1, 1, 1),
(1, 2, 2),
(1, float('inf'), float('inf')),
(2, -2, float('inf')),
(2, -1, float('inf')),
(2, 2, 2),
(2, float('inf'), float('inf')),
),
)
def test_threadpool_threadrange_set(minthreads, maxthreads, inited_maxthreads):
"""Test setting the number of threads in a ThreadPool.
The ThreadPool should properly set the min+max number of the threads to use
in the pool if those limits are valid.
"""
tp = ThreadPool(
server=None,
min=minthreads,
max=maxthreads,
)
assert tp.min == minthreads
assert tp.max == inited_maxthreads
@pytest.mark.parametrize(
('minthreads', 'maxthreads', 'error'),
(
(-1, -1, 'min=-1 must be > 0'),
(-1, 0, 'min=-1 must be > 0'),
(-1, 1, 'min=-1 must be > 0'),
(-1, 2, 'min=-1 must be > 0'),
(0, -1, 'min=0 must be > 0'),
(0, 0, 'min=0 must be > 0'),
(0, 1, 'min=0 must be > 0'),
(0, 2, 'min=0 must be > 0'),
(1, 0, 'Expected an integer or the infinity value for the `max` argument but got 0.'),
(1, 0.5, 'Expected an integer or the infinity value for the `max` argument but got 0.5.'),
(2, 0, 'Expected an integer or the infinity value for the `max` argument but got 0.'),
(2, '1', "Expected an integer or the infinity value for the `max` argument but got '1'."),
(2, 1, 'max=1 must be > min=2'),
),
)
def test_threadpool_invalid_threadrange(minthreads, maxthreads, error):
"""Test that a ThreadPool rejects invalid min/max values.
The ThreadPool should raise an error with the proper message when
initialized with an invalid min+max number of threads.
"""
with pytest.raises((ValueError, TypeError), match=error):
ThreadPool(
server=None,
min=minthreads,
max=maxthreads,
)
def test_threadpool_multistart_validation(monkeypatch):
"""Test for ThreadPool multi-start behavior.
Tests that when calling start() on a ThreadPool multiple times raises a
:exc:`RuntimeError`
"""
# replace _spawn_worker with a function that returns a placeholder to avoid
# actually starting any threads
monkeypatch.setattr(
ThreadPool,
'_spawn_worker',
lambda _: types.SimpleNamespace(ready=True),
)
tp = ThreadPool(server=None)
tp.start()
with pytest.raises(RuntimeError, match='Threadpools can only be started once.'):
tp.start()

View file

@ -55,17 +55,6 @@ _stdlib_to_openssl_verify = {
} }
fails_under_py3 = pytest.mark.xfail(
reason='Fails under Python 3+',
)
fails_under_py3_in_pypy = pytest.mark.xfail(
IS_PYPY,
reason='Fails under PyPy3',
)
missing_ipv6 = pytest.mark.skipif( missing_ipv6 = pytest.mark.skipif(
not _probe_ipv6_sock('::1'), not _probe_ipv6_sock('::1'),
reason='' reason=''
@ -556,7 +545,6 @@ def test_ssl_env( # noqa: C901 # FIXME
# builtin ssl environment generation may use a loopback socket # builtin ssl environment generation may use a loopback socket
# ensure no ResourceWarning was raised during the test # ensure no ResourceWarning was raised during the test
# NOTE: python 2.7 does not emit ResourceWarning for ssl sockets
if IS_PYPY: if IS_PYPY:
# NOTE: PyPy doesn't have ResourceWarning # NOTE: PyPy doesn't have ResourceWarning
# Ref: https://doc.pypy.org/en/latest/cpython_differences.html # Ref: https://doc.pypy.org/en/latest/cpython_differences.html

View file

@ -463,16 +463,13 @@ def shb(response):
return resp_status_line, response.getheaders(), response.read() return resp_status_line, response.getheaders(), response.read()
# def openURL(*args, raise_subcls=(), **kwargs): def openURL(*args, raise_subcls=(), **kwargs):
# py27 compatible signature:
def openURL(*args, **kwargs):
""" """
Open a URL, retrying when it fails. Open a URL, retrying when it fails.
Specify ``raise_subcls`` (class or tuple of classes) to exclude Specify ``raise_subcls`` (class or tuple of classes) to exclude
those socket.error subclasses from being suppressed and retried. those socket.error subclasses from being suppressed and retried.
""" """
raise_subcls = kwargs.pop('raise_subcls', ())
opener = functools.partial(_open_url_once, *args, **kwargs) opener = functools.partial(_open_url_once, *args, **kwargs)
def on_exception(): def on_exception():

View file

@ -119,9 +119,7 @@ def _probe_ipv6_sock(interface):
try: try:
with closing(socket.socket(family=socket.AF_INET6)) as sock: with closing(socket.socket(family=socket.AF_INET6)) as sock:
sock.bind((interface, 0)) sock.bind((interface, 0))
except (OSError, socket.error) as sock_err: except OSError as sock_err:
# In Python 3 socket.error is an alias for OSError
# In Python 2 socket.error is a subclass of IOError
if sock_err.errno != errno.EADDRNOTAVAIL: if sock_err.errno != errno.EADDRNOTAVAIL:
raise raise
else: else:

View file

@ -151,12 +151,33 @@ class ThreadPool:
server (cheroot.server.HTTPServer): web server object server (cheroot.server.HTTPServer): web server object
receiving this request receiving this request
min (int): minimum number of worker threads min (int): minimum number of worker threads
max (int): maximum number of worker threads max (int): maximum number of worker threads (-1/inf for no max)
accepted_queue_size (int): maximum number of active accepted_queue_size (int): maximum number of active
requests in queue requests in queue
accepted_queue_timeout (int): timeout for putting request accepted_queue_timeout (int): timeout for putting request
into queue into queue
:raises ValueError: if the min/max values are invalid
:raises TypeError: if the max is not an integer or inf
""" """
if min < 1:
raise ValueError(f'min={min!s} must be > 0')
if max == float('inf'):
pass
elif not isinstance(max, int) or max == 0:
raise TypeError(
'Expected an integer or the infinity value for the `max` '
f'argument but got {max!r}.',
)
elif max < 0:
max = float('inf')
if max < min:
raise ValueError(
f'max={max!s} must be > min={min!s} (or infinity for no max)',
)
self.server = server self.server = server
self.min = min self.min = min
self.max = max self.max = max
@ -167,18 +188,13 @@ class ThreadPool:
self._pending_shutdowns = collections.deque() self._pending_shutdowns = collections.deque()
def start(self): def start(self):
"""Start the pool of threads.""" """Start the pool of threads.
for _ in range(self.min):
self._threads.append(WorkerThread(self.server)) :raises RuntimeError: if the pool is already started
for worker in self._threads: """
worker.name = ( if self._threads:
'CP Server {worker_name!s}'. raise RuntimeError('Threadpools can only be started once.')
format(worker_name=worker.name) self.grow(self.min)
)
worker.start()
for worker in self._threads:
while not worker.ready:
time.sleep(.1)
@property @property
def idle(self): # noqa: D401; irrelevant for properties def idle(self): # noqa: D401; irrelevant for properties
@ -206,17 +222,13 @@ class ThreadPool:
def grow(self, amount): def grow(self, amount):
"""Spawn new worker threads (not above self.max).""" """Spawn new worker threads (not above self.max)."""
if self.max > 0: budget = max(self.max - len(self._threads), 0)
budget = max(self.max - len(self._threads), 0)
else:
# self.max <= 0 indicates no maximum
budget = float('inf')
n_new = min(amount, budget) n_new = min(amount, budget)
workers = [self._spawn_worker() for i in range(n_new)] workers = [self._spawn_worker() for i in range(n_new)]
while not all(worker.ready for worker in workers): for worker in workers:
time.sleep(.1) while not worker.ready:
time.sleep(.1)
self._threads.extend(workers) self._threads.extend(workers)
def _spawn_worker(self): def _spawn_worker(self):

View file

@ -43,6 +43,7 @@ class Server(server.HTTPServer):
max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5,
accepted_queue_size=-1, accepted_queue_timeout=10, accepted_queue_size=-1, accepted_queue_timeout=10,
peercreds_enabled=False, peercreds_resolve_enabled=False, peercreds_enabled=False, peercreds_resolve_enabled=False,
reuse_port=False,
): ):
"""Initialize WSGI Server instance. """Initialize WSGI Server instance.
@ -69,6 +70,7 @@ class Server(server.HTTPServer):
server_name=server_name, server_name=server_name,
peercreds_enabled=peercreds_enabled, peercreds_enabled=peercreds_enabled,
peercreds_resolve_enabled=peercreds_resolve_enabled, peercreds_resolve_enabled=peercreds_resolve_enabled,
reuse_port=reuse_port,
) )
self.wsgi_app = wsgi_app self.wsgi_app = wsgi_app
self.request_queue_size = request_queue_size self.request_queue_size = request_queue_size

View file

@ -8,7 +8,7 @@ class Server(server.HTTPServer):
timeout: Any timeout: Any
shutdown_timeout: Any shutdown_timeout: Any
requests: Any requests: Any
def __init__(self, bind_addr, wsgi_app, numthreads: int = ..., server_name: Any | None = ..., max: int = ..., request_queue_size: int = ..., timeout: int = ..., shutdown_timeout: int = ..., accepted_queue_size: int = ..., accepted_queue_timeout: int = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ...) -> None: ... def __init__(self, bind_addr, wsgi_app, numthreads: int = ..., server_name: Any | None = ..., max: int = ..., request_queue_size: int = ..., timeout: int = ..., shutdown_timeout: int = ..., accepted_queue_size: int = ..., accepted_queue_timeout: int = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ..., reuse_port: bool = ...) -> None: ...
@property @property
def numthreads(self): ... def numthreads(self): ...
@numthreads.setter @numthreads.setter

View file

@ -22,6 +22,7 @@ __all__ = [
"asyncquery", "asyncquery",
"asyncresolver", "asyncresolver",
"dnssec", "dnssec",
"dnssecalgs",
"dnssectypes", "dnssectypes",
"e164", "e164",
"edns", "edns",

View file

@ -35,6 +35,9 @@ class Socket: # pragma: no cover
async def getsockname(self): async def getsockname(self):
raise NotImplementedError raise NotImplementedError
async def getpeercert(self, timeout):
raise NotImplementedError
async def __aenter__(self): async def __aenter__(self):
return self return self
@ -61,6 +64,11 @@ class StreamSocket(Socket): # pragma: no cover
raise NotImplementedError raise NotImplementedError
class NullTransport:
async def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
class Backend: # pragma: no cover class Backend: # pragma: no cover
def name(self): def name(self):
return "unknown" return "unknown"
@ -83,3 +91,9 @@ class Backend: # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
raise NotImplementedError raise NotImplementedError
def get_transport_class(self):
raise NotImplementedError
async def wait_for(self, awaitable, timeout):
raise NotImplementedError

View file

@ -2,14 +2,13 @@
"""asyncio library query support""" """asyncio library query support"""
import socket
import asyncio import asyncio
import socket
import sys import sys
import dns._asyncbackend import dns._asyncbackend
import dns.exception import dns.exception
_is_win32 = sys.platform == "win32" _is_win32 = sys.platform == "win32"
@ -38,14 +37,21 @@ class _DatagramProtocol:
def connection_lost(self, exc): def connection_lost(self, exc):
if self.recvfrom and not self.recvfrom.done(): if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc) if exc is None:
# EOF we triggered. Is there a better way to do this?
try:
raise EOFError
except EOFError as e:
self.recvfrom.set_exception(e)
else:
self.recvfrom.set_exception(exc)
def close(self): def close(self):
self.transport.close() self.transport.close()
async def _maybe_wait_for(awaitable, timeout): async def _maybe_wait_for(awaitable, timeout):
if timeout: if timeout is not None:
try: try:
return await asyncio.wait_for(awaitable, timeout) return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -85,6 +91,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getsockname(self): async def getsockname(self):
return self.transport.get_extra_info("sockname") return self.transport.get_extra_info("sockname")
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer): def __init__(self, af, reader, writer):
@ -101,10 +110,6 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def close(self): async def close(self):
self.writer.close() self.writer.close()
try:
await self.writer.wait_closed()
except AttributeError: # pragma: no cover
pass
async def getpeername(self): async def getpeername(self):
return self.writer.get_extra_info("peername") return self.writer.get_extra_info("peername")
@ -112,6 +117,97 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def getsockname(self): async def getsockname(self):
return self.writer.get_extra_info("sockname") return self.writer.get_extra_info("sockname")
async def getpeercert(self, timeout):
return self.writer.get_extra_info("peercert")
try:
import anyio
import httpcore
import httpcore._backends.anyio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
if local_port != 0:
raise NotImplementedError(
"the asyncio transport for HTTPX cannot set the local port"
)
async def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
with anyio.fail_after(timeout):
stream = await anyio.connect_tcp(
remote_host=address,
remote_port=port,
local_host=local_address,
)
return _CoreAnyIOStream(stream)
except Exception:
pass
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await anyio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
except ImportError:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
@ -171,3 +267,9 @@ class Backend(dns._asyncbackend.Backend):
def datagram_connection_required(self): def datagram_connection_required(self):
return _is_win32 return _is_win32
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
return await _maybe_wait_for(awaitable, timeout)

View file

@ -1,122 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""curio async I/O library query support"""
import socket
import curio
import curio.socket # type: ignore
import dns._asyncbackend
import dns.exception
import dns.inet
def _maybe_timeout(timeout):
if timeout:
return curio.ignore_after(timeout)
else:
return dns._asyncbackend.NullContext()
# for brevity
_lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, socket):
self.socket = socket
self.family = socket.family
async def sendall(self, what, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
class Backend(dns._asyncbackend.Backend):
def name(self):
return "curio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto)
try:
if source:
s.bind(_lltuple(source, af))
except Exception: # pragma: no cover
await s.close()
raise
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
if source:
source_addr = _lltuple(source, af)
else:
source_addr = None
async with _maybe_timeout(timeout):
s = await curio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname,
)
return StreamSocket(s)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await curio.sleep(interval)

154
lib/dns/_ddr.py Normal file
View file

@ -0,0 +1,154 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
#
# Support for Discovery of Designated Resolvers
import socket
import time
from urllib.parse import urlparse
import dns.asyncbackend
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdtypes.svcbbase
# The special name of the local resolver when using DDR
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
#
# Processing is split up into I/O independent and I/O dependent parts to
# make supporting sync and async versions easy.
#
class _SVCBInfo:
def __init__(self, bootstrap_address, port, hostname, nameservers):
self.bootstrap_address = bootstrap_address
self.port = port
self.hostname = hostname
self.nameservers = nameservers
def ddr_check_certificate(self, cert):
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
for name, value in cert["subjectAltName"]:
if name == "IP Address" and value == self.bootstrap_address:
return True
return False
def make_tls_context(self):
ssl = dns.query.ssl
ctx = ssl.create_default_context()
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
return ctx
def ddr_tls_check_sync(self, lifetime):
ctx = self.make_tls_context()
expiration = time.time() + lifetime
with socket.create_connection(
(self.bootstrap_address, self.port), lifetime
) as s:
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
ts.settimeout(dns.query._remaining(expiration))
ts.do_handshake()
cert = ts.getpeercert()
return self.ddr_check_certificate(cert)
async def ddr_tls_check_async(self, lifetime, backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
ctx = self.make_tls_context()
expiration = time.time() + lifetime
async with await backend.make_socket(
dns.inet.af_for_address(self.bootstrap_address),
socket.SOCK_STREAM,
0,
None,
(self.bootstrap_address, self.port),
lifetime,
ctx,
self.hostname,
) as ts:
cert = await ts.getpeercert(dns.query._remaining(expiration))
return self.ddr_check_certificate(cert)
def _extract_nameservers_from_svcb(answer):
bootstrap_address = answer.nameserver
if not dns.inet.is_address(bootstrap_address):
return []
infos = []
for rr in answer.rrset.processing_order():
nameservers = []
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
if param is None:
continue
alpns = set(param.ids)
host = rr.target.to_text(omit_final_dot=True)
port = None
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
if param is not None:
port = param.port
# For now we ignore address hints and address resolution and always use the
# bootstrap address
if b"h2" in alpns:
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
if param is None or not param.value.endswith(b"{?dns}"):
continue
path = param.value[:-6].decode()
if not path.startswith("/"):
path = "/" + path
if port is None:
port = 443
url = f"https://{host}:{port}{path}"
# check the URL
try:
urlparse(url)
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
except Exception:
# continue processing other ALPN types
pass
if b"dot" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
)
if b"doq" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
)
if len(nameservers) > 0:
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
return infos
def _get_nameservers_sync(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if info.ddr_tls_check_sync(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers
async def _get_nameservers_async(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if await info.ddr_tls_check_async(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers

View file

@ -7,7 +7,6 @@
import contextvars import contextvars
import inspect import inspect
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) _in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)

View file

@ -3,6 +3,7 @@
"""trio async I/O library query support""" """trio async I/O library query support"""
import socket import socket
import trio import trio
import trio.socket # type: ignore import trio.socket # type: ignore
@ -12,7 +13,7 @@ import dns.inet
def _maybe_timeout(timeout): def _maybe_timeout(timeout):
if timeout: if timeout is not None:
return trio.move_on_after(timeout) return trio.move_on_after(timeout)
else: else:
return dns._asyncbackend.NullContext() return dns._asyncbackend.NullContext()
@ -50,6 +51,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getsockname(self): async def getsockname(self):
return self.socket.getsockname() return self.socket.getsockname()
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False): def __init__(self, family, stream, tls=False):
@ -82,6 +86,100 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
else: else:
return self.stream.socket.getsockname() return self.stream.socket.getsockname()
async def getpeercert(self, timeout):
if self.tls:
with _maybe_timeout(timeout):
await self.stream.do_handshake()
return self.stream.getpeercert()
else:
raise NotImplementedError
try:
import httpcore
import httpcore._backends.trio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreTrioStream = httpcore._backends.trio.TrioStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
async def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = (local_address, self._local_port)
else:
source = None
destination = (address, port)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
sock = await Backend().make_socket(
af, socket.SOCK_STREAM, 0, source, destination, timeout
)
return _CoreTrioStream(sock.stream)
except Exception:
continue
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await trio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
except ImportError:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
@ -104,8 +202,14 @@ class Backend(dns._asyncbackend.Backend):
if source: if source:
await s.bind(_lltuple(source, af)) await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM: if socktype == socket.SOCK_STREAM:
connected = False
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af)) await s.connect(_lltuple(destination, af))
connected = True
if not connected:
raise dns.exception.Timeout(
timeout=timeout
) # lgtm[py/unreachable-statement]
except Exception: # pragma: no cover except Exception: # pragma: no cover
s.close() s.close()
raise raise
@ -130,3 +234,13 @@ class Backend(dns._asyncbackend.Backend):
async def sleep(self, interval): async def sleep(self, interval):
await trio.sleep(interval) await trio.sleep(interval)
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
with _maybe_timeout(timeout):
return await awaitable
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]

View file

@ -5,13 +5,12 @@ from typing import Dict
import dns.exception import dns.exception
# pylint: disable=unused-import # pylint: disable=unused-import
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
from dns._asyncbackend import (
Socket,
DatagramSocket,
StreamSocket,
Backend, Backend,
) # noqa: F401 lgtm[py/unused-import] DatagramSocket,
Socket,
StreamSocket,
)
# pylint: enable=unused-import # pylint: enable=unused-import
@ -30,8 +29,8 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
def get_backend(name: str) -> Backend: def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend. """Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio", *name*, a ``str``, the name of the backend. Currently the "trio"
"curio", and "asyncio" backends are available. and "asyncio" backends are available.
Raises NotImplementError if an unknown backend name is specified. Raises NotImplementError if an unknown backend name is specified.
""" """
@ -43,10 +42,6 @@ def get_backend(name: str) -> Backend:
import dns._trio_backend import dns._trio_backend
backend = dns._trio_backend.Backend() backend = dns._trio_backend.Backend()
elif name == "curio":
import dns._curio_backend
backend = dns._curio_backend.Backend()
elif name == "asyncio": elif name == "asyncio":
import dns._asyncio_backend import dns._asyncio_backend
@ -73,9 +68,7 @@ def sniff() -> str:
try: try:
return sniffio.current_async_library() return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError: except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError( raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
"sniffio cannot determine " + "async library"
)
except ImportError: except ImportError:
import asyncio import asyncio

View file

@ -17,39 +17,38 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64 import base64
import contextlib import contextlib
import socket import socket
import struct import struct
import time import time
from typing import Any, Dict, Optional, Tuple, Union
import dns.asyncbackend import dns.asyncbackend
import dns.exception import dns.exception
import dns.inet import dns.inet
import dns.name
import dns.message import dns.message
import dns.name
import dns.quic import dns.quic
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.transaction import dns.transaction
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.query import ( from dns.query import (
_compute_times,
_matches_destination,
BadResponse, BadResponse,
ssl,
UDPMode,
_have_httpx,
_have_http2,
NoDOH, NoDOH,
NoDOQ, NoDOQ,
UDPMode,
_compute_times,
_have_http2,
_matches_destination,
_remaining,
have_doh,
ssl,
) )
if _have_httpx: if have_doh:
import httpx import httpx
# for brevity # for brevity
@ -73,7 +72,7 @@ def _source_tuple(af, address, port):
def _timeout(expiration, now=None): def _timeout(expiration, now=None):
if expiration: if expiration is not None:
if not now: if not now:
now = time.time() now = time.time()
return max(expiration - now, 0) return max(expiration - now, 0)
@ -445,9 +444,6 @@ async def tls(
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None: if server_hostname is None:
ssl_context.check_hostname = False ssl_context.check_hostname = False
else:
ssl_context = None
server_hostname = None
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port) stuple = _source_tuple(af, source, source_port)
dtuple = (where, port) dtuple = (where, port)
@ -495,6 +491,9 @@ async def https(
path: str = "/dns-query", path: str = "/dns-query",
post: bool = True, post: bool = True,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
@ -508,8 +507,10 @@ async def https(
parameters, exceptions, and return type of this method. parameters, exceptions, and return type of this method.
""" """
if not _have_httpx: if not have_doh:
raise NoDOH("httpx is not available.") # pragma: no cover raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
wire = q.to_wire() wire = q.to_wire()
try: try:
@ -518,15 +519,32 @@ async def https(
af = None af = None
transport = None transport = None
headers = {"accept": "application/dns-message"} headers = {"accept": "application/dns-message"}
if af is not None: if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET: if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path) url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path) url = "https://[{}]:{}{}".format(where, port, path)
else: else:
url = where url = where
if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0]) backend = dns.asyncbackend.get_default_backend()
if source is None:
local_address = None
local_port = 0
else:
local_address = source
local_port = source_port
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=_have_http2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if client: if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client) cm: contextlib.AbstractAsyncContextManager = NullContext(client)
@ -545,14 +563,14 @@ async def https(
"content-length": str(len(wire)), "content-length": str(len(wire)),
} }
) )
response = await the_client.post( response = await backend.wait_for(
url, headers=headers, content=wire, timeout=timeout the_client.post(url, headers=headers, content=wire), timeout
) )
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes twire = wire.decode() # httpx does a repr() if we give it bytes
response = await the_client.get( response = await backend.wait_for(
url, headers=headers, timeout=timeout, params={"dns": twire} the_client.get(url, headers=headers, params={"dns": twire}), timeout
) )
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
@ -690,6 +708,7 @@ async def quic(
connection: Optional[dns.quic.AsyncQuicConnection] = None, connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None, backend: Optional[dns.asyncbackend.Backend] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via """Return the response obtained after sending an asynchronous query via
DNS-over-QUIC. DNS-over-QUIC.
@ -715,14 +734,16 @@ async def quic(
(cfactory, mfactory) = dns.quic.factories_for_backend(backend) (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context: async with cfactory() as context:
async with mfactory(context, verify_mode=verify) as the_manager: async with mfactory(
context, verify_mode=verify, server_name=server_hostname
) as the_manager:
if not connection: if not connection:
the_connection = the_manager.connect(where, port, source, source_port) the_connection = the_manager.connect(where, port, source, source_port)
start = time.time() (start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream() stream = await the_connection.make_stream(timeout)
async with stream: async with stream:
await stream.send(wire, True) await stream.send(wire, True)
wire = await stream.receive(timeout) wire = await stream.receive(_remaining(expiration))
finish = time.time() finish = time.time()
r = dns.message.from_wire( r = dns.message.from_wire(
wire, wire,

View file

@ -17,10 +17,11 @@
"""Asynchronous DNS stub resolver.""" """Asynchronous DNS stub resolver."""
from typing import Any, Dict, Optional, Union import socket
import time import time
from typing import Any, Dict, List, Optional, Union
import dns._ddr
import dns.asyncbackend import dns.asyncbackend
import dns.asyncquery import dns.asyncquery
import dns.exception import dns.exception
@ -31,8 +32,7 @@ import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from] import dns.resolver # lgtm[py/import-and-import-from]
# import some resolver symbols for brevity # import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
# for indentation purposes below # for indentation purposes below
_udp = dns.asyncquery.udp _udp = dns.asyncquery.udp
@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver):
assert request is not None # needed for type checking assert request is not None # needed for type checking
done = False done = False
while not done: while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver() (nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff: if backoff:
await backend.sleep(backoff) await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors) timeout = self._compute_timeout(start, lifetime, resolution.errors)
try: try:
if dns.inet.is_address(nameserver): response = await nameserver.async_query(
if tcp: request,
response = await _tcp( timeout=timeout,
request, source=source,
nameserver, source_port=source_port,
timeout, max_size=tcp,
port, backend=backend,
source, )
source_port,
backend=backend,
)
else:
response = await _udp(
request,
nameserver,
timeout,
port,
source,
source_port,
raise_on_truncation=True,
backend=backend,
)
else:
response = await dns.asyncquery.https(
request, nameserver, timeout=timeout
)
except Exception as ex: except Exception as ex:
(_, done) = resolution.query_result(None, ex) (_, done) = resolution.query_result(None, ex)
continue continue
@ -153,6 +135,73 @@ class Resolver(dns.resolver.BaseResolver):
dns.reversename.from_address(ipaddr), *args, **modified_kwargs dns.reversename.from_address(ipaddr), *args, **modified_kwargs
) )
async def resolve_name(
self,
name: Union[dns.name.Name, str],
family: int = socket.AF_UNSPEC,
**kwargs: Any,
) -> dns.resolver.HostAnswers:
"""Use an asynchronous resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
the specified name.
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
*family*, an ``int``, the address family. If socket.AF_UNSPEC
(the default), both A and AAAA records will be retrieved.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs.pop("rdtype", None)
modified_kwargs["rdclass"] = dns.rdataclass.IN
if family == socket.AF_INET:
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
return dns.resolver.HostAnswers.make(v4=v4)
elif family == socket.AF_INET6:
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return dns.resolver.HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC:
raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
lifetime = modified_kwargs.pop("lifetime", None)
start = time.time()
v6 = await self.resolve(
name,
dns.rdatatype.AAAA,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
# Note that setting name ensures we query the same name
# for A as we did for AAAA. (This is just in case search lists
# are active by default in the resolver configuration and
# we might be talking to a server that says NXDOMAIN when it
# wants to say NOERROR no data.
name = v6.qname
v4 = await self.resolve(
name,
dns.rdatatype.A,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
answers = dns.resolver.HostAnswers.make(
v6=v6, v4=v4, add_empty=not raise_on_no_answer
)
if not answers:
raise NoAnswer(response=v6.response)
return answers
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
@ -176,6 +225,37 @@ class Resolver(dns.resolver.BaseResolver):
canonical_name = e.canonical_name canonical_name = e.canonical_name
return canonical_name return canonical_name
async def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
is 5 seconds.
If the SVCB query is successful and results in a non-empty list of nameservers,
then the resolver's nameservers are set to the returned servers in priority
order.
The current implementation does not use any address hints from the SVCB record,
nor does it resolve addresses for the SCVB target name, rather it assumes that
the bootstrap nameserver will always be one of the addresses and uses it.
A future revision to the code may offer fuller support. The code verifies that
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
try:
expiration = time.time() + lifetime
answer = await self.resolve(
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
)
timeout = dns.query._remaining(expiration)
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
if len(nameservers) > 0:
self.nameservers = nameservers
except Exception:
pass
default_resolver = None default_resolver = None
@ -246,6 +326,18 @@ async def resolve_address(
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def resolve_name(
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
) -> dns.resolver.HostAnswers:
"""Use a resolver to asynchronously query for address records.
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
information on the parameters.
"""
return await get_default_resolver().resolve_name(name, family, **kwargs)
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
@ -256,6 +348,16 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
return await get_default_resolver().canonical_name(name) return await get_default_resolver().canonical_name(name)
async def try_ddr(timeout: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
return await get_default_resolver().try_ddr(timeout)
async def zone_for_name( async def zone_for_name(
name: Union[dns.name.Name, str], name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
@ -290,3 +392,84 @@ async def zone_for_name(
name = name.parent() name = name.parent()
except dns.name.NoParent: # pragma: no cover except dns.name.NoParent: # pragma: no cover
raise NoRootSOA raise NoRootSOA
async def make_resolver_at(
where: Union[dns.name.Name, str],
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Optional[Resolver] = None,
) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
full resolver.
*port*, an ``int``, the port to use. If not specified, the default is 53.
*family*, an ``int``, the address family to use. This parameter is used if
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
the first address returned by ``resolve_name()`` will be used, otherwise the
first address of the specified family will be used.
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames. If not specified, the default resolver will be used.
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
if resolver is None:
resolver = get_default_resolver()
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
if isinstance(where, str) and dns.inet.is_address(where):
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
else:
answers = await resolver.resolve_name(where, family)
for address in answers.addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
res = dns.asyncresolver.Resolver(configure=False)
res.nameservers = nameservers
return res
async def resolve_at(
where: Union[dns.name.Name, str],
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Optional[Resolver] = None,
) -> dns.resolver.Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
to make a resolver, and then uses it to resolve the query.
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
resolver parameters *where*, *port*, *family*, and *resolver*.
If making more than one query, it is more efficient to call
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
res = await make_resolver_at(where, port, family, resolver)
return await res.resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)

View file

@ -17,50 +17,44 @@
"""Common DNSSEC-related functions and constants.""" """Common DNSSEC-related functions and constants."""
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
import base64
import contextlib
import functools
import hashlib import hashlib
import math
import struct import struct
import time import time
import base64
from datetime import datetime from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
import dns.exception import dns.exception
import dns.name import dns.name
import dns.node import dns.node
import dns.rdataset
import dns.rdata import dns.rdata
import dns.rdatatype
import dns.rdataclass import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rrset import dns.rrset
import dns.transaction
import dns.zone
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
from dns.exception import ( # pylint: disable=W0611
AlgorithmKeyMismatch,
DeniedByPolicy,
UnsupportedAlgorithm,
ValidationFailure,
)
from dns.rdtypes.ANY.CDNSKEY import CDNSKEY from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
from dns.rdtypes.ANY.CDS import CDS from dns.rdtypes.ANY.CDS import CDS
from dns.rdtypes.ANY.DNSKEY import DNSKEY from dns.rdtypes.ANY.DNSKEY import DNSKEY
from dns.rdtypes.ANY.DS import DS from dns.rdtypes.ANY.DS import DS
from dns.rdtypes.ANY.NSEC import NSEC, Bitmap
from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM
from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
from dns.rdtypes.dnskeybase import Flag from dns.rdtypes.dnskeybase import Flag
class UnsupportedAlgorithm(dns.exception.DNSException):
"""The DNSSEC algorithm is not supported."""
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
"""The DNSSEC algorithm is not supported for the given key type."""
class ValidationFailure(dns.exception.DNSException):
"""The DNSSEC signature is invalid."""
class DeniedByPolicy(dns.exception.DNSException):
"""Denied by DNSSEC policy."""
PublicKey = Union[ PublicKey = Union[
"GenericPublicKey",
"rsa.RSAPublicKey", "rsa.RSAPublicKey",
"ec.EllipticCurvePublicKey", "ec.EllipticCurvePublicKey",
"ed25519.Ed25519PublicKey", "ed25519.Ed25519PublicKey",
@ -68,12 +62,15 @@ PublicKey = Union[
] ]
PrivateKey = Union[ PrivateKey = Union[
"GenericPrivateKey",
"rsa.RSAPrivateKey", "rsa.RSAPrivateKey",
"ec.EllipticCurvePrivateKey", "ec.EllipticCurvePrivateKey",
"ed25519.Ed25519PrivateKey", "ed25519.Ed25519PrivateKey",
"ed448.Ed448PrivateKey", "ed448.Ed448PrivateKey",
] ]
RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None]
def algorithm_from_text(text: str) -> Algorithm: def algorithm_from_text(text: str) -> Algorithm:
"""Convert text into a DNSSEC algorithm value. """Convert text into a DNSSEC algorithm value.
@ -308,113 +305,13 @@ def _find_candidate_keys(
return [ return [
cast(DNSKEY, rd) cast(DNSKEY, rd)
for rd in rdataset for rd in rdataset
if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag if rd.algorithm == rrsig.algorithm
and key_id(rd) == rrsig.key_tag
and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1
and rd.protocol == 3 # RFC 4034 2.1.2
] ]
def _is_rsa(algorithm: int) -> bool:
return algorithm in (
Algorithm.RSAMD5,
Algorithm.RSASHA1,
Algorithm.RSASHA1NSEC3SHA1,
Algorithm.RSASHA256,
Algorithm.RSASHA512,
)
def _is_dsa(algorithm: int) -> bool:
return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
def _is_ecdsa(algorithm: int) -> bool:
return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
def _is_eddsa(algorithm: int) -> bool:
return algorithm in (Algorithm.ED25519, Algorithm.ED448)
def _is_gost(algorithm: int) -> bool:
return algorithm == Algorithm.ECCGOST
def _is_md5(algorithm: int) -> bool:
return algorithm == Algorithm.RSAMD5
def _is_sha1(algorithm: int) -> bool:
return algorithm in (
Algorithm.DSA,
Algorithm.RSASHA1,
Algorithm.DSANSEC3SHA1,
Algorithm.RSASHA1NSEC3SHA1,
)
def _is_sha256(algorithm: int) -> bool:
return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
def _is_sha384(algorithm: int) -> bool:
return algorithm == Algorithm.ECDSAP384SHA384
def _is_sha512(algorithm: int) -> bool:
return algorithm == Algorithm.RSASHA512
def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None:
"""Ensure algorithm is valid for key type, throwing an exception on
mismatch."""
if isinstance(key, rsa.RSAPublicKey):
if _is_rsa(algorithm):
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm)
if isinstance(key, dsa.DSAPublicKey):
if _is_dsa(algorithm):
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm)
if isinstance(key, ec.EllipticCurvePublicKey):
if _is_ecdsa(algorithm):
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm)
if isinstance(key, ed25519.Ed25519PublicKey):
if algorithm == Algorithm.ED25519:
return
raise AlgorithmKeyMismatch(
'algorithm "%s" not valid for ED25519 key' % algorithm
)
if isinstance(key, ed448.Ed448PublicKey):
if algorithm == Algorithm.ED448:
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm)
raise TypeError("unsupported key type")
def _make_hash(algorithm: int) -> Any:
if _is_md5(algorithm):
return hashes.MD5()
if _is_sha1(algorithm):
return hashes.SHA1()
if _is_sha256(algorithm):
return hashes.SHA256()
if _is_sha384(algorithm):
return hashes.SHA384()
if _is_sha512(algorithm):
return hashes.SHA512()
if algorithm == Algorithm.ED25519:
return hashes.SHA512()
if algorithm == Algorithm.ED448:
return hashes.SHAKE256(114)
raise ValidationFailure("unknown hash for algorithm %u" % algorithm)
def _bytes_to_long(b: bytes) -> int:
return int.from_bytes(b, "big")
def _get_rrname_rdataset( def _get_rrname_rdataset(
rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: ) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
@ -424,85 +321,13 @@ def _get_rrname_rdataset(
return rrset.name, rrset return rrset.name, rrset
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
keyptr: bytes public_cls = get_algorithm_cls_from_dnskey(key).public_cls
if _is_rsa(key.algorithm): try:
# we ignore because mypy is confused and thinks key.key is a str for unknown public_key = public_cls.from_dnskey(key)
# reasons. except ValueError:
keyptr = key.key raise ValidationFailure("invalid public key")
(bytes_,) = struct.unpack("!B", keyptr[0:1]) public_key.verify(sig, data)
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack("!H", keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
try:
rsa_public_key = rsa.RSAPublicNumbers(
_bytes_to_long(rsa_e), _bytes_to_long(rsa_n)
).public_key(default_backend())
except ValueError:
raise ValidationFailure("invalid public key")
rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
elif _is_dsa(key.algorithm):
keyptr = key.key
(t,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
try:
dsa_public_key = dsa.DSAPublicNumbers( # type: ignore
_bytes_to_long(dsa_y),
dsa.DSAParameterNumbers(
_bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g)
),
).public_key(default_backend())
except ValueError:
raise ValidationFailure("invalid public key")
dsa_public_key.verify(sig, data, chosen_hash)
elif _is_ecdsa(key.algorithm):
keyptr = key.key
curve: Any
if key.algorithm == Algorithm.ECDSAP256SHA256:
curve = ec.SECP256R1()
octets = 32
else:
curve = ec.SECP384R1()
octets = 48
ecdsa_x = keyptr[0:octets]
ecdsa_y = keyptr[octets : octets * 2]
try:
ecdsa_public_key = ec.EllipticCurvePublicNumbers(
curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y)
).public_key(default_backend())
except ValueError:
raise ValidationFailure("invalid public key")
ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
elif _is_eddsa(key.algorithm):
keyptr = key.key
loader: Any
if key.algorithm == Algorithm.ED25519:
loader = ed25519.Ed25519PublicKey
else:
loader = ed448.Ed448PublicKey
try:
eddsa_public_key = loader.from_public_bytes(keyptr)
except ValueError:
raise ValidationFailure("invalid public key")
eddsa_public_key.verify(sig, data)
elif _is_gost(key.algorithm):
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython'
% algorithm_to_text(key.algorithm)
)
else:
raise ValidationFailure("unknown algorithm %u" % key.algorithm)
def _validate_rrsig( def _validate_rrsig(
@ -559,29 +384,13 @@ def _validate_rrsig(
if rrsig.inception > now: if rrsig.inception > now:
raise ValidationFailure("not yet valid") raise ValidationFailure("not yet valid")
if _is_dsa(rrsig.algorithm):
sig_r = rrsig.signature[1:21]
sig_s = rrsig.signature[21:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
elif _is_ecdsa(rrsig.algorithm):
if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
octets = 32
else:
octets = 48
sig_r = rrsig.signature[0:octets]
sig_s = rrsig.signature[octets:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
else:
sig = rrsig.signature
data = _make_rrsig_signature_data(rrset, rrsig, origin) data = _make_rrsig_signature_data(rrset, rrsig, origin)
chosen_hash = _make_hash(rrsig.algorithm)
for candidate_key in candidate_keys: for candidate_key in candidate_keys:
if not policy.ok_to_validate(candidate_key): if not policy.ok_to_validate(candidate_key):
continue continue
try: try:
_validate_signature(sig, data, candidate_key, chosen_hash) _validate_signature(rrsig.signature, data, candidate_key)
return return
except (InvalidSignature, ValidationFailure): except (InvalidSignature, ValidationFailure):
# this happens on an individual validation failure # this happens on an individual validation failure
@ -673,6 +482,7 @@ def _sign(
lifetime: Optional[int] = None, lifetime: Optional[int] = None,
verify: bool = False, verify: bool = False,
policy: Optional[Policy] = None, policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
) -> RRSIG: ) -> RRSIG:
"""Sign RRset using private key. """Sign RRset using private key.
@ -708,6 +518,10 @@ def _sign(
*policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy,
``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624.
*origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all
names in the rrset (including its owner name) must be absolute; otherwise the
specified origin will be used to make names absolute when signing.
Raises ``DeniedByPolicy`` if the signature is denied by policy. Raises ``DeniedByPolicy`` if the signature is denied by policy.
""" """
@ -735,16 +549,26 @@ def _sign(
if expiration is not None: if expiration is not None:
rrsig_expiration = to_timestamp(expiration) rrsig_expiration = to_timestamp(expiration)
elif lifetime is not None: elif lifetime is not None:
rrsig_expiration = int(time.time()) + lifetime rrsig_expiration = rrsig_inception + lifetime
else: else:
raise ValueError("expiration or lifetime must be specified") raise ValueError("expiration or lifetime must be specified")
# Derelativize now because we need a correct labels length for the
# rrsig_template.
if origin is not None:
rrname = rrname.derelativize(origin)
labels = len(rrname) - 1
# Adjust labels appropriately for wildcards.
if rrname.is_wild():
labels -= 1
rrsig_template = RRSIG( rrsig_template = RRSIG(
rdclass=rdclass, rdclass=rdclass,
rdtype=dns.rdatatype.RRSIG, rdtype=dns.rdatatype.RRSIG,
type_covered=rdtype, type_covered=rdtype,
algorithm=dnskey.algorithm, algorithm=dnskey.algorithm,
labels=len(rrname) - 1, labels=labels,
original_ttl=original_ttl, original_ttl=original_ttl,
expiration=rrsig_expiration, expiration=rrsig_expiration,
inception=rrsig_inception, inception=rrsig_inception,
@ -753,63 +577,18 @@ def _sign(
signature=b"", signature=b"",
) )
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template) data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
chosen_hash = _make_hash(rrsig_template.algorithm)
signature = None
if isinstance(private_key, rsa.RSAPrivateKey): if isinstance(private_key, GenericPrivateKey):
if not _is_rsa(dnskey.algorithm): signing_key = private_key
raise ValueError("Invalid DNSKEY algorithm for RSA key")
signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash)
if verify:
private_key.public_key().verify(
signature, data, padding.PKCS1v15(), chosen_hash
)
elif isinstance(private_key, dsa.DSAPrivateKey):
if not _is_dsa(dnskey.algorithm):
raise ValueError("Invalid DNSKEY algorithm for DSA key")
public_dsa_key = private_key.public_key()
if public_dsa_key.key_size > 1024:
raise ValueError("DSA key size overflow")
der_signature = private_key.sign(data, chosen_hash)
if verify:
public_dsa_key.verify(der_signature, data, chosen_hash)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
octets = 20
signature = (
struct.pack("!B", dsa_t)
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
)
elif isinstance(private_key, ec.EllipticCurvePrivateKey):
if not _is_ecdsa(dnskey.algorithm):
raise ValueError("Invalid DNSKEY algorithm for EC key")
der_signature = private_key.sign(data, ec.ECDSA(chosen_hash))
if verify:
private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash))
if dnskey.algorithm == Algorithm.ECDSAP256SHA256:
octets = 32
else:
octets = 48
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes(
dsa_s, length=octets, byteorder="big"
)
elif isinstance(private_key, ed25519.Ed25519PrivateKey):
if dnskey.algorithm != Algorithm.ED25519:
raise ValueError("Invalid DNSKEY algorithm for ED25519 key")
signature = private_key.sign(data)
if verify:
private_key.public_key().verify(signature, data)
elif isinstance(private_key, ed448.Ed448PrivateKey):
if dnskey.algorithm != Algorithm.ED448:
raise ValueError("Invalid DNSKEY algorithm for ED448 key")
signature = private_key.sign(data)
if verify:
private_key.public_key().verify(signature, data)
else: else:
raise TypeError("Unsupported key algorithm") try:
private_cls = get_algorithm_cls_from_dnskey(dnskey)
signing_key = private_cls(key=private_key)
except UnsupportedAlgorithm:
raise TypeError("Unsupported key algorithm")
signature = signing_key.sign(data, verify)
return cast(RRSIG, rrsig_template.replace(signature=signature)) return cast(RRSIG, rrsig_template.replace(signature=signature))
@ -858,9 +637,12 @@ def _make_rrsig_signature_data(
raise ValidationFailure("relative RR name without an origin specified") raise ValidationFailure("relative RR name without an origin specified")
rrname = rrname.derelativize(origin) rrname = rrname.derelativize(origin)
if len(rrname) - 1 < rrsig.labels: name_len = len(rrname)
if rrname.is_wild() and rrsig.labels != name_len - 2:
raise ValidationFailure("wild owner name has wrong label length")
if name_len - 1 < rrsig.labels:
raise ValidationFailure("owner name longer than RRSIG labels") raise ValidationFailure("owner name longer than RRSIG labels")
elif rrsig.labels < len(rrname) - 1: elif rrsig.labels < name_len - 1:
suffix = rrname.split(rrsig.labels + 1)[1] suffix = rrname.split(rrsig.labels + 1)[1]
rrname = dns.name.from_text("*", suffix) rrname = dns.name.from_text("*", suffix)
rrnamebuf = rrname.to_digestable() rrnamebuf = rrname.to_digestable()
@ -884,9 +666,8 @@ def _make_dnskey(
) -> DNSKEY: ) -> DNSKEY:
"""Convert a public key to DNSKEY Rdata """Convert a public key to DNSKEY Rdata
*public_key*, the public key to convert, a *public_key*, a ``PublicKey`` (``GenericPublicKey`` or
``cryptography.hazmat.primitives.asymmetric`` public key class applicable ``cryptography.hazmat.primitives.asymmetric``) to convert.
for DNSSEC.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
@ -902,72 +683,13 @@ def _make_dnskey(
Return DNSKEY ``Rdata``. Return DNSKEY ``Rdata``.
""" """
def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes: algorithm = Algorithm.make(algorithm)
"""Encode a public key per RFC 3110, section 2."""
pn = public_key.public_numbers()
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
if _exp_len > 255:
exp_header = b"\0" + struct.pack("!H", _exp_len)
else:
exp_header = struct.pack("!B", _exp_len)
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
raise ValueError("unsupported RSA key length")
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes: if isinstance(public_key, GenericPublicKey):
"""Encode a public key per RFC 2536, section 2.""" return public_key.to_dnskey(flags=flags, protocol=protocol)
pn = public_key.public_numbers()
dsa_t = (public_key.key_size // 8 - 64) // 8
if dsa_t > 8:
raise ValueError("unsupported DSA key size")
octets = 64 + dsa_t * 8
res = struct.pack("!B", dsa_t)
res += pn.parameter_numbers.q.to_bytes(20, "big")
res += pn.parameter_numbers.p.to_bytes(octets, "big")
res += pn.parameter_numbers.g.to_bytes(octets, "big")
res += pn.y.to_bytes(octets, "big")
return res
def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes:
"""Encode a public key per RFC 6605, section 4."""
pn = public_key.public_numbers()
if isinstance(public_key.curve, ec.SECP256R1):
return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big")
elif isinstance(public_key.curve, ec.SECP384R1):
return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big")
else:
raise ValueError("unsupported ECDSA curve")
the_algorithm = Algorithm.make(algorithm)
_ensure_algorithm_key_combination(the_algorithm, public_key)
if isinstance(public_key, rsa.RSAPublicKey):
key_bytes = encode_rsa_public_key(public_key)
elif isinstance(public_key, dsa.DSAPublicKey):
key_bytes = encode_dsa_public_key(public_key)
elif isinstance(public_key, ec.EllipticCurvePublicKey):
key_bytes = encode_ecdsa_public_key(public_key)
elif isinstance(public_key, ed25519.Ed25519PublicKey):
key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
elif isinstance(public_key, ed448.Ed448PublicKey):
key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
else: else:
raise TypeError("unsupported key algorithm") public_cls = get_algorithm_cls(algorithm).public_cls
return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol)
return DNSKEY(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
algorithm=the_algorithm,
key=key_bytes,
)
def _make_cdnskey( def _make_cdnskey(
@ -1216,23 +938,252 @@ def dnskey_rdataset_to_cdnskey_rdataset(
return dns.rdataset.from_rdata_list(rdataset.ttl, res) return dns.rdataset.from_rdata_list(rdataset.ttl, res)
def default_rrset_signer(
txn: dns.transaction.Transaction,
rrset: dns.rrset.RRset,
signer: dns.name.Name,
ksks: List[Tuple[PrivateKey, DNSKEY]],
zsks: List[Tuple[PrivateKey, DNSKEY]],
inception: Optional[Union[datetime, str, int, float]] = None,
expiration: Optional[Union[datetime, str, int, float]] = None,
lifetime: Optional[int] = None,
policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
) -> None:
"""Default RRset signer"""
if rrset.rdtype in set(
[
dns.rdatatype.RdataType.DNSKEY,
dns.rdatatype.RdataType.CDS,
dns.rdatatype.RdataType.CDNSKEY,
]
):
keys = ksks
else:
keys = zsks
for private_key, dnskey in keys:
rrsig = dns.dnssec.sign(
rrset=rrset,
private_key=private_key,
dnskey=dnskey,
inception=inception,
expiration=expiration,
lifetime=lifetime,
signer=signer,
policy=policy,
origin=origin,
)
txn.add(rrset.name, rrset.ttl, rrsig)
def sign_zone(
zone: dns.zone.Zone,
txn: Optional[dns.transaction.Transaction] = None,
keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None,
add_dnskey: bool = True,
dnskey_ttl: Optional[int] = None,
inception: Optional[Union[datetime, str, int, float]] = None,
expiration: Optional[Union[datetime, str, int, float]] = None,
lifetime: Optional[int] = None,
nsec3: Optional[NSEC3PARAM] = None,
rrset_signer: Optional[RRsetSigner] = None,
policy: Optional[Policy] = None,
) -> None:
"""Sign zone.
*zone*, a ``dns.zone.Zone``, the zone to sign.
*txn*, a ``dns.transaction.Transaction``, an optional transaction to use for
signing.
*keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK
roles are assigned automatically if the SEP flag is used, otherwise all RRsets are
signed by all keys.
*add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are
automatically added to the zone on signing.
*dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL
of the existing DNSKEY RRset used or the TTL of the SOA RRset.
*inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
inception time. If ``None``, the current time is used. If a ``str``, the format is
"YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text
form; this is the same the RRSIG rdata's text form. Values of type `int` or `float`
are interpreted as seconds since the UNIX epoch.
*expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
expiration time. If ``None``, the expiration time will be the inception time plus
the value of the *lifetime* parameter. See the description of *inception* above for
how the various parameter types are interpreted.
*lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This
parameter is only meaningful if *expiration* is ``None``.
*nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet
implemented.
*rrset_signer*, a ``Callable``, an optional function for signing RRsets. The
function requires two arguments: transaction and RRset. If the not specified,
``dns.dnssec.default_rrset_signer`` will be used.
Returns ``None``.
"""
ksks = []
zsks = []
# if we have both KSKs and ZSKs, split by SEP flag. if not, sign all
# records with all keys
if keys:
for key in keys:
if key[1].flags & Flag.SEP:
ksks.append(key)
else:
zsks.append(key)
if not ksks:
ksks = keys
if not zsks:
zsks = keys
else:
keys = []
if txn:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn)
else:
cm = zone.writer()
with cm as _txn:
if add_dnskey:
if dnskey_ttl is None:
dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY)
if dnskey:
dnskey_ttl = dnskey.ttl
else:
soa = _txn.get(zone.origin, dns.rdatatype.SOA)
dnskey_ttl = soa.ttl
for _, dnskey in keys:
_txn.add(zone.origin, dnskey_ttl, dnskey)
if nsec3:
raise NotImplementedError("Signing with NSEC3 not yet implemented")
else:
_rrset_signer = rrset_signer or functools.partial(
default_rrset_signer,
signer=zone.origin,
ksks=ksks,
zsks=zsks,
inception=inception,
expiration=expiration,
lifetime=lifetime,
policy=policy,
origin=zone.origin,
)
return _sign_zone_nsec(zone, _txn, _rrset_signer)
def _sign_zone_nsec(
zone: dns.zone.Zone,
txn: dns.transaction.Transaction,
rrset_signer: Optional[RRsetSigner] = None,
) -> None:
"""NSEC zone signer"""
def _txn_add_nsec(
txn: dns.transaction.Transaction,
name: dns.name.Name,
next_secure: Optional[dns.name.Name],
rdclass: dns.rdataclass.RdataClass,
ttl: int,
rrset_signer: Optional[RRsetSigner] = None,
) -> None:
"""NSEC zone signer helper"""
mandatory_types = set(
[dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC]
)
node = txn.get_node(name)
if node and next_secure:
types = (
set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types
)
windows = Bitmap.from_rdtypes(list(types))
rrset = dns.rrset.from_rdata(
name,
ttl,
NSEC(
rdclass=rdclass,
rdtype=dns.rdatatype.RdataType.NSEC,
next=next_secure,
windows=windows,
),
)
txn.add(rrset)
if rrset_signer:
rrset_signer(txn, rrset)
rrsig_ttl = zone.get_soa().minimum
delegation = None
last_secure = None
for name in sorted(txn.iterate_names()):
if delegation and name.is_subdomain(delegation):
# names below delegations are not secure
continue
elif txn.get(name, dns.rdatatype.NS) and name != zone.origin:
# inside delegation
delegation = name
else:
# outside delegation
delegation = None
if rrset_signer:
node = txn.get_node(name)
if node:
for rdataset in node.rdatasets:
if rdataset.rdtype == dns.rdatatype.RRSIG:
# do not sign RRSIGs
continue
elif delegation and rdataset.rdtype != dns.rdatatype.DS:
# do not sign delegations except DS records
continue
else:
rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset)
rrset_signer(txn, rrset)
# We need "is not None" as the empty name is False because its length is 0.
if last_secure is not None:
_txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer)
last_secure = name
if last_secure:
_txn_add_nsec(
txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer
)
def _need_pyca(*args, **kwargs): def _need_pyca(*args, **kwargs):
raise ImportError( raise ImportError(
"DNSSEC validation requires " + "python cryptography" "DNSSEC validation requires python cryptography"
) # pragma: no cover ) # pragma: no cover
try: try:
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import utils from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import ec ed25519,
from cryptography.hazmat.primitives.asymmetric import ed25519 )
from cryptography.hazmat.primitives.asymmetric import ed448
from cryptography.hazmat.primitives.asymmetric import rsa from dns.dnssecalgs import ( # pylint: disable=C0412
get_algorithm_cls,
get_algorithm_cls_from_dnskey,
)
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
validate = _need_pyca validate = _need_pyca
validate_rrsig = _need_pyca validate_rrsig = _need_pyca

View file

@ -0,0 +1,121 @@
from typing import Dict, Optional, Tuple, Type, Union
import dns.name
try:
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
from dns.dnssecalgs.rsa import (
PrivateRSAMD5,
PrivateRSASHA1,
PrivateRSASHA1NSEC3SHA1,
PrivateRSASHA256,
PrivateRSASHA512,
)
_have_cryptography = True
except ImportError:
_have_cryptography = False
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography:
algorithms.update(
{
(Algorithm.RSAMD5, None): PrivateRSAMD5,
(Algorithm.DSA, None): PrivateDSA,
(Algorithm.RSASHA1, None): PrivateRSASHA1,
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
(Algorithm.RSASHA256, None): PrivateRSASHA256,
(Algorithm.RSASHA512, None): PrivateRSASHA512,
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
(Algorithm.ED25519, None): PrivateED25519,
(Algorithm.ED448, None): PrivateED448,
}
)
def get_algorithm_cls(
algorithm: Union[int, str], prefix: AlgorithmPrefix = None
) -> Type[GenericPrivateKey]:
"""Get Private Key class from Algorithm.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
algorithm = Algorithm.make(algorithm)
cls = algorithms.get((algorithm, prefix))
if cls:
return cls
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
)
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
"""Get Private Key class from DNSKEY.
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
prefix: AlgorithmPrefix = None
if dnskey.algorithm == Algorithm.PRIVATEDNS:
prefix, _ = dns.name.from_wire(dnskey.key, 0)
elif dnskey.algorithm == Algorithm.PRIVATEOID:
length = int(dnskey.key[0])
prefix = dnskey.key[0 : length + 1]
return get_algorithm_cls(dnskey.algorithm, prefix)
def register_algorithm_cls(
algorithm: Union[int, str],
algorithm_cls: Type[GenericPrivateKey],
name: Optional[Union[dns.name.Name, str]] = None,
oid: Optional[bytes] = None,
) -> None:
"""Register Algorithm Private Key class.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
*algorithm_cls*: A `GenericPrivateKey` class.
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
Raises ``ValueError`` if a name or oid is specified incorrectly.
"""
if not issubclass(algorithm_cls, GenericPrivateKey):
raise TypeError("Invalid algorithm class")
algorithm = Algorithm.make(algorithm)
prefix: AlgorithmPrefix = None
if algorithm == Algorithm.PRIVATEDNS:
if name is None:
raise ValueError("Name required for PRIVATEDNS algorithms")
if isinstance(name, str):
name = dns.name.from_text(name)
prefix = name
elif algorithm == Algorithm.PRIVATEOID:
if oid is None:
raise ValueError("OID required for PRIVATEOID algorithms")
prefix = bytes([len(oid)]) + oid
elif name:
raise ValueError("Name only supported for PRIVATEDNS algorithm")
elif oid:
raise ValueError("OID only supported for PRIVATEOID algorithm")
algorithms[(algorithm, prefix)] = algorithm_cls

View file

@ -0,0 +1,84 @@
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
from typing import Any, Optional, Type
import dns.rdataclass
import dns.rdatatype
from dns.dnssectypes import Algorithm
from dns.exception import AlgorithmKeyMismatch
from dns.rdtypes.ANY.DNSKEY import DNSKEY
from dns.rdtypes.dnskeybase import Flag
class GenericPublicKey(ABC):
algorithm: Algorithm
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def verify(self, signature: bytes, data: bytes) -> None:
"""Verify signed DNSSEC data"""
@abstractmethod
def encode_key_bytes(self) -> bytes:
"""Encode key as bytes for DNSKEY"""
@classmethod
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
if key.algorithm != cls.algorithm:
raise AlgorithmKeyMismatch
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
"""Return public key as DNSKEY"""
return DNSKEY(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
algorithm=self.algorithm,
key=self.encode_key_bytes(),
)
@classmethod
@abstractmethod
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
"""Create public key from DNSKEY"""
@classmethod
@abstractmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
@abstractmethod
def to_pem(self) -> bytes:
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
class GenericPrivateKey(ABC):
public_cls: Type[GenericPublicKey]
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign DNSSEC data"""
@abstractmethod
def public_key(self) -> "GenericPublicKey":
"""Return public key instance"""
@classmethod
@abstractmethod
def from_pem(
cls, private_pem: bytes, password: Optional[bytes] = None
) -> "GenericPrivateKey":
"""Create private key from PEM-encoded PKCS#8"""
@abstractmethod
def to_pem(self, password: Optional[bytes] = None) -> bytes:
"""Return private key as PEM-encoded PKCS#8"""

View file

@ -0,0 +1,68 @@
from typing import Any, Optional, Type
from cryptography.hazmat.primitives import serialization
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
from dns.exception import AlgorithmKeyMismatch
class CryptographyPublicKey(GenericPublicKey):
key: Any = None
key_cls: Any = None
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
@classmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
key = serialization.load_pem_public_key(public_pem)
return cls(key=key)
def to_pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
class CryptographyPrivateKey(GenericPrivateKey):
key: Any = None
key_cls: Any = None
public_cls: Type[CryptographyPublicKey]
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
def public_key(self) -> "CryptographyPublicKey":
return self.public_cls(key=self.key.public_key())
@classmethod
def from_pem(
cls, private_pem: bytes, password: Optional[bytes] = None
) -> "GenericPrivateKey":
key = serialization.load_pem_private_key(private_pem, password=password)
return cls(key=key)
def to_pem(self, password: Optional[bytes] = None) -> bytes:
encryption_algorithm: serialization.KeySerializationEncryption
if password:
encryption_algorithm = serialization.BestAvailableEncryption(password)
else:
encryption_algorithm = serialization.NoEncryption()
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)

101
lib/dns/dnssecalgs/dsa.py Normal file
View file

@ -0,0 +1,101 @@
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import dsa, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicDSA(CryptographyPublicKey):
key: dsa.DSAPublicKey
key_cls = dsa.DSAPublicKey
algorithm = Algorithm.DSA
chosen_hash = hashes.SHA1()
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[1:21]
sig_s = signature[21:]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 2536, section 2."""
pn = self.key.public_numbers()
dsa_t = (self.key.key_size // 8 - 64) // 8
if dsa_t > 8:
raise ValueError("unsupported DSA key size")
octets = 64 + dsa_t * 8
res = struct.pack("!B", dsa_t)
res += pn.parameter_numbers.q.to_bytes(20, "big")
res += pn.parameter_numbers.p.to_bytes(octets, "big")
res += pn.parameter_numbers.g.to_bytes(octets, "big")
res += pn.y.to_bytes(octets, "big")
return res
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(t,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
return cls(
key=dsa.DSAPublicNumbers( # type: ignore
int.from_bytes(dsa_y, "big"),
dsa.DSAParameterNumbers(
int.from_bytes(dsa_p, "big"),
int.from_bytes(dsa_q, "big"),
int.from_bytes(dsa_g, "big"),
),
).public_key(default_backend()),
)
class PrivateDSA(CryptographyPrivateKey):
key: dsa.DSAPrivateKey
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024:
raise ValueError("DSA key size overflow")
der_signature = self.key.sign(data, self.public_cls.chosen_hash)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
octets = 20
signature = (
struct.pack("!B", dsa_t)
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateDSA":
return cls(
key=dsa.generate_private_key(key_size=key_size),
)
class PublicDSANSEC3SHA1(PublicDSA):
algorithm = Algorithm.DSANSEC3SHA1
class PrivateDSANSEC3SHA1(PrivateDSA):
public_cls = PublicDSANSEC3SHA1

View file

@ -0,0 +1,89 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicECDSA(CryptographyPublicKey):
key: ec.EllipticCurvePublicKey
key_cls = ec.EllipticCurvePublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
curve: ec.EllipticCurve
octets: int
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[0 : self.octets]
sig_s = signature[self.octets :]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 6605, section 4."""
pn = self.key.public_numbers()
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
cls._ensure_algorithm_key_combination(key)
ecdsa_x = key.key[0 : cls.octets]
ecdsa_y = key.key[cls.octets : cls.octets * 2]
return cls(
key=ec.EllipticCurvePublicNumbers(
curve=cls.curve,
x=int.from_bytes(ecdsa_x, "big"),
y=int.from_bytes(ecdsa_y, "big"),
).public_key(default_backend()),
)
class PrivateECDSA(CryptographyPrivateKey):
key: ec.EllipticCurvePrivateKey
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big"
) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big")
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateECDSA":
return cls(
key=ec.generate_private_key(
curve=cls.public_cls.curve, backend=default_backend()
),
)
class PublicECDSAP256SHA256(PublicECDSA):
algorithm = Algorithm.ECDSAP256SHA256
chosen_hash = hashes.SHA256()
curve = ec.SECP256R1()
octets = 32
class PrivateECDSAP256SHA256(PrivateECDSA):
public_cls = PublicECDSAP256SHA256
class PublicECDSAP384SHA384(PublicECDSA):
algorithm = Algorithm.ECDSAP384SHA384
chosen_hash = hashes.SHA384()
curve = ec.SECP384R1()
octets = 48
class PrivateECDSAP384SHA384(PrivateECDSA):
public_cls = PublicECDSAP384SHA384

View file

@ -0,0 +1,65 @@
from typing import Type
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicEDDSA(CryptographyPublicKey):
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 8080, section 3."""
return self.key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
cls._ensure_algorithm_key_combination(key)
return cls(
key=cls.key_cls.from_public_bytes(key.key),
)
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA]
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateEDDSA":
return cls(key=cls.key_cls.generate())
class PublicED25519(PublicEDDSA):
key: ed25519.Ed25519PublicKey
key_cls = ed25519.Ed25519PublicKey
algorithm = Algorithm.ED25519
class PrivateED25519(PrivateEDDSA):
key: ed25519.Ed25519PrivateKey
key_cls = ed25519.Ed25519PrivateKey
public_cls = PublicED25519
class PublicED448(PublicEDDSA):
key: ed448.Ed448PublicKey
key_cls = ed448.Ed448PublicKey
algorithm = Algorithm.ED448
class PrivateED448(PrivateEDDSA):
key: ed448.Ed448PrivateKey
key_cls = ed448.Ed448PrivateKey
public_cls = PublicED448

119
lib/dns/dnssecalgs/rsa.py Normal file
View file

@ -0,0 +1,119 @@
import math
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicRSA(CryptographyPublicKey):
key: rsa.RSAPublicKey
key_cls = rsa.RSAPublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 3110, section 2."""
pn = self.key.public_numbers()
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
if _exp_len > 255:
exp_header = b"\0" + struct.pack("!H", _exp_len)
else:
exp_header = struct.pack("!B", _exp_len)
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
raise ValueError("unsupported RSA key length")
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(bytes_,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack("!H", keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
return cls(
key=rsa.RSAPublicNumbers(
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
).public_key(default_backend())
)
class PrivateRSA(CryptographyPrivateKey):
key: rsa.RSAPrivateKey
key_cls = rsa.RSAPrivateKey
public_cls = PublicRSA
default_public_exponent = 65537
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateRSA":
return cls(
key=rsa.generate_private_key(
public_exponent=cls.default_public_exponent,
key_size=key_size,
backend=default_backend(),
)
)
class PublicRSAMD5(PublicRSA):
algorithm = Algorithm.RSAMD5
chosen_hash = hashes.MD5()
class PrivateRSAMD5(PrivateRSA):
public_cls = PublicRSAMD5
class PublicRSASHA1(PublicRSA):
algorithm = Algorithm.RSASHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1(PrivateRSA):
public_cls = PublicRSASHA1
class PublicRSASHA1NSEC3SHA1(PublicRSA):
algorithm = Algorithm.RSASHA1NSEC3SHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
public_cls = PublicRSASHA1NSEC3SHA1
class PublicRSASHA256(PublicRSA):
algorithm = Algorithm.RSASHA256
chosen_hash = hashes.SHA256()
class PrivateRSASHA256(PrivateRSA):
public_cls = PublicRSASHA256
class PublicRSASHA512(PublicRSA):
algorithm = Algorithm.RSASHA512
chosen_hash = hashes.SHA512()
class PrivateRSASHA512(PrivateRSA):
public_cls = PublicRSASHA512

View file

@ -17,11 +17,10 @@
"""EDNS Options""" """EDNS Options"""
from typing import Any, Dict, Optional, Union
import math import math
import socket import socket
import struct import struct
from typing import Any, Dict, Optional, Union
import dns.enum import dns.enum
import dns.inet import dns.inet
@ -380,7 +379,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
def from_wire_parser( def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option: ) -> Option:
the_code = EDECode.make(parser.get_uint16()) code = EDECode.make(parser.get_uint16())
text = parser.get_remaining() text = parser.get_remaining()
if text: if text:
@ -390,7 +389,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
else: else:
btext = None btext = None
return cls(the_code, btext) return cls(code, btext)
_type_to_class: Dict[OptionType, Any] = { _type_to_class: Dict[OptionType, Any] = {
@ -424,8 +423,8 @@ def option_from_wire_parser(
Returns an instance of a subclass of ``dns.edns.Option``. Returns an instance of a subclass of ``dns.edns.Option``.
""" """
the_otype = OptionType.make(otype) otype = OptionType.make(otype)
cls = get_option_class(the_otype) cls = get_option_class(otype)
return cls.from_wire_parser(otype, parser) return cls.from_wire_parser(otype, parser)

View file

@ -15,17 +15,15 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from typing import Any, Optional
import os
import hashlib import hashlib
import os
import random import random
import threading import threading
import time import time
from typing import Any, Optional
class EntropyPool: class EntropyPool:
# This is an entropy pool for Python implementations that do not # This is an entropy pool for Python implementations that do not
# have a working SystemRandom. I'm not sure there are any, but # have a working SystemRandom. I'm not sure there are any, but
# leaving this code doesn't hurt anything as the library code # leaving this code doesn't hurt anything as the library code

View file

@ -16,18 +16,31 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import enum import enum
from typing import Type, TypeVar, Union
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
class IntEnum(enum.IntEnum): class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _check_value(cls, value): def _missing_(cls, value):
max = cls._maximum() cls._check_value(value)
if value < 0 or value > max: val = int.__new__(cls, value)
name = cls._short_name() val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
raise ValueError(f"{name} must be between >= 0 and <= {max}") val._value_ = value
return val
@classmethod @classmethod
def from_text(cls, text): def _check_value(cls, value):
max = cls._maximum()
if not isinstance(value, int):
raise TypeError
if value < 0 or value > max:
name = cls._short_name()
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
@classmethod
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
text = text.upper() text = text.upper()
try: try:
return cls[text] return cls[text]
@ -47,7 +60,7 @@ class IntEnum(enum.IntEnum):
raise cls._unknown_exception_class() raise cls._unknown_exception_class()
@classmethod @classmethod
def to_text(cls, value): def to_text(cls: Type[TIntEnum], value: int) -> str:
cls._check_value(value) cls._check_value(value)
try: try:
text = cls(value).name text = cls(value).name
@ -59,7 +72,7 @@ class IntEnum(enum.IntEnum):
return text return text
@classmethod @classmethod
def make(cls, value): def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
"""Convert text or a value into an enumerated type, if possible. """Convert text or a value into an enumerated type, if possible.
*value*, the ``int`` or ``str`` to convert. *value*, the ``int`` or ``str`` to convert.
@ -76,10 +89,7 @@ class IntEnum(enum.IntEnum):
if isinstance(value, str): if isinstance(value, str):
return cls.from_text(value) return cls.from_text(value)
cls._check_value(value) cls._check_value(value)
try: return cls(value)
return cls(value)
except ValueError:
return value
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):

View file

@ -140,6 +140,22 @@ class Timeout(DNSException):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class UnsupportedAlgorithm(DNSException):
"""The DNSSEC algorithm is not supported."""
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
"""The DNSSEC algorithm is not supported for the given key type."""
class ValidationFailure(DNSException):
"""The DNSSEC signature is invalid."""
class DeniedByPolicy(DNSException):
"""Denied by DNSSEC policy."""
class ExceptionWrapper: class ExceptionWrapper:
def __init__(self, exception_class): def __init__(self, exception_class):
self.exception_class = exception_class self.exception_class = exception_class

View file

@ -17,9 +17,8 @@
"""DNS Message Flags.""" """DNS Message Flags."""
from typing import Any
import enum import enum
from typing import Any
# Standard DNS flags # Standard DNS flags

View file

@ -1,8 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Any
import collections.abc import collections.abc
from typing import Any
from dns._immutable_ctx import immutable from dns._immutable_ctx import immutable

View file

@ -17,14 +17,12 @@
"""Generic Internet address helper functions.""" """Generic Internet address helper functions."""
from typing import Any, Optional, Tuple
import socket import socket
from typing import Any, Optional, Tuple
import dns.ipv4 import dns.ipv4
import dns.ipv6 import dns.ipv6
# We assume that AF_INET and AF_INET6 are always defined. We keep # We assume that AF_INET and AF_INET6 are always defined. We keep
# these here for the benefit of any old code (unlikely though that # these here for the benefit of any old code (unlikely though that
# is!). # is!).
@ -171,3 +169,12 @@ def low_level_address_tuple(
return tup return tup
else: else:
raise NotImplementedError(f"unknown address family {af}") raise NotImplementedError(f"unknown address family {af}")
def any_for_af(af):
"""Return the 'any' address for the specified address family."""
if af == socket.AF_INET:
return "0.0.0.0"
elif af == socket.AF_INET6:
return "::"
raise NotImplementedError(f"unknown address family {af}")

View file

@ -17,9 +17,8 @@
"""IPv4 helper functions.""" """IPv4 helper functions."""
from typing import Union
import struct import struct
from typing import Union
import dns.exception import dns.exception

View file

@ -17,10 +17,9 @@
"""IPv6 helper functions.""" """IPv6 helper functions."""
from typing import List, Union
import re
import binascii import binascii
import re
from typing import List, Union
import dns.exception import dns.exception
import dns.ipv4 import dns.ipv4

View file

@ -17,30 +17,29 @@
"""DNS Messages""" """DNS Messages"""
from typing import Any, Dict, List, Optional, Tuple, Union
import contextlib import contextlib
import io import io
import time import time
from typing import Any, Dict, List, Optional, Tuple, Union
import dns.wire
import dns.edns import dns.edns
import dns.entropy
import dns.enum import dns.enum
import dns.exception import dns.exception
import dns.flags import dns.flags
import dns.name import dns.name
import dns.opcode import dns.opcode
import dns.entropy
import dns.rcode import dns.rcode
import dns.rdata import dns.rdata
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.rrset
import dns.renderer
import dns.ttl
import dns.tsig
import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.OPT
import dns.rdtypes.ANY.TSIG import dns.rdtypes.ANY.TSIG
import dns.renderer
import dns.rrset
import dns.tsig
import dns.ttl
import dns.wire
class ShortHeader(dns.exception.FormError): class ShortHeader(dns.exception.FormError):
@ -135,7 +134,7 @@ IndexKeyType = Tuple[
Optional[dns.rdataclass.RdataClass], Optional[dns.rdataclass.RdataClass],
] ]
IndexType = Dict[IndexKeyType, dns.rrset.RRset] IndexType = Dict[IndexKeyType, dns.rrset.RRset]
SectionType = Union[int, List[dns.rrset.RRset]] SectionType = Union[int, str, List[dns.rrset.RRset]]
class Message: class Message:
@ -231,7 +230,7 @@ class Message:
s.write("payload %d\n" % self.payload) s.write("payload %d\n" % self.payload)
for opt in self.options: for opt in self.options:
s.write("option %s\n" % opt.to_text()) s.write("option %s\n" % opt.to_text())
for (name, which) in self._section_enum.__members__.items(): for name, which in self._section_enum.__members__.items():
s.write(f";{name}\n") s.write(f";{name}\n")
for rrset in self.section_from_number(which): for rrset in self.section_from_number(which):
s.write(rrset.to_text(origin, relativize, **kw)) s.write(rrset.to_text(origin, relativize, **kw))
@ -348,27 +347,29 @@ class Message:
deleting: Optional[dns.rdataclass.RdataClass] = None, deleting: Optional[dns.rdataclass.RdataClass] = None,
create: bool = False, create: bool = False,
force_unique: bool = False, force_unique: bool = False,
idna_codec: Optional[dns.name.IDNACodec] = None,
) -> dns.rrset.RRset: ) -> dns.rrset.RRset:
"""Find the RRset with the given attributes in the specified section. """Find the RRset with the given attributes in the specified section.
*section*, an ``int`` section number, or one of the section *section*, an ``int`` section number, a ``str`` section name, or one of
attributes of this message. This specifies the the section attributes of this message. This specifies the
the section of the message to search. For example:: the section of the message to search. For example::
my_message.find_rrset(my_message.answer, name, rdclass, rdtype) my_message.find_rrset(my_message.answer, name, rdclass, rdtype)
my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype)
my_message.find_rrset("ANSWER", name, rdclass, rdtype)
*name*, a ``dns.name.Name``, the name of the RRset. *name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
*rdclass*, an ``int``, the class of the RRset. *rdclass*, an ``int`` or ``str``, the class of the RRset.
*rdtype*, an ``int``, the type of the RRset. *rdtype*, an ``int`` or ``str``, the type of the RRset.
*covers*, an ``int`` or ``None``, the covers value of the RRset. *covers*, an ``int`` or ``str``, the covers value of the RRset.
The default is ``None``. The default is ``dns.rdatatype.NONE``.
*deleting*, an ``int`` or ``None``, the deleting value of the RRset. *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
The default is ``None``. RRset. The default is ``None``.
*create*, a ``bool``. If ``True``, create the RRset if it is not found. *create*, a ``bool``. If ``True``, create the RRset if it is not found.
The created RRset is appended to *section*. The created RRset is appended to *section*.
@ -378,6 +379,10 @@ class Message:
already. The default is ``False``. This is useful when creating already. The default is ``False``. This is useful when creating
DDNS Update messages, as order matters for them. DDNS Update messages, as order matters for them.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
is used.
Raises ``KeyError`` if the RRset was not found and create was Raises ``KeyError`` if the RRset was not found and create was
``False``. ``False``.
@ -386,10 +391,19 @@ class Message:
if isinstance(section, int): if isinstance(section, int):
section_number = section section_number = section
the_section = self.section_from_number(section_number) section = self.section_from_number(section_number)
elif isinstance(section, str):
section_number = MessageSection.from_text(section)
section = self.section_from_number(section_number)
else: else:
section_number = self.section_number(section) section_number = self.section_number(section)
the_section = section if isinstance(name, str):
name = dns.name.from_text(name, idna_codec=idna_codec)
rdtype = dns.rdatatype.RdataType.make(rdtype)
rdclass = dns.rdataclass.RdataClass.make(rdclass)
covers = dns.rdatatype.RdataType.make(covers)
if deleting is not None:
deleting = dns.rdataclass.RdataClass.make(deleting)
key = (section_number, name, rdclass, rdtype, covers, deleting) key = (section_number, name, rdclass, rdtype, covers, deleting)
if not force_unique: if not force_unique:
if self.index is not None: if self.index is not None:
@ -397,13 +411,13 @@ class Message:
if rrset is not None: if rrset is not None:
return rrset return rrset
else: else:
for rrset in the_section: for rrset in section:
if rrset.full_match(name, rdclass, rdtype, covers, deleting): if rrset.full_match(name, rdclass, rdtype, covers, deleting):
return rrset return rrset
if not create: if not create:
raise KeyError raise KeyError
rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
the_section.append(rrset) section.append(rrset)
if self.index is not None: if self.index is not None:
self.index[key] = rrset self.index[key] = rrset
return rrset return rrset
@ -418,29 +432,31 @@ class Message:
deleting: Optional[dns.rdataclass.RdataClass] = None, deleting: Optional[dns.rdataclass.RdataClass] = None,
create: bool = False, create: bool = False,
force_unique: bool = False, force_unique: bool = False,
idna_codec: Optional[dns.name.IDNACodec] = None,
) -> Optional[dns.rrset.RRset]: ) -> Optional[dns.rrset.RRset]:
"""Get the RRset with the given attributes in the specified section. """Get the RRset with the given attributes in the specified section.
If the RRset is not found, None is returned. If the RRset is not found, None is returned.
*section*, an ``int`` section number, or one of the section *section*, an ``int`` section number, a ``str`` section name, or one of
attributes of this message. This specifies the the section attributes of this message. This specifies the
the section of the message to search. For example:: the section of the message to search. For example::
my_message.get_rrset(my_message.answer, name, rdclass, rdtype) my_message.get_rrset(my_message.answer, name, rdclass, rdtype)
my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype)
my_message.get_rrset("ANSWER", name, rdclass, rdtype)
*name*, a ``dns.name.Name``, the name of the RRset. *name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
*rdclass*, an ``int``, the class of the RRset. *rdclass*, an ``int`` or ``str``, the class of the RRset.
*rdtype*, an ``int``, the type of the RRset. *rdtype*, an ``int`` or ``str``, the type of the RRset.
*covers*, an ``int`` or ``None``, the covers value of the RRset. *covers*, an ``int`` or ``str``, the covers value of the RRset.
The default is ``None``. The default is ``dns.rdatatype.NONE``.
*deleting*, an ``int`` or ``None``, the deleting value of the RRset. *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
The default is ``None``. RRset. The default is ``None``.
*create*, a ``bool``. If ``True``, create the RRset if it is not found. *create*, a ``bool``. If ``True``, create the RRset if it is not found.
The created RRset is appended to *section*. The created RRset is appended to *section*.
@ -450,12 +466,24 @@ class Message:
already. The default is ``False``. This is useful when creating already. The default is ``False``. This is useful when creating
DDNS Update messages, as order matters for them. DDNS Update messages, as order matters for them.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
is used.
Returns a ``dns.rrset.RRset object`` or ``None``. Returns a ``dns.rrset.RRset object`` or ``None``.
""" """
try: try:
rrset = self.find_rrset( rrset = self.find_rrset(
section, name, rdclass, rdtype, covers, deleting, create, force_unique section,
name,
rdclass,
rdtype,
covers,
deleting,
create,
force_unique,
idna_codec,
) )
except KeyError: except KeyError:
rrset = None rrset = None
@ -1708,13 +1736,11 @@ def make_query(
if isinstance(qname, str): if isinstance(qname, str):
qname = dns.name.from_text(qname, idna_codec=idna_codec) qname = dns.name.from_text(qname, idna_codec=idna_codec)
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
the_rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
m = QueryMessage(id=id) m = QueryMessage(id=id)
m.flags = dns.flags.Flag(flags) m.flags = dns.flags.Flag(flags)
m.find_rrset( m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True
)
# only pass keywords on to use_edns if they have been set to a # only pass keywords on to use_edns if they have been set to a
# non-None value. Setting a field will turn EDNS on if it hasn't # non-None value. Setting a field will turn EDNS on if it hasn't
# been configured. # been configured.

View file

@ -18,12 +18,10 @@
"""DNS Names. """DNS Names.
""" """
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import copy import copy
import struct
import encodings.idna # type: ignore import encodings.idna # type: ignore
import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union
try: try:
import idna # type: ignore import idna # type: ignore
@ -33,10 +31,9 @@ except ImportError: # pragma: no cover
have_idna_2008 = False have_idna_2008 = False
import dns.enum import dns.enum
import dns.wire
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.wire
CompressType = Dict["Name", int] CompressType = Dict["Name", int]

329
lib/dns/nameserver.py Normal file
View file

@ -0,0 +1,329 @@
from typing import Optional, Union
from urllib.parse import urlparse
import dns.asyncbackend
import dns.asyncquery
import dns.inet
import dns.message
import dns.query
class Nameserver:
def __init__(self):
pass
def __str__(self):
raise NotImplementedError
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
raise NotImplementedError
def answer_nameserver(self) -> str:
raise NotImplementedError
def answer_port(self) -> int:
raise NotImplementedError
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
class AddressAndPortNameserver(Nameserver):
def __init__(self, address: str, port: int):
super().__init__()
self.address = address
self.port = port
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
return False
def __str__(self):
ns_kind = self.kind()
return f"{ns_kind}:{self.address}@{self.port}"
def answer_nameserver(self) -> str:
return self.address
def answer_port(self) -> int:
return self.port
class Do53Nameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 53):
super().__init__(address, port)
def kind(self):
return "Do53"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = dns.query.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = dns.query.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return response
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = await dns.asyncquery.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = await dns.asyncquery.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return response
class DoHNameserver(Nameserver):
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
def kind(self):
return "DoH"
def is_always_max_size(self) -> bool:
return True
def __str__(self):
return self.url
def answer_nameserver(self) -> str:
return self.url
def answer_port(self) -> int:
port = urlparse(self.url).port
if port is None:
port = 443
return port
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.https(
request,
self.url,
timeout=timeout,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.https(
request,
self.url,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
class DoTNameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
super().__init__(address, port)
self.hostname = hostname
def kind(self):
return "DoT"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
)
class DoQNameserver(AddressAndPortNameserver):
def __init__(
self,
address: str,
port: int = 853,
verify: Union[bool, str] = True,
server_hostname: Optional[str] = None,
):
super().__init__(address, port)
self.verify = verify
self.server_hostname = server_hostname
def kind(self):
return "DoQ"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)

View file

@ -17,19 +17,17 @@
"""DNS nodes. A node is a set of rdatasets.""" """DNS nodes. A node is a set of rdatasets."""
from typing import Any, Dict, Optional
import enum import enum
import io import io
from typing import Any, Dict, Optional
import dns.immutable import dns.immutable
import dns.name import dns.name
import dns.rdataclass import dns.rdataclass
import dns.rdataset import dns.rdataset
import dns.rdatatype import dns.rdatatype
import dns.rrset
import dns.renderer import dns.renderer
import dns.rrset
_cname_types = { _cname_types = {
dns.rdatatype.CNAME, dns.rdatatype.CNAME,

View file

@ -17,8 +17,6 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64 import base64
import contextlib import contextlib
import enum import enum
@ -28,12 +26,12 @@ import selectors
import socket import socket
import struct import struct
import time import time
import urllib.parse from typing import Any, Dict, Optional, Tuple, Union
import dns.exception import dns.exception
import dns.inet import dns.inet
import dns.name
import dns.message import dns.message
import dns.name
import dns.quic import dns.quic
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
@ -43,20 +41,32 @@ import dns.transaction
import dns.tsig import dns.tsig
import dns.xfr import dns.xfr
try:
import requests
from requests_toolbelt.adapters.source import SourceAddressAdapter
from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
_have_requests = True def _remaining(expiration):
except ImportError: # pragma: no cover if expiration is None:
_have_requests = False return None
timeout = expiration - time.time()
if timeout <= 0.0:
raise dns.exception.Timeout
return timeout
def _expiration_for_this_attempt(timeout, expiration):
if expiration is None:
return None
return min(time.time() + timeout, expiration)
_have_httpx = False _have_httpx = False
_have_http2 = False _have_http2 = False
try: try:
import httpcore
import httpcore._backends.sync
import httpx import httpx
_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream
_have_httpx = True _have_httpx = True
try: try:
# See if http2 support is available. # See if http2 support is available.
@ -64,10 +74,87 @@ try:
_have_http2 = True _have_http2 = True
except Exception: except Exception:
pass pass
except ImportError: # pragma: no cover
pass
have_doh = _have_requests or _have_httpx class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = dns.inet.low_level_address_tuple(
(local_address, self._local_port), af
)
else:
source = None
sock = _make_socket(af, socket.SOCK_STREAM, source)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
try:
_connect(
sock,
dns.inet.low_level_address_tuple((address, port), af),
attempt_expiration,
)
return _CoreSyncStream(sock)
except Exception:
pass
raise httpcore.ConnectError
def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
class _HTTPTransport(httpx.HTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
resolver = dns.resolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
except ImportError: # pragma: no cover
class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
have_doh = _have_httpx
try: try:
import ssl import ssl
@ -88,7 +175,7 @@ except ImportError: # pragma: no cover
@classmethod @classmethod
def create_default_context(cls, *args, **kwargs): def create_default_context(cls, *args, **kwargs):
raise Exception("no ssl support") raise Exception("no ssl support") # pylint: disable=broad-exception-raised
# Function used to create a socket. Can be overridden if needed in special # Function used to create a socket. Can be overridden if needed in special
@ -105,7 +192,7 @@ class BadResponse(dns.exception.FormError):
class NoDOH(dns.exception.DNSException): class NoDOH(dns.exception.DNSException):
"""DNS over HTTPS (DOH) was requested but the requests module is not """DNS over HTTPS (DOH) was requested but the httpx module is not
available.""" available."""
@ -230,7 +317,7 @@ def _destination_and_source(
# We know the destination af, so source had better agree! # We know the destination af, so source had better agree!
if saf != af: if saf != af:
raise ValueError( raise ValueError(
"different address families for source " + "and destination" "different address families for source and destination"
) )
else: else:
# We didn't know the destination af, but we know the source, # We didn't know the destination af, but we know the source,
@ -240,11 +327,10 @@ def _destination_and_source(
# Caller has specified a source_port but not an address, so we # Caller has specified a source_port but not an address, so we
# need to return a source, and we need to use the appropriate # need to return a source, and we need to use the appropriate
# wildcard address as the address. # wildcard address as the address.
if af == socket.AF_INET: try:
source = "0.0.0.0" source = dns.inet.any_for_af(af)
elif af == socket.AF_INET6: except Exception:
source = "::" # we catch this and raise ValueError for backwards compatibility
else:
raise ValueError("source_port specified but address family is unknown") raise ValueError("source_port specified but address family is unknown")
# Convert high-level (address, port) tuples into low-level address # Convert high-level (address, port) tuples into low-level address
# tuples. # tuples.
@ -289,6 +375,8 @@ def https(
post: bool = True, post: bool = True,
bootstrap_address: Optional[str] = None, bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
@ -314,91 +402,78 @@ def https(
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message. received message.
*session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the *session*, an ``httpx.Client``. If provided, the client session to use to send the
client/session to use to send the queries. queries.
*path*, a ``str``. If *where* is an IP address, then *path* will be used to *path*, a ``str``. If *where* is an IP address, then *path* will be used to
construct the URL to send the DNS query to. construct the URL to send the DNS query to.
*post*, a ``bool``. If ``True``, the default, POST method will be used. *post*, a ``bool``. If ``True``, the default, POST method will be used.
*bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS *bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
resolver.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification. directory which will be used for verification.
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames in URLs. If not specified, a new resolver with a default
configuration will be used; note this is *not* the default resolver as that resolver
might have been configured to use DoH causing a chicken-and-egg problem. This
parameter only has an effect if the HTTP library is httpx.
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
if not have_doh: if not have_doh:
raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover raise NoDOH # pragma: no cover
if session and not isinstance(session, httpx.Client):
_httpx_ok = _have_httpx raise ValueError("session parameter must be an httpx.Client")
wire = q.to_wire() wire = q.to_wire()
(af, _, source) = _destination_and_source(where, port, source, source_port, False) (af, _, the_source) = _destination_and_source(
transport_adapter = None where, port, source, source_port, False
)
transport = None transport = None
headers = {"accept": "application/dns-message"} headers = {"accept": "application/dns-message"}
if af is not None: if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET: if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path) url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path) url = "https://[{}]:{}{}".format(where, port, path)
elif bootstrap_address is not None:
_httpx_ok = False
split_url = urllib.parse.urlsplit(where)
if split_url.hostname is None:
raise ValueError("DoH URL has no hostname")
headers["Host"] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
if _have_requests:
transport_adapter = HostHeaderSSLAdapter()
else: else:
url = where url = where
if source is not None:
# set source port and source address
if _have_httpx:
if source_port == 0:
transport = httpx.HTTPTransport(local_address=source[0], verify=verify)
else:
_httpx_ok = False
if _have_requests:
transport_adapter = SourceAddressAdapter(source)
if session: # set source port and source address
if _have_httpx:
_is_httpx = isinstance(session, httpx.Client) if the_source is None:
else: local_address = None
_is_httpx = False local_port = 0
if _is_httpx and not _httpx_ok:
raise NoDOH(
"Session is httpx, but httpx cannot be used for "
"the requested operation."
)
else: else:
_is_httpx = _httpx_ok local_address = the_source[0]
local_port = the_source[1]
if not _httpx_ok and not _have_requests: transport = _HTTPTransport(
raise NoDOH( local_address=local_address,
"Cannot use httpx for this operation, and requests is not available." http1=True,
) http2=_have_http2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if session: if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
elif _is_httpx: else:
cm = httpx.Client( cm = httpx.Client(
http1=True, http2=_have_http2, verify=verify, transport=transport http1=True, http2=_have_http2, verify=verify, transport=transport
) )
else:
cm = requests.sessions.Session()
with cm as session: with cm as session:
if transport_adapter and not _is_httpx:
session.mount(url, transport_adapter)
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
if post: if post:
@ -408,29 +483,13 @@ def https(
"content-length": str(len(wire)), "content-length": str(len(wire)),
} }
) )
if _is_httpx: response = session.post(url, headers=headers, content=wire, timeout=timeout)
response = session.post(
url, headers=headers, content=wire, timeout=timeout
)
else:
response = session.post(
url, headers=headers, data=wire, timeout=timeout, verify=verify
)
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
if _is_httpx: twire = wire.decode() # httpx does a repr() if we give it bytes
twire = wire.decode() # httpx does a repr() if we give it bytes response = session.get(
response = session.get( url, headers=headers, timeout=timeout, params={"dns": twire}
url, headers=headers, timeout=timeout, params={"dns": twire} )
)
else:
response = session.get(
url,
headers=headers,
timeout=timeout,
verify=verify,
params={"dns": wire},
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
@ -1070,6 +1129,7 @@ def quic(
ignore_trailing: bool = False, ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None, connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
server_hostname: Optional[str] = None,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC. """Return the response obtained after sending a query via DNS-over-QUIC.
@ -1101,6 +1161,10 @@ def quic(
verification is done; if a `str` then it specifies the path to a certificate file or verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification. directory which will be used for verification.
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
@ -1115,16 +1179,18 @@ def quic(
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection the_connection = connection
else: else:
manager = dns.quic.SyncQuicManager(verify_mode=verify) manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=server_hostname
)
the_manager = manager # for type checking happiness the_manager = manager # for type checking happiness
with manager: with manager:
if not connection: if not connection:
the_connection = the_manager.connect(where, port, source, source_port) the_connection = the_manager.connect(where, port, source, source_port)
start = time.time() (start, expiration) = _compute_times(timeout)
with the_connection.make_stream() as stream: with the_connection.make_stream(timeout) as stream:
stream.send(wire, True) stream.send(wire, True)
wire = stream.receive(timeout) wire = stream.receive(_remaining(expiration))
finish = time.time() finish = time.time()
r = dns.message.from_wire( r = dns.message.from_wire(
wire, wire,

View file

@ -5,13 +5,13 @@ try:
import dns.asyncbackend import dns.asyncbackend
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
from dns.quic._asyncio import ( from dns.quic._asyncio import (
AsyncioQuicManager,
AsyncioQuicConnection, AsyncioQuicConnection,
AsyncioQuicManager,
AsyncioQuicStream, AsyncioQuicStream,
) )
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream
have_quic = True have_quic = True
@ -33,9 +33,10 @@ try:
try: try:
import trio import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports from dns.quic._trio import ( # pylint: disable=ungrouped-imports
TrioQuicManager,
TrioQuicConnection, TrioQuicConnection,
TrioQuicManager,
TrioQuicStream, TrioQuicStream,
) )

View file

@ -9,14 +9,16 @@ import time
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore import aioquic.quic.events # type: ignore
import dns.inet
import dns.asyncbackend
import dns.asyncbackend
import dns.exception
import dns.inet
from dns.quic._common import ( from dns.quic._common import (
BaseQuicStream, QUIC_MAX_DATAGRAM,
AsyncQuicConnection, AsyncQuicConnection,
AsyncQuicManager, AsyncQuicManager,
QUIC_MAX_DATAGRAM, BaseQuicStream,
UnexpectedEOF,
) )
@ -30,15 +32,15 @@ class AsyncioQuicStream(BaseQuicStream):
await self._wake_up.wait() await self._wake_up.wait()
async def wait_for(self, amount, expiration): async def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True: while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount): if self._buffer.have(amount):
return return
self._expecting = amount self._expecting = amount
try: try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout) await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except Exception: except TimeoutError:
pass raise dns.exception.Timeout
self._expecting = 0 self._expecting = 0
async def receive(self, timeout=None): async def receive(self, timeout=None):
@ -86,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
try: try:
af = dns.inet.af_for_address(self._address) af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio") backend = dns.asyncbackend.get_backend("asyncio")
# Note that peer is a low-level address tuple, but make_socket() wants
# a high-level address tuple, so we convert.
self._socket = await backend.make_socket( self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, self._peer af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
) )
self._socket_created.set() self._socket_created.set()
async with self._socket: async with self._socket:
@ -106,6 +110,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer.notify_all() self._wake_timer.notify_all()
except Exception: except Exception:
pass pass
finally:
self._done = True
async with self._wake_timer:
self._wake_timer.notify_all()
self._handshake_complete.set()
async def _wait_for_wake_timer(self): async def _wait_for_wake_timer(self):
async with self._wake_timer: async with self._wake_timer:
@ -115,7 +124,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await self._socket_created.wait() await self._socket_created.wait()
while not self._done: while not self._done:
datagrams = self._connection.datagrams_to_send(time.time()) datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, address) in datagrams: for datagram, address in datagrams:
assert address == self._peer[0] assert address == self._peer[0]
await self._socket.sendto(datagram, self._peer, None) await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values() (expiration, interval) = self._get_timer_values()
@ -160,8 +169,13 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._receiver_task = asyncio.Task(self._receiver()) self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender()) self._sender_task = asyncio.Task(self._sender())
async def make_stream(self): async def make_stream(self, timeout=None):
await self._handshake_complete.wait() try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
except TimeoutError:
raise dns.exception.Timeout
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False) stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id) stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream self._streams[stream_id] = stream
@ -172,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._manager.closed(self._peer[0], self._peer[1]) self._manager.closed(self._peer[0], self._peer[1])
self._closed = True self._closed = True
self._connection.close() self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._socket.close()
async with self._wake_timer: async with self._wake_timer:
self._wake_timer.notify_all() self._wake_timer.notify_all()
try: try:
@ -185,8 +202,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
class AsyncioQuicManager(AsyncQuicManager): class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection) super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port) (connection, start) = self._connect(address, port, source, source_port)
@ -198,7 +215,7 @@ class AsyncioQuicManager(AsyncQuicManager):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections # Copy the iterator into a list as exiting things will mutate the connections
# table. # table.
connections = list(self._connections.values()) connections = list(self._connections.values())
for connection in connections: for connection in connections:

View file

@ -3,13 +3,12 @@
import socket import socket
import struct import struct
import time import time
from typing import Any, Optional
from typing import Any
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
import dns.inet
import dns.inet
QUIC_MAX_DATAGRAM = 2048 QUIC_MAX_DATAGRAM = 2048
@ -135,12 +134,12 @@ class BaseQuicConnection:
class AsyncQuicConnection(BaseQuicConnection): class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self) -> Any: async def make_stream(self, timeout: Optional[float] = None) -> Any:
pass pass
class BaseQuicManager: class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory): def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {} self._connections = {}
self._connection_factory = connection_factory self._connection_factory = connection_factory
if conf is None: if conf is None:
@ -151,6 +150,7 @@ class BaseQuicManager:
conf = aioquic.quic.configuration.QuicConfiguration( conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=["doq", "doq-i03"], alpn_protocols=["doq", "doq-i03"],
verify_mode=verify_mode, verify_mode=verify_mode,
server_name=server_name,
) )
if verify_path is not None: if verify_path is not None:
conf.load_verify_locations(verify_path) conf.load_verify_locations(verify_path)

View file

@ -1,8 +1,8 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import selectors
import socket import socket
import ssl import ssl
import selectors
import struct import struct
import threading import threading
import time import time
@ -10,13 +10,15 @@ import time
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore import aioquic.quic.events # type: ignore
import dns.inet
import dns.exception
import dns.inet
from dns.quic._common import ( from dns.quic._common import (
BaseQuicStream, QUIC_MAX_DATAGRAM,
BaseQuicConnection, BaseQuicConnection,
BaseQuicManager, BaseQuicManager,
QUIC_MAX_DATAGRAM, BaseQuicStream,
UnexpectedEOF,
) )
# Avoid circularity with dns.query # Avoid circularity with dns.query
@ -33,14 +35,15 @@ class SyncQuicStream(BaseQuicStream):
self._lock = threading.Lock() self._lock = threading.Lock()
def wait_for(self, amount, expiration): def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True: while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock: with self._lock:
if self._buffer.have(amount): if self._buffer.have(amount):
return return
self._expecting = amount self._expecting = amount
with self._wake_up: with self._wake_up:
self._wake_up.wait(timeout) if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
self._expecting = 0 self._expecting = 0
def receive(self, timeout=None): def receive(self, timeout=None):
@ -114,24 +117,30 @@ class SyncQuicConnection(BaseQuicConnection):
return return
def _worker(self): def _worker(self):
sel = _selector_class() try:
sel.register(self._socket, selectors.EVENT_READ, self._read) sel = _selector_class()
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) sel.register(self._socket, selectors.EVENT_READ, self._read)
while not self._done: sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
(expiration, interval) = self._get_timer_values(False) while not self._done:
items = sel.select(interval) (expiration, interval) = self._get_timer_values(False)
for (key, _) in items: items = sel.select(interval)
key.data() for key, _ in items:
key.data()
with self._lock:
self._handle_timer(expiration)
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
finally:
with self._lock: with self._lock:
self._handle_timer(expiration) self._done = True
datagrams = self._connection.datagrams_to_send(time.time()) # Ensure anyone waiting for this gets woken up.
for (datagram, _) in datagrams: self._handshake_complete.set()
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
def _handle_events(self): def _handle_events(self):
while True: while True:
@ -163,9 +172,12 @@ class SyncQuicConnection(BaseQuicConnection):
self._worker_thread = threading.Thread(target=self._worker) self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start() self._worker_thread.start()
def make_stream(self): def make_stream(self, timeout=None):
self._handshake_complete.wait() if not self._handshake_complete.wait(timeout):
raise dns.exception.Timeout
with self._lock: with self._lock:
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False) stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id) stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream self._streams[stream_id] = stream
@ -187,8 +199,8 @@ class SyncQuicConnection(BaseQuicConnection):
class SyncQuicManager(BaseQuicManager): class SyncQuicManager(BaseQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, SyncQuicConnection) super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock() self._lock = threading.Lock()
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):
@ -206,7 +218,7 @@ class SyncQuicManager(BaseQuicManager):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections # Copy the iterator into a list as exiting things will mutate the connections
# table. # table.
connections = list(self._connections.values()) connections = list(self._connections.values())
for connection in connections: for connection in connections:

View file

@ -10,13 +10,15 @@ import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore import aioquic.quic.events # type: ignore
import trio import trio
import dns.exception
import dns.inet import dns.inet
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.quic._common import ( from dns.quic._common import (
BaseQuicStream, QUIC_MAX_DATAGRAM,
AsyncQuicConnection, AsyncQuicConnection,
AsyncQuicManager, AsyncQuicManager,
QUIC_MAX_DATAGRAM, BaseQuicStream,
UnexpectedEOF,
) )
@ -44,6 +46,7 @@ class TrioQuicStream(BaseQuicStream):
(size,) = struct.unpack("!H", self._buffer.get(2)) (size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size) await self.wait_for(size)
return self._buffer.get(size) return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False): async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram) data = self._encapsulate(datagram)
@ -80,20 +83,26 @@ class TrioQuicConnection(AsyncQuicConnection):
self._worker_scope = None self._worker_scope = None
async def _worker(self): async def _worker(self):
await self._socket.connect(self._peer) try:
while not self._done: await self._socket.connect(self._peer)
(expiration, interval) = self._get_timer_values(False) while not self._done:
with trio.CancelScope( (expiration, interval) = self._get_timer_values(False)
deadline=trio.current_time() + interval with trio.CancelScope(
) as self._worker_scope: deadline=trio.current_time() + interval
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) ) as self._worker_scope:
self._connection.receive_datagram(datagram, self._peer[0], time.time()) datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._worker_scope = None self._connection.receive_datagram(
self._handle_timer(expiration) datagram, self._peer[0], time.time()
datagrams = self._connection.datagrams_to_send(time.time()) )
for (datagram, _) in datagrams: self._worker_scope = None
await self._socket.send(datagram) self._handle_timer(expiration)
await self._handle_events() datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
await self._socket.send(datagram)
await self._handle_events()
finally:
self._done = True
self._handshake_complete.set()
async def _handle_events(self): async def _handle_events(self):
count = 0 count = 0
@ -130,12 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection):
nursery.start_soon(self._worker) nursery.start_soon(self._worker)
self._run_done.set() self._run_done.set()
async def make_stream(self): async def make_stream(self, timeout=None):
await self._handshake_complete.wait() if timeout is None:
stream_id = self._connection.get_next_available_stream_id(False) context = NullContext(None)
stream = TrioQuicStream(self, stream_id) else:
self._streams[stream_id] = stream context = trio.move_on_after(timeout)
return stream with context:
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
raise dns.exception.Timeout
async def close(self): async def close(self):
if not self._closed: if not self._closed:
@ -148,8 +165,10 @@ class TrioQuicConnection(AsyncQuicConnection):
class TrioQuicManager(AsyncQuicManager): class TrioQuicManager(AsyncQuicManager):
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): def __init__(
super().__init__(conf, verify_mode, TrioQuicConnection) self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery self._nursery = nursery
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):
@ -162,7 +181,7 @@ class TrioQuicManager(AsyncQuicManager):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections # Copy the iterator into a list as exiting things will mutate the connections
# table. # table.
connections = list(self._connections.values()) connections = list(self._connections.values())
for connection in connections: for connection in connections:

View file

@ -17,17 +17,15 @@
"""DNS rdata.""" """DNS rdata."""
from typing import Any, Dict, Optional, Tuple, Union
from importlib import import_module
import base64 import base64
import binascii import binascii
import io
import inspect import inspect
import io
import itertools import itertools
import random import random
from importlib import import_module
from typing import Any, Dict, Optional, Tuple, Union
import dns.wire
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.ipv4 import dns.ipv4
@ -37,6 +35,7 @@ import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.tokenizer import dns.tokenizer
import dns.ttl import dns.ttl
import dns.wire
_chunksize = 32 _chunksize = 32
@ -358,7 +357,6 @@ class Rdata:
or self.rdclass != other.rdclass or self.rdclass != other.rdclass
or self.rdtype != other.rdtype or self.rdtype != other.rdtype
): ):
return NotImplemented return NotImplemented
return self._cmp(other) < 0 return self._cmp(other) < 0
@ -881,16 +879,11 @@ def register_type(
it applies to all classes. it applies to all classes.
""" """
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
existing_cls = get_rdata_class(rdclass, the_rdtype) existing_cls = get_rdata_class(rdclass, rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
try: _rdata_classes[(rdclass, rdtype)] = getattr(
if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
except ValueError:
pass
_rdata_classes[(rdclass, the_rdtype)] = getattr(
implementation, rdtype_text.replace("-", "_") implementation, rdtype_text.replace("-", "_")
) )
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)

View file

@ -17,18 +17,17 @@
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
from typing import Any, cast, Collection, Dict, List, Optional, Union
import io import io
import random import random
import struct import struct
from typing import Any, Collection, Dict, List, Optional, Union, cast
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name import dns.name
import dns.rdatatype
import dns.rdataclass
import dns.rdata import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.set import dns.set
import dns.ttl import dns.ttl
@ -471,9 +470,9 @@ def from_text_list(
Returns a ``dns.rdataset.Rdataset`` object. Returns a ``dns.rdataset.Rdataset`` object.
""" """
the_rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
r = Rdataset(the_rdclass, the_rdtype) r = Rdataset(rdclass, rdtype)
r.update_ttl(ttl) r.update_ttl(ttl)
for t in text_rdatas: for t in text_rdatas:
rd = dns.rdata.from_text( rd = dns.rdata.from_text(

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,15 +15,15 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable import dns.immutable
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
# pylint: disable=unused-import # pylint: disable=unused-import
from dns.rdtypes.dnskeybase import ( from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
SEP,
REVOKE, REVOKE,
SEP,
ZONE, ZONE,
) # noqa: F401 lgtm[py/unused-import] )
# pylint: enable=unused-import # pylint: enable=unused-import

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase
import dns.immutable import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,12 +15,12 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import base64 import base64
import struct
import dns.dnssectypes
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.dnssectypes
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -19,9 +19,9 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.name
import dns.rdtypes.util import dns.rdtypes.util

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase
import dns.immutable import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,15 +15,15 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable import dns.immutable
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
# pylint: disable=unused-import # pylint: disable=unused-import
from dns.rdtypes.dnskeybase import ( from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
SEP,
REVOKE, REVOKE,
SEP,
ZONE, ZONE,
) # noqa: F401 lgtm[py/unused-import] )
# pylint: enable=unused-import # pylint: enable=unused-import

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase
import dns.immutable import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -16,8 +16,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.euibase
import dns.immutable import dns.immutable
import dns.rdtypes.euibase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -16,8 +16,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.euibase
import dns.immutable import dns.immutable
import dns.rdtypes.euibase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,9 +15,9 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import base64 import base64
import binascii import binascii
import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable

View file

@ -21,7 +21,6 @@ import dns.exception
import dns.immutable import dns.immutable
import dns.rdata import dns.rdata
_pows = tuple(10**i for i in range(0, 11)) _pows = tuple(10**i for i in range(0, 11))
# default values are in centimeters # default values are in centimeters
@ -40,7 +39,7 @@ def _exponent_of(what, desc):
if what == 0: if what == 0:
return 0 return 0
exp = None exp = None
for (i, pow) in enumerate(_pows): for i, pow in enumerate(_pows):
if what < pow: if what < pow:
exp = i - 1 exp = i - 1
break break

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -17,9 +17,9 @@
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.name
import dns.rdtypes.util import dns.rdtypes.util

View file

@ -25,7 +25,6 @@ import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.rdtypes.util import dns.rdtypes.util
b32_hex_to_normal = bytes.maketrans( b32_hex_to_normal = bytes.maketrans(
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
) )
@ -67,6 +66,7 @@ class NSEC3(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
next = next.rstrip("=")
if self.salt == b"": if self.salt == b"":
salt = "-" salt = "-"
else: else:
@ -94,6 +94,10 @@ class NSEC3(dns.rdata.Rdata):
else: else:
salt = binascii.unhexlify(salt.encode("ascii")) salt = binascii.unhexlify(salt.encode("ascii"))
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
if next.endswith(b"="):
raise binascii.Error("Incorrect padding")
if len(next) % 8 != 0:
next += b"=" * (8 - len(next) % 8)
next = base64.b32decode(next) next = base64.b32decode(next)
bitmap = Bitmap.from_text(tok) bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii import binascii
import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable

View file

@ -18,11 +18,10 @@
import struct import struct
import dns.edns import dns.edns
import dns.immutable
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
# We don't implement from_text, and that's ok. # We don't implement from_text, and that's ok.
# pylint: disable=abstract-method # pylint: disable=abstract-method

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -17,8 +17,8 @@
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
import dns.name import dns.name
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -21,8 +21,8 @@ import struct
import time import time
import dns.dnssectypes import dns.dnssectypes
import dns.immutable
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -19,8 +19,8 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
import dns.name import dns.name
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,11 +15,11 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii import binascii
import struct
import dns.rdata
import dns.immutable import dns.immutable
import dns.rdata
import dns.rdatatype import dns.rdatatype

View file

@ -18,8 +18,8 @@
import base64 import base64
import struct import struct
import dns.immutable
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -20,9 +20,9 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdtypes.util import dns.rdtypes.util
import dns.name
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -1,7 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import binascii import binascii
import struct
import dns.immutable import dns.immutable
import dns.rdata import dns.rdata

View file

@ -17,8 +17,8 @@
import struct import struct
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -124,7 +124,6 @@ class APL(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
items = [] items = []
while parser.remaining() > 0: while parser.remaining() > 0:
header = parser.get_struct("!HBB") header = parser.get_struct("!HBB")

View file

@ -1,7 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.rdtypes.svcbbase
import dns.immutable import dns.immutable
import dns.rdtypes.svcbbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import base64 import base64
import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable

Some files were not shown because too many files have changed in this diff Show more