mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-07 05:31:15 -07:00
Bump dnspython from 2.3.0 to 2.4.2 (#2123)
* Bump dnspython from 2.3.0 to 2.4.2 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.3.0 to 2.4.2. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.3.0...v2.4.2) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update dnspython==2.4.2 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
9f00f5dafa
commit
c0aa4e4996
108 changed files with 2985 additions and 1136 deletions
|
@ -22,6 +22,7 @@ __all__ = [
|
|||
"asyncquery",
|
||||
"asyncresolver",
|
||||
"dnssec",
|
||||
"dnssecalgs",
|
||||
"dnssectypes",
|
||||
"e164",
|
||||
"edns",
|
||||
|
|
|
@ -35,6 +35,9 @@ class Socket: # pragma: no cover
|
|||
async def getsockname(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
|
@ -61,6 +64,11 @@ class StreamSocket(Socket): # pragma: no cover
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class NullTransport:
|
||||
async def connect_tcp(self, host, port, timeout, local_address):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Backend: # pragma: no cover
|
||||
def name(self):
|
||||
return "unknown"
|
||||
|
@ -83,3 +91,9 @@ class Backend: # pragma: no cover
|
|||
|
||||
async def sleep(self, interval):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_transport_class(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def wait_for(self, awaitable, timeout):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -2,14 +2,13 @@
|
|||
|
||||
"""asyncio library query support"""
|
||||
|
||||
import socket
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import dns._asyncbackend
|
||||
import dns.exception
|
||||
|
||||
|
||||
_is_win32 = sys.platform == "win32"
|
||||
|
||||
|
||||
|
@ -38,6 +37,13 @@ class _DatagramProtocol:
|
|||
|
||||
def connection_lost(self, exc):
|
||||
if self.recvfrom and not self.recvfrom.done():
|
||||
if exc is None:
|
||||
# EOF we triggered. Is there a better way to do this?
|
||||
try:
|
||||
raise EOFError
|
||||
except EOFError as e:
|
||||
self.recvfrom.set_exception(e)
|
||||
else:
|
||||
self.recvfrom.set_exception(exc)
|
||||
|
||||
def close(self):
|
||||
|
@ -45,7 +51,7 @@ class _DatagramProtocol:
|
|||
|
||||
|
||||
async def _maybe_wait_for(awaitable, timeout):
|
||||
if timeout:
|
||||
if timeout is not None:
|
||||
try:
|
||||
return await asyncio.wait_for(awaitable, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
|
@ -85,6 +91,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
|||
async def getsockname(self):
|
||||
return self.transport.get_extra_info("sockname")
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, af, reader, writer):
|
||||
|
@ -101,10 +110,6 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
|
||||
async def close(self):
|
||||
self.writer.close()
|
||||
try:
|
||||
await self.writer.wait_closed()
|
||||
except AttributeError: # pragma: no cover
|
||||
pass
|
||||
|
||||
async def getpeername(self):
|
||||
return self.writer.get_extra_info("peername")
|
||||
|
@ -112,6 +117,97 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
async def getsockname(self):
|
||||
return self.writer.get_extra_info("sockname")
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
return self.writer.get_extra_info("peercert")
|
||||
|
||||
|
||||
try:
|
||||
import anyio
|
||||
import httpcore
|
||||
import httpcore._backends.anyio
|
||||
import httpx
|
||||
|
||||
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
||||
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream
|
||||
|
||||
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
||||
|
||||
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
||||
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||
super().__init__()
|
||||
self._local_port = local_port
|
||||
self._resolver = resolver
|
||||
self._bootstrap_address = bootstrap_address
|
||||
self._family = family
|
||||
if local_port != 0:
|
||||
raise NotImplementedError(
|
||||
"the asyncio transport for HTTPX cannot set the local port"
|
||||
)
|
||||
|
||||
async def connect_tcp(
|
||||
self, host, port, timeout, local_address, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
addresses = []
|
||||
_, expiration = _compute_times(timeout)
|
||||
if dns.inet.is_address(host):
|
||||
addresses.append(host)
|
||||
elif self._bootstrap_address is not None:
|
||||
addresses.append(self._bootstrap_address)
|
||||
else:
|
||||
timeout = _remaining(expiration)
|
||||
family = self._family
|
||||
if local_address:
|
||||
family = dns.inet.af_for_address(local_address)
|
||||
answers = await self._resolver.resolve_name(
|
||||
host, family=family, lifetime=timeout
|
||||
)
|
||||
addresses = answers.addresses()
|
||||
for address in addresses:
|
||||
try:
|
||||
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||
timeout = _remaining(attempt_expiration)
|
||||
with anyio.fail_after(timeout):
|
||||
stream = await anyio.connect_tcp(
|
||||
remote_host=address,
|
||||
remote_port=port,
|
||||
local_host=local_address,
|
||||
)
|
||||
return _CoreAnyIOStream(stream)
|
||||
except Exception:
|
||||
pass
|
||||
raise httpcore.ConnectError
|
||||
|
||||
async def connect_unix_socket(
|
||||
self, path, timeout, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, seconds): # pylint: disable=signature-differs
|
||||
await anyio.sleep(seconds)
|
||||
|
||||
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
local_port=0,
|
||||
bootstrap_address=None,
|
||||
resolver=None,
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
resolver = dns.asyncresolver.Resolver()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pool._network_backend = _NetworkBackend(
|
||||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||
|
||||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
|
@ -171,3 +267,9 @@ class Backend(dns._asyncbackend.Backend):
|
|||
|
||||
def datagram_connection_required(self):
|
||||
return _is_win32
|
||||
|
||||
def get_transport_class(self):
|
||||
return _HTTPTransport
|
||||
|
||||
async def wait_for(self, awaitable, timeout):
|
||||
return await _maybe_wait_for(awaitable, timeout)
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
"""curio async I/O library query support"""
|
||||
|
||||
import socket
|
||||
import curio
|
||||
import curio.socket # type: ignore
|
||||
|
||||
import dns._asyncbackend
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
|
||||
|
||||
def _maybe_timeout(timeout):
|
||||
if timeout:
|
||||
return curio.ignore_after(timeout)
|
||||
else:
|
||||
return dns._asyncbackend.NullContext()
|
||||
|
||||
|
||||
# for brevity
|
||||
_lltuple = dns.inet.low_level_address_tuple
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, socket):
|
||||
super().__init__(socket.family)
|
||||
self.socket = socket
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.sendto(what, destination)
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
|
||||
async def recvfrom(self, size, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.recvfrom(size)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
await self.socket.close()
|
||||
|
||||
async def getpeername(self):
|
||||
return self.socket.getpeername()
|
||||
|
||||
async def getsockname(self):
|
||||
return self.socket.getsockname()
|
||||
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, socket):
|
||||
self.socket = socket
|
||||
self.family = socket.family
|
||||
|
||||
async def sendall(self, what, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.sendall(what)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def recv(self, size, timeout):
|
||||
async with _maybe_timeout(timeout):
|
||||
return await self.socket.recv(size)
|
||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
||||
|
||||
async def close(self):
|
||||
await self.socket.close()
|
||||
|
||||
async def getpeername(self):
|
||||
return self.socket.getpeername()
|
||||
|
||||
async def getsockname(self):
|
||||
return self.socket.getsockname()
|
||||
|
||||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
return "curio"
|
||||
|
||||
async def make_socket(
|
||||
self,
|
||||
af,
|
||||
socktype,
|
||||
proto=0,
|
||||
source=None,
|
||||
destination=None,
|
||||
timeout=None,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
if socktype == socket.SOCK_DGRAM:
|
||||
s = curio.socket.socket(af, socktype, proto)
|
||||
try:
|
||||
if source:
|
||||
s.bind(_lltuple(source, af))
|
||||
except Exception: # pragma: no cover
|
||||
await s.close()
|
||||
raise
|
||||
return DatagramSocket(s)
|
||||
elif socktype == socket.SOCK_STREAM:
|
||||
if source:
|
||||
source_addr = _lltuple(source, af)
|
||||
else:
|
||||
source_addr = None
|
||||
async with _maybe_timeout(timeout):
|
||||
s = await curio.open_connection(
|
||||
destination[0],
|
||||
destination[1],
|
||||
ssl=ssl_context,
|
||||
source_addr=source_addr,
|
||||
server_hostname=server_hostname,
|
||||
)
|
||||
return StreamSocket(s)
|
||||
raise NotImplementedError(
|
||||
"unsupported socket " + f"type {socktype}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def sleep(self, interval):
|
||||
await curio.sleep(interval)
|
154
lib/dns/_ddr.py
Normal file
154
lib/dns/_ddr.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
#
|
||||
# Support for Discovery of Designated Resolvers
|
||||
|
||||
import socket
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.inet
|
||||
import dns.name
|
||||
import dns.nameserver
|
||||
import dns.query
|
||||
import dns.rdtypes.svcbbase
|
||||
|
||||
# The special name of the local resolver when using DDR
|
||||
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
|
||||
|
||||
|
||||
#
|
||||
# Processing is split up into I/O independent and I/O dependent parts to
|
||||
# make supporting sync and async versions easy.
|
||||
#
|
||||
|
||||
|
||||
class _SVCBInfo:
|
||||
def __init__(self, bootstrap_address, port, hostname, nameservers):
|
||||
self.bootstrap_address = bootstrap_address
|
||||
self.port = port
|
||||
self.hostname = hostname
|
||||
self.nameservers = nameservers
|
||||
|
||||
def ddr_check_certificate(self, cert):
|
||||
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
|
||||
for name, value in cert["subjectAltName"]:
|
||||
if name == "IP Address" and value == self.bootstrap_address:
|
||||
return True
|
||||
return False
|
||||
|
||||
def make_tls_context(self):
|
||||
ssl = dns.query.ssl
|
||||
ctx = ssl.create_default_context()
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
return ctx
|
||||
|
||||
def ddr_tls_check_sync(self, lifetime):
|
||||
ctx = self.make_tls_context()
|
||||
expiration = time.time() + lifetime
|
||||
with socket.create_connection(
|
||||
(self.bootstrap_address, self.port), lifetime
|
||||
) as s:
|
||||
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
|
||||
ts.settimeout(dns.query._remaining(expiration))
|
||||
ts.do_handshake()
|
||||
cert = ts.getpeercert()
|
||||
return self.ddr_check_certificate(cert)
|
||||
|
||||
async def ddr_tls_check_async(self, lifetime, backend=None):
|
||||
if backend is None:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
ctx = self.make_tls_context()
|
||||
expiration = time.time() + lifetime
|
||||
async with await backend.make_socket(
|
||||
dns.inet.af_for_address(self.bootstrap_address),
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
None,
|
||||
(self.bootstrap_address, self.port),
|
||||
lifetime,
|
||||
ctx,
|
||||
self.hostname,
|
||||
) as ts:
|
||||
cert = await ts.getpeercert(dns.query._remaining(expiration))
|
||||
return self.ddr_check_certificate(cert)
|
||||
|
||||
|
||||
def _extract_nameservers_from_svcb(answer):
|
||||
bootstrap_address = answer.nameserver
|
||||
if not dns.inet.is_address(bootstrap_address):
|
||||
return []
|
||||
infos = []
|
||||
for rr in answer.rrset.processing_order():
|
||||
nameservers = []
|
||||
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
|
||||
if param is None:
|
||||
continue
|
||||
alpns = set(param.ids)
|
||||
host = rr.target.to_text(omit_final_dot=True)
|
||||
port = None
|
||||
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
|
||||
if param is not None:
|
||||
port = param.port
|
||||
# For now we ignore address hints and address resolution and always use the
|
||||
# bootstrap address
|
||||
if b"h2" in alpns:
|
||||
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
|
||||
if param is None or not param.value.endswith(b"{?dns}"):
|
||||
continue
|
||||
path = param.value[:-6].decode()
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if port is None:
|
||||
port = 443
|
||||
url = f"https://{host}:{port}{path}"
|
||||
# check the URL
|
||||
try:
|
||||
urlparse(url)
|
||||
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
|
||||
except Exception:
|
||||
# continue processing other ALPN types
|
||||
pass
|
||||
if b"dot" in alpns:
|
||||
if port is None:
|
||||
port = 853
|
||||
nameservers.append(
|
||||
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
|
||||
)
|
||||
if b"doq" in alpns:
|
||||
if port is None:
|
||||
port = 853
|
||||
nameservers.append(
|
||||
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
|
||||
)
|
||||
if len(nameservers) > 0:
|
||||
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
|
||||
return infos
|
||||
|
||||
|
||||
def _get_nameservers_sync(answer, lifetime):
|
||||
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
|
||||
answer."""
|
||||
nameservers = []
|
||||
infos = _extract_nameservers_from_svcb(answer)
|
||||
for info in infos:
|
||||
try:
|
||||
if info.ddr_tls_check_sync(lifetime):
|
||||
nameservers.extend(info.nameservers)
|
||||
except Exception:
|
||||
pass
|
||||
return nameservers
|
||||
|
||||
|
||||
async def _get_nameservers_async(answer, lifetime):
|
||||
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
|
||||
answer."""
|
||||
nameservers = []
|
||||
infos = _extract_nameservers_from_svcb(answer)
|
||||
for info in infos:
|
||||
try:
|
||||
if await info.ddr_tls_check_async(lifetime):
|
||||
nameservers.extend(info.nameservers)
|
||||
except Exception:
|
||||
pass
|
||||
return nameservers
|
|
@ -7,7 +7,6 @@
|
|||
import contextvars
|
||||
import inspect
|
||||
|
||||
|
||||
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
"""trio async I/O library query support"""
|
||||
|
||||
import socket
|
||||
|
||||
import trio
|
||||
import trio.socket # type: ignore
|
||||
|
||||
|
@ -12,7 +13,7 @@ import dns.inet
|
|||
|
||||
|
||||
def _maybe_timeout(timeout):
|
||||
if timeout:
|
||||
if timeout is not None:
|
||||
return trio.move_on_after(timeout)
|
||||
else:
|
||||
return dns._asyncbackend.NullContext()
|
||||
|
@ -50,6 +51,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
|||
async def getsockname(self):
|
||||
return self.socket.getsockname()
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, family, stream, tls=False):
|
||||
|
@ -82,6 +86,100 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
else:
|
||||
return self.stream.socket.getsockname()
|
||||
|
||||
async def getpeercert(self, timeout):
|
||||
if self.tls:
|
||||
with _maybe_timeout(timeout):
|
||||
await self.stream.do_handshake()
|
||||
return self.stream.getpeercert()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
try:
|
||||
import httpcore
|
||||
import httpcore._backends.trio
|
||||
import httpx
|
||||
|
||||
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
||||
_CoreTrioStream = httpcore._backends.trio.TrioStream
|
||||
|
||||
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
||||
|
||||
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
||||
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||
super().__init__()
|
||||
self._local_port = local_port
|
||||
self._resolver = resolver
|
||||
self._bootstrap_address = bootstrap_address
|
||||
self._family = family
|
||||
|
||||
async def connect_tcp(
|
||||
self, host, port, timeout, local_address, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
addresses = []
|
||||
_, expiration = _compute_times(timeout)
|
||||
if dns.inet.is_address(host):
|
||||
addresses.append(host)
|
||||
elif self._bootstrap_address is not None:
|
||||
addresses.append(self._bootstrap_address)
|
||||
else:
|
||||
timeout = _remaining(expiration)
|
||||
family = self._family
|
||||
if local_address:
|
||||
family = dns.inet.af_for_address(local_address)
|
||||
answers = await self._resolver.resolve_name(
|
||||
host, family=family, lifetime=timeout
|
||||
)
|
||||
addresses = answers.addresses()
|
||||
for address in addresses:
|
||||
try:
|
||||
af = dns.inet.af_for_address(address)
|
||||
if local_address is not None or self._local_port != 0:
|
||||
source = (local_address, self._local_port)
|
||||
else:
|
||||
source = None
|
||||
destination = (address, port)
|
||||
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||
timeout = _remaining(attempt_expiration)
|
||||
sock = await Backend().make_socket(
|
||||
af, socket.SOCK_STREAM, 0, source, destination, timeout
|
||||
)
|
||||
return _CoreTrioStream(sock.stream)
|
||||
except Exception:
|
||||
continue
|
||||
raise httpcore.ConnectError
|
||||
|
||||
async def connect_unix_socket(
|
||||
self, path, timeout, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep(self, seconds): # pylint: disable=signature-differs
|
||||
await trio.sleep(seconds)
|
||||
|
||||
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
local_port=0,
|
||||
bootstrap_address=None,
|
||||
resolver=None,
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
resolver = dns.asyncresolver.Resolver()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pool._network_backend = _NetworkBackend(
|
||||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||
|
||||
|
||||
class Backend(dns._asyncbackend.Backend):
|
||||
def name(self):
|
||||
|
@ -104,8 +202,14 @@ class Backend(dns._asyncbackend.Backend):
|
|||
if source:
|
||||
await s.bind(_lltuple(source, af))
|
||||
if socktype == socket.SOCK_STREAM:
|
||||
connected = False
|
||||
with _maybe_timeout(timeout):
|
||||
await s.connect(_lltuple(destination, af))
|
||||
connected = True
|
||||
if not connected:
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # lgtm[py/unreachable-statement]
|
||||
except Exception: # pragma: no cover
|
||||
s.close()
|
||||
raise
|
||||
|
@ -130,3 +234,13 @@ class Backend(dns._asyncbackend.Backend):
|
|||
|
||||
async def sleep(self, interval):
|
||||
await trio.sleep(interval)
|
||||
|
||||
def get_transport_class(self):
|
||||
return _HTTPTransport
|
||||
|
||||
async def wait_for(self, awaitable, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await awaitable
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
|
|
|
@ -5,13 +5,12 @@ from typing import Dict
|
|||
import dns.exception
|
||||
|
||||
# pylint: disable=unused-import
|
||||
|
||||
from dns._asyncbackend import (
|
||||
Socket,
|
||||
DatagramSocket,
|
||||
StreamSocket,
|
||||
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
|
||||
Backend,
|
||||
) # noqa: F401 lgtm[py/unused-import]
|
||||
DatagramSocket,
|
||||
Socket,
|
||||
StreamSocket,
|
||||
)
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
@ -30,8 +29,8 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
|
|||
def get_backend(name: str) -> Backend:
|
||||
"""Get the specified asynchronous backend.
|
||||
|
||||
*name*, a ``str``, the name of the backend. Currently the "trio",
|
||||
"curio", and "asyncio" backends are available.
|
||||
*name*, a ``str``, the name of the backend. Currently the "trio"
|
||||
and "asyncio" backends are available.
|
||||
|
||||
Raises NotImplementError if an unknown backend name is specified.
|
||||
"""
|
||||
|
@ -43,10 +42,6 @@ def get_backend(name: str) -> Backend:
|
|||
import dns._trio_backend
|
||||
|
||||
backend = dns._trio_backend.Backend()
|
||||
elif name == "curio":
|
||||
import dns._curio_backend
|
||||
|
||||
backend = dns._curio_backend.Backend()
|
||||
elif name == "asyncio":
|
||||
import dns._asyncio_backend
|
||||
|
||||
|
@ -73,9 +68,7 @@ def sniff() -> str:
|
|||
try:
|
||||
return sniffio.current_async_library()
|
||||
except sniffio.AsyncLibraryNotFoundError:
|
||||
raise AsyncLibraryNotFoundError(
|
||||
"sniffio cannot determine " + "async library"
|
||||
)
|
||||
raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
|
||||
except ImportError:
|
||||
import asyncio
|
||||
|
||||
|
|
|
@ -17,39 +17,38 @@
|
|||
|
||||
"""Talk to a DNS server."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
import dns.name
|
||||
import dns.message
|
||||
import dns.name
|
||||
import dns.quic
|
||||
import dns.rcode
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.transaction
|
||||
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.query import (
|
||||
_compute_times,
|
||||
_matches_destination,
|
||||
BadResponse,
|
||||
ssl,
|
||||
UDPMode,
|
||||
_have_httpx,
|
||||
_have_http2,
|
||||
NoDOH,
|
||||
NoDOQ,
|
||||
UDPMode,
|
||||
_compute_times,
|
||||
_have_http2,
|
||||
_matches_destination,
|
||||
_remaining,
|
||||
have_doh,
|
||||
ssl,
|
||||
)
|
||||
|
||||
if _have_httpx:
|
||||
if have_doh:
|
||||
import httpx
|
||||
|
||||
# for brevity
|
||||
|
@ -73,7 +72,7 @@ def _source_tuple(af, address, port):
|
|||
|
||||
|
||||
def _timeout(expiration, now=None):
|
||||
if expiration:
|
||||
if expiration is not None:
|
||||
if not now:
|
||||
now = time.time()
|
||||
return max(expiration - now, 0)
|
||||
|
@ -445,9 +444,6 @@ async def tls(
|
|||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
if server_hostname is None:
|
||||
ssl_context.check_hostname = False
|
||||
else:
|
||||
ssl_context = None
|
||||
server_hostname = None
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
|
@ -495,6 +491,9 @@ async def https(
|
|||
path: str = "/dns-query",
|
||||
post: bool = True,
|
||||
verify: Union[bool, str] = True,
|
||||
bootstrap_address: Optional[str] = None,
|
||||
resolver: Optional["dns.asyncresolver.Resolver"] = None,
|
||||
family: Optional[int] = socket.AF_UNSPEC,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
|
@ -508,8 +507,10 @@ async def https(
|
|||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if not _have_httpx:
|
||||
raise NoDOH("httpx is not available.") # pragma: no cover
|
||||
if not have_doh:
|
||||
raise NoDOH # pragma: no cover
|
||||
if client and not isinstance(client, httpx.AsyncClient):
|
||||
raise ValueError("session parameter must be an httpx.AsyncClient")
|
||||
|
||||
wire = q.to_wire()
|
||||
try:
|
||||
|
@ -518,15 +519,32 @@ async def https(
|
|||
af = None
|
||||
transport = None
|
||||
headers = {"accept": "application/dns-message"}
|
||||
if af is not None:
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = "https://{}:{}{}".format(where, port, path)
|
||||
elif af == socket.AF_INET6:
|
||||
url = "https://[{}]:{}{}".format(where, port, path)
|
||||
else:
|
||||
url = where
|
||||
if source is not None:
|
||||
transport = httpx.AsyncHTTPTransport(local_address=source[0])
|
||||
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
|
||||
if source is None:
|
||||
local_address = None
|
||||
local_port = 0
|
||||
else:
|
||||
local_address = source
|
||||
local_port = source_port
|
||||
transport = backend.get_transport_class()(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=_have_http2,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
if client:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
||||
|
@ -545,14 +563,14 @@ async def https(
|
|||
"content-length": str(len(wire)),
|
||||
}
|
||||
)
|
||||
response = await the_client.post(
|
||||
url, headers=headers, content=wire, timeout=timeout
|
||||
response = await backend.wait_for(
|
||||
the_client.post(url, headers=headers, content=wire), timeout
|
||||
)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = await the_client.get(
|
||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
||||
response = await backend.wait_for(
|
||||
the_client.get(url, headers=headers, params={"dns": twire}), timeout
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
|
@ -690,6 +708,7 @@ async def quic(
|
|||
connection: Optional[dns.quic.AsyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending an asynchronous query via
|
||||
DNS-over-QUIC.
|
||||
|
@ -715,14 +734,16 @@ async def quic(
|
|||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(context, verify_mode=verify) as the_manager:
|
||||
async with mfactory(
|
||||
context, verify_mode=verify, server_name=server_hostname
|
||||
) as the_manager:
|
||||
if not connection:
|
||||
the_connection = the_manager.connect(where, port, source, source_port)
|
||||
start = time.time()
|
||||
stream = await the_connection.make_stream()
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
stream = await the_connection.make_stream(timeout)
|
||||
async with stream:
|
||||
await stream.send(wire, True)
|
||||
wire = await stream.receive(timeout)
|
||||
wire = await stream.receive(_remaining(expiration))
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
|
|
|
@ -17,10 +17,11 @@
|
|||
|
||||
"""Asynchronous DNS stub resolver."""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import dns._ddr
|
||||
import dns.asyncbackend
|
||||
import dns.asyncquery
|
||||
import dns.exception
|
||||
|
@ -31,8 +32,7 @@ import dns.rdatatype
|
|||
import dns.resolver # lgtm[py/import-and-import-from]
|
||||
|
||||
# import some resolver symbols for brevity
|
||||
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
|
||||
|
||||
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
|
||||
|
||||
# for indentation purposes below
|
||||
_udp = dns.asyncquery.udp
|
||||
|
@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
assert request is not None # needed for type checking
|
||||
done = False
|
||||
while not done:
|
||||
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
|
||||
(nameserver, tcp, backoff) = resolution.next_nameserver()
|
||||
if backoff:
|
||||
await backend.sleep(backoff)
|
||||
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
||||
try:
|
||||
if dns.inet.is_address(nameserver):
|
||||
if tcp:
|
||||
response = await _tcp(
|
||||
response = await nameserver.async_query(
|
||||
request,
|
||||
nameserver,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
timeout=timeout,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
max_size=tcp,
|
||||
backend=backend,
|
||||
)
|
||||
else:
|
||||
response = await _udp(
|
||||
request,
|
||||
nameserver,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
raise_on_truncation=True,
|
||||
backend=backend,
|
||||
)
|
||||
else:
|
||||
response = await dns.asyncquery.https(
|
||||
request, nameserver, timeout=timeout
|
||||
)
|
||||
except Exception as ex:
|
||||
(_, done) = resolution.query_result(None, ex)
|
||||
continue
|
||||
|
@ -153,6 +135,73 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
||||
)
|
||||
|
||||
async def resolve_name(
|
||||
self,
|
||||
name: Union[dns.name.Name, str],
|
||||
family: int = socket.AF_UNSPEC,
|
||||
**kwargs: Any,
|
||||
) -> dns.resolver.HostAnswers:
|
||||
"""Use an asynchronous resolver to query for address records.
|
||||
|
||||
This utilizes the resolve() method to perform A and/or AAAA lookups on
|
||||
the specified name.
|
||||
|
||||
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
|
||||
|
||||
*family*, an ``int``, the address family. If socket.AF_UNSPEC
|
||||
(the default), both A and AAAA records will be retrieved.
|
||||
|
||||
All other arguments that can be passed to the resolve() function
|
||||
except for rdtype and rdclass are also supported by this
|
||||
function.
|
||||
"""
|
||||
# We make a modified kwargs for type checking happiness, as otherwise
|
||||
# we get a legit warning about possibly having rdtype and rdclass
|
||||
# in the kwargs more than once.
|
||||
modified_kwargs: Dict[str, Any] = {}
|
||||
modified_kwargs.update(kwargs)
|
||||
modified_kwargs.pop("rdtype", None)
|
||||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||
|
||||
if family == socket.AF_INET:
|
||||
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
|
||||
return dns.resolver.HostAnswers.make(v4=v4)
|
||||
elif family == socket.AF_INET6:
|
||||
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
|
||||
return dns.resolver.HostAnswers.make(v6=v6)
|
||||
elif family != socket.AF_UNSPEC:
|
||||
raise NotImplementedError(f"unknown address family {family}")
|
||||
|
||||
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
|
||||
lifetime = modified_kwargs.pop("lifetime", None)
|
||||
start = time.time()
|
||||
v6 = await self.resolve(
|
||||
name,
|
||||
dns.rdatatype.AAAA,
|
||||
raise_on_no_answer=False,
|
||||
lifetime=self._compute_timeout(start, lifetime),
|
||||
**modified_kwargs,
|
||||
)
|
||||
# Note that setting name ensures we query the same name
|
||||
# for A as we did for AAAA. (This is just in case search lists
|
||||
# are active by default in the resolver configuration and
|
||||
# we might be talking to a server that says NXDOMAIN when it
|
||||
# wants to say NOERROR no data.
|
||||
name = v6.qname
|
||||
v4 = await self.resolve(
|
||||
name,
|
||||
dns.rdatatype.A,
|
||||
raise_on_no_answer=False,
|
||||
lifetime=self._compute_timeout(start, lifetime),
|
||||
**modified_kwargs,
|
||||
)
|
||||
answers = dns.resolver.HostAnswers.make(
|
||||
v6=v6, v4=v4, add_empty=not raise_on_no_answer
|
||||
)
|
||||
if not answers:
|
||||
raise NoAnswer(response=v6.response)
|
||||
return answers
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
|
@ -176,6 +225,37 @@ class Resolver(dns.resolver.BaseResolver):
|
|||
canonical_name = e.canonical_name
|
||||
return canonical_name
|
||||
|
||||
async def try_ddr(self, lifetime: float = 5.0) -> None:
|
||||
"""Try to update the resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
|
||||
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
|
||||
is 5 seconds.
|
||||
|
||||
If the SVCB query is successful and results in a non-empty list of nameservers,
|
||||
then the resolver's nameservers are set to the returned servers in priority
|
||||
order.
|
||||
|
||||
The current implementation does not use any address hints from the SVCB record,
|
||||
nor does it resolve addresses for the SCVB target name, rather it assumes that
|
||||
the bootstrap nameserver will always be one of the addresses and uses it.
|
||||
A future revision to the code may offer fuller support. The code verifies that
|
||||
the bootstrap nameserver is in the Subject Alternative Name field of the
|
||||
TLS certficate.
|
||||
"""
|
||||
try:
|
||||
expiration = time.time() + lifetime
|
||||
answer = await self.resolve(
|
||||
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
|
||||
)
|
||||
timeout = dns.query._remaining(expiration)
|
||||
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
|
||||
if len(nameservers) > 0:
|
||||
self.nameservers = nameservers
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
default_resolver = None
|
||||
|
||||
|
@ -246,6 +326,18 @@ async def resolve_address(
|
|||
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
||||
|
||||
|
||||
async def resolve_name(
|
||||
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
|
||||
) -> dns.resolver.HostAnswers:
|
||||
"""Use a resolver to asynchronously query for address records.
|
||||
|
||||
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
|
||||
information on the parameters.
|
||||
"""
|
||||
|
||||
return await get_default_resolver().resolve_name(name, family, **kwargs)
|
||||
|
||||
|
||||
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
"""Determine the canonical name of *name*.
|
||||
|
||||
|
@ -256,6 +348,16 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
|||
return await get_default_resolver().canonical_name(name)
|
||||
|
||||
|
||||
async def try_ddr(timeout: float = 5.0) -> None:
|
||||
"""Try to update the default resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
|
||||
"""
|
||||
return await get_default_resolver().try_ddr(timeout)
|
||||
|
||||
|
||||
async def zone_for_name(
|
||||
name: Union[dns.name.Name, str],
|
||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||
|
@ -290,3 +392,84 @@ async def zone_for_name(
|
|||
name = name.parent()
|
||||
except dns.name.NoParent: # pragma: no cover
|
||||
raise NoRootSOA
|
||||
|
||||
|
||||
async def make_resolver_at(
|
||||
where: Union[dns.name.Name, str],
|
||||
port: int = 53,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
resolver: Optional[Resolver] = None,
|
||||
) -> Resolver:
|
||||
"""Make a stub resolver using the specified destination as the full resolver.
|
||||
|
||||
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
|
||||
full resolver.
|
||||
|
||||
*port*, an ``int``, the port to use. If not specified, the default is 53.
|
||||
|
||||
*family*, an ``int``, the address family to use. This parameter is used if
|
||||
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
|
||||
the first address returned by ``resolve_name()`` will be used, otherwise the
|
||||
first address of the specified family will be used.
|
||||
|
||||
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
|
||||
resolution of hostnames. If not specified, the default resolver will be used.
|
||||
|
||||
Returns a ``dns.resolver.Resolver`` or raises an exception.
|
||||
"""
|
||||
if resolver is None:
|
||||
resolver = get_default_resolver()
|
||||
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
|
||||
if isinstance(where, str) and dns.inet.is_address(where):
|
||||
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
|
||||
else:
|
||||
answers = await resolver.resolve_name(where, family)
|
||||
for address in answers.addresses():
|
||||
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
|
||||
res = dns.asyncresolver.Resolver(configure=False)
|
||||
res.nameservers = nameservers
|
||||
return res
|
||||
|
||||
|
||||
async def resolve_at(
|
||||
where: Union[dns.name.Name, str],
|
||||
qname: Union[dns.name.Name, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: Optional[str] = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: Optional[float] = None,
|
||||
search: Optional[bool] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
port: int = 53,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
resolver: Optional[Resolver] = None,
|
||||
) -> dns.resolver.Answer:
|
||||
"""Query nameservers to find the answer to the question.
|
||||
|
||||
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
|
||||
to make a resolver, and then uses it to resolve the query.
|
||||
|
||||
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
|
||||
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
|
||||
resolver parameters *where*, *port*, *family*, and *resolver*.
|
||||
|
||||
If making more than one query, it is more efficient to call
|
||||
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
|
||||
instead of calling ``resolve_at()`` multiple times.
|
||||
"""
|
||||
res = await make_resolver_at(where, port, family, resolver)
|
||||
return await res.resolve(
|
||||
qname,
|
||||
rdtype,
|
||||
rdclass,
|
||||
tcp,
|
||||
source,
|
||||
raise_on_no_answer,
|
||||
source_port,
|
||||
lifetime,
|
||||
search,
|
||||
backend,
|
||||
)
|
||||
|
|
|
@ -17,50 +17,44 @@
|
|||
|
||||
"""Common DNSSEC-related functions and constants."""
|
||||
|
||||
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import math
|
||||
import struct
|
||||
import time
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
||||
|
||||
import dns.exception
|
||||
import dns.name
|
||||
import dns.node
|
||||
import dns.rdataset
|
||||
import dns.rdata
|
||||
import dns.rdatatype
|
||||
import dns.rdataclass
|
||||
import dns.rdataset
|
||||
import dns.rdatatype
|
||||
import dns.rrset
|
||||
import dns.transaction
|
||||
import dns.zone
|
||||
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
|
||||
from dns.exception import ( # pylint: disable=W0611
|
||||
AlgorithmKeyMismatch,
|
||||
DeniedByPolicy,
|
||||
UnsupportedAlgorithm,
|
||||
ValidationFailure,
|
||||
)
|
||||
from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
|
||||
from dns.rdtypes.ANY.CDS import CDS
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
from dns.rdtypes.ANY.DS import DS
|
||||
from dns.rdtypes.ANY.NSEC import NSEC, Bitmap
|
||||
from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM
|
||||
from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
|
||||
from dns.rdtypes.dnskeybase import Flag
|
||||
|
||||
|
||||
class UnsupportedAlgorithm(dns.exception.DNSException):
|
||||
"""The DNSSEC algorithm is not supported."""
|
||||
|
||||
|
||||
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
|
||||
"""The DNSSEC algorithm is not supported for the given key type."""
|
||||
|
||||
|
||||
class ValidationFailure(dns.exception.DNSException):
|
||||
"""The DNSSEC signature is invalid."""
|
||||
|
||||
|
||||
class DeniedByPolicy(dns.exception.DNSException):
|
||||
"""Denied by DNSSEC policy."""
|
||||
|
||||
|
||||
PublicKey = Union[
|
||||
"GenericPublicKey",
|
||||
"rsa.RSAPublicKey",
|
||||
"ec.EllipticCurvePublicKey",
|
||||
"ed25519.Ed25519PublicKey",
|
||||
|
@ -68,12 +62,15 @@ PublicKey = Union[
|
|||
]
|
||||
|
||||
PrivateKey = Union[
|
||||
"GenericPrivateKey",
|
||||
"rsa.RSAPrivateKey",
|
||||
"ec.EllipticCurvePrivateKey",
|
||||
"ed25519.Ed25519PrivateKey",
|
||||
"ed448.Ed448PrivateKey",
|
||||
]
|
||||
|
||||
RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None]
|
||||
|
||||
|
||||
def algorithm_from_text(text: str) -> Algorithm:
|
||||
"""Convert text into a DNSSEC algorithm value.
|
||||
|
@ -308,113 +305,13 @@ def _find_candidate_keys(
|
|||
return [
|
||||
cast(DNSKEY, rd)
|
||||
for rd in rdataset
|
||||
if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag
|
||||
if rd.algorithm == rrsig.algorithm
|
||||
and key_id(rd) == rrsig.key_tag
|
||||
and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1
|
||||
and rd.protocol == 3 # RFC 4034 2.1.2
|
||||
]
|
||||
|
||||
|
||||
def _is_rsa(algorithm: int) -> bool:
|
||||
return algorithm in (
|
||||
Algorithm.RSAMD5,
|
||||
Algorithm.RSASHA1,
|
||||
Algorithm.RSASHA1NSEC3SHA1,
|
||||
Algorithm.RSASHA256,
|
||||
Algorithm.RSASHA512,
|
||||
)
|
||||
|
||||
|
||||
def _is_dsa(algorithm: int) -> bool:
|
||||
return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
|
||||
|
||||
|
||||
def _is_ecdsa(algorithm: int) -> bool:
|
||||
return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
|
||||
|
||||
|
||||
def _is_eddsa(algorithm: int) -> bool:
|
||||
return algorithm in (Algorithm.ED25519, Algorithm.ED448)
|
||||
|
||||
|
||||
def _is_gost(algorithm: int) -> bool:
|
||||
return algorithm == Algorithm.ECCGOST
|
||||
|
||||
|
||||
def _is_md5(algorithm: int) -> bool:
|
||||
return algorithm == Algorithm.RSAMD5
|
||||
|
||||
|
||||
def _is_sha1(algorithm: int) -> bool:
|
||||
return algorithm in (
|
||||
Algorithm.DSA,
|
||||
Algorithm.RSASHA1,
|
||||
Algorithm.DSANSEC3SHA1,
|
||||
Algorithm.RSASHA1NSEC3SHA1,
|
||||
)
|
||||
|
||||
|
||||
def _is_sha256(algorithm: int) -> bool:
|
||||
return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
|
||||
|
||||
|
||||
def _is_sha384(algorithm: int) -> bool:
|
||||
return algorithm == Algorithm.ECDSAP384SHA384
|
||||
|
||||
|
||||
def _is_sha512(algorithm: int) -> bool:
|
||||
return algorithm == Algorithm.RSASHA512
|
||||
|
||||
|
||||
def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None:
|
||||
"""Ensure algorithm is valid for key type, throwing an exception on
|
||||
mismatch."""
|
||||
if isinstance(key, rsa.RSAPublicKey):
|
||||
if _is_rsa(algorithm):
|
||||
return
|
||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm)
|
||||
if isinstance(key, dsa.DSAPublicKey):
|
||||
if _is_dsa(algorithm):
|
||||
return
|
||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm)
|
||||
if isinstance(key, ec.EllipticCurvePublicKey):
|
||||
if _is_ecdsa(algorithm):
|
||||
return
|
||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm)
|
||||
if isinstance(key, ed25519.Ed25519PublicKey):
|
||||
if algorithm == Algorithm.ED25519:
|
||||
return
|
||||
raise AlgorithmKeyMismatch(
|
||||
'algorithm "%s" not valid for ED25519 key' % algorithm
|
||||
)
|
||||
if isinstance(key, ed448.Ed448PublicKey):
|
||||
if algorithm == Algorithm.ED448:
|
||||
return
|
||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm)
|
||||
|
||||
raise TypeError("unsupported key type")
|
||||
|
||||
|
||||
def _make_hash(algorithm: int) -> Any:
|
||||
if _is_md5(algorithm):
|
||||
return hashes.MD5()
|
||||
if _is_sha1(algorithm):
|
||||
return hashes.SHA1()
|
||||
if _is_sha256(algorithm):
|
||||
return hashes.SHA256()
|
||||
if _is_sha384(algorithm):
|
||||
return hashes.SHA384()
|
||||
if _is_sha512(algorithm):
|
||||
return hashes.SHA512()
|
||||
if algorithm == Algorithm.ED25519:
|
||||
return hashes.SHA512()
|
||||
if algorithm == Algorithm.ED448:
|
||||
return hashes.SHAKE256(114)
|
||||
|
||||
raise ValidationFailure("unknown hash for algorithm %u" % algorithm)
|
||||
|
||||
|
||||
def _bytes_to_long(b: bytes) -> int:
|
||||
return int.from_bytes(b, "big")
|
||||
|
||||
|
||||
def _get_rrname_rdataset(
|
||||
rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
|
||||
) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
|
||||
|
@ -424,85 +321,13 @@ def _get_rrname_rdataset(
|
|||
return rrset.name, rrset
|
||||
|
||||
|
||||
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None:
|
||||
keyptr: bytes
|
||||
if _is_rsa(key.algorithm):
|
||||
# we ignore because mypy is confused and thinks key.key is a str for unknown
|
||||
# reasons.
|
||||
keyptr = key.key
|
||||
(bytes_,) = struct.unpack("!B", keyptr[0:1])
|
||||
keyptr = keyptr[1:]
|
||||
if bytes_ == 0:
|
||||
(bytes_,) = struct.unpack("!H", keyptr[0:2])
|
||||
keyptr = keyptr[2:]
|
||||
rsa_e = keyptr[0:bytes_]
|
||||
rsa_n = keyptr[bytes_:]
|
||||
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
|
||||
public_cls = get_algorithm_cls_from_dnskey(key).public_cls
|
||||
try:
|
||||
rsa_public_key = rsa.RSAPublicNumbers(
|
||||
_bytes_to_long(rsa_e), _bytes_to_long(rsa_n)
|
||||
).public_key(default_backend())
|
||||
public_key = public_cls.from_dnskey(key)
|
||||
except ValueError:
|
||||
raise ValidationFailure("invalid public key")
|
||||
rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
|
||||
elif _is_dsa(key.algorithm):
|
||||
keyptr = key.key
|
||||
(t,) = struct.unpack("!B", keyptr[0:1])
|
||||
keyptr = keyptr[1:]
|
||||
octets = 64 + t * 8
|
||||
dsa_q = keyptr[0:20]
|
||||
keyptr = keyptr[20:]
|
||||
dsa_p = keyptr[0:octets]
|
||||
keyptr = keyptr[octets:]
|
||||
dsa_g = keyptr[0:octets]
|
||||
keyptr = keyptr[octets:]
|
||||
dsa_y = keyptr[0:octets]
|
||||
try:
|
||||
dsa_public_key = dsa.DSAPublicNumbers( # type: ignore
|
||||
_bytes_to_long(dsa_y),
|
||||
dsa.DSAParameterNumbers(
|
||||
_bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g)
|
||||
),
|
||||
).public_key(default_backend())
|
||||
except ValueError:
|
||||
raise ValidationFailure("invalid public key")
|
||||
dsa_public_key.verify(sig, data, chosen_hash)
|
||||
elif _is_ecdsa(key.algorithm):
|
||||
keyptr = key.key
|
||||
curve: Any
|
||||
if key.algorithm == Algorithm.ECDSAP256SHA256:
|
||||
curve = ec.SECP256R1()
|
||||
octets = 32
|
||||
else:
|
||||
curve = ec.SECP384R1()
|
||||
octets = 48
|
||||
ecdsa_x = keyptr[0:octets]
|
||||
ecdsa_y = keyptr[octets : octets * 2]
|
||||
try:
|
||||
ecdsa_public_key = ec.EllipticCurvePublicNumbers(
|
||||
curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y)
|
||||
).public_key(default_backend())
|
||||
except ValueError:
|
||||
raise ValidationFailure("invalid public key")
|
||||
ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
|
||||
elif _is_eddsa(key.algorithm):
|
||||
keyptr = key.key
|
||||
loader: Any
|
||||
if key.algorithm == Algorithm.ED25519:
|
||||
loader = ed25519.Ed25519PublicKey
|
||||
else:
|
||||
loader = ed448.Ed448PublicKey
|
||||
try:
|
||||
eddsa_public_key = loader.from_public_bytes(keyptr)
|
||||
except ValueError:
|
||||
raise ValidationFailure("invalid public key")
|
||||
eddsa_public_key.verify(sig, data)
|
||||
elif _is_gost(key.algorithm):
|
||||
raise UnsupportedAlgorithm(
|
||||
'algorithm "%s" not supported by dnspython'
|
||||
% algorithm_to_text(key.algorithm)
|
||||
)
|
||||
else:
|
||||
raise ValidationFailure("unknown algorithm %u" % key.algorithm)
|
||||
public_key.verify(sig, data)
|
||||
|
||||
|
||||
def _validate_rrsig(
|
||||
|
@ -559,29 +384,13 @@ def _validate_rrsig(
|
|||
if rrsig.inception > now:
|
||||
raise ValidationFailure("not yet valid")
|
||||
|
||||
if _is_dsa(rrsig.algorithm):
|
||||
sig_r = rrsig.signature[1:21]
|
||||
sig_s = rrsig.signature[21:]
|
||||
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
|
||||
elif _is_ecdsa(rrsig.algorithm):
|
||||
if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
|
||||
octets = 32
|
||||
else:
|
||||
octets = 48
|
||||
sig_r = rrsig.signature[0:octets]
|
||||
sig_s = rrsig.signature[octets:]
|
||||
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
|
||||
else:
|
||||
sig = rrsig.signature
|
||||
|
||||
data = _make_rrsig_signature_data(rrset, rrsig, origin)
|
||||
chosen_hash = _make_hash(rrsig.algorithm)
|
||||
|
||||
for candidate_key in candidate_keys:
|
||||
if not policy.ok_to_validate(candidate_key):
|
||||
continue
|
||||
try:
|
||||
_validate_signature(sig, data, candidate_key, chosen_hash)
|
||||
_validate_signature(rrsig.signature, data, candidate_key)
|
||||
return
|
||||
except (InvalidSignature, ValidationFailure):
|
||||
# this happens on an individual validation failure
|
||||
|
@ -673,6 +482,7 @@ def _sign(
|
|||
lifetime: Optional[int] = None,
|
||||
verify: bool = False,
|
||||
policy: Optional[Policy] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
) -> RRSIG:
|
||||
"""Sign RRset using private key.
|
||||
|
||||
|
@ -708,6 +518,10 @@ def _sign(
|
|||
*policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy,
|
||||
``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624.
|
||||
|
||||
*origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all
|
||||
names in the rrset (including its owner name) must be absolute; otherwise the
|
||||
specified origin will be used to make names absolute when signing.
|
||||
|
||||
Raises ``DeniedByPolicy`` if the signature is denied by policy.
|
||||
"""
|
||||
|
||||
|
@ -735,16 +549,26 @@ def _sign(
|
|||
if expiration is not None:
|
||||
rrsig_expiration = to_timestamp(expiration)
|
||||
elif lifetime is not None:
|
||||
rrsig_expiration = int(time.time()) + lifetime
|
||||
rrsig_expiration = rrsig_inception + lifetime
|
||||
else:
|
||||
raise ValueError("expiration or lifetime must be specified")
|
||||
|
||||
# Derelativize now because we need a correct labels length for the
|
||||
# rrsig_template.
|
||||
if origin is not None:
|
||||
rrname = rrname.derelativize(origin)
|
||||
labels = len(rrname) - 1
|
||||
|
||||
# Adjust labels appropriately for wildcards.
|
||||
if rrname.is_wild():
|
||||
labels -= 1
|
||||
|
||||
rrsig_template = RRSIG(
|
||||
rdclass=rdclass,
|
||||
rdtype=dns.rdatatype.RRSIG,
|
||||
type_covered=rdtype,
|
||||
algorithm=dnskey.algorithm,
|
||||
labels=len(rrname) - 1,
|
||||
labels=labels,
|
||||
original_ttl=original_ttl,
|
||||
expiration=rrsig_expiration,
|
||||
inception=rrsig_inception,
|
||||
|
@ -753,64 +577,19 @@ def _sign(
|
|||
signature=b"",
|
||||
)
|
||||
|
||||
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template)
|
||||
chosen_hash = _make_hash(rrsig_template.algorithm)
|
||||
signature = None
|
||||
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
|
||||
|
||||
if isinstance(private_key, rsa.RSAPrivateKey):
|
||||
if not _is_rsa(dnskey.algorithm):
|
||||
raise ValueError("Invalid DNSKEY algorithm for RSA key")
|
||||
signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash)
|
||||
if verify:
|
||||
private_key.public_key().verify(
|
||||
signature, data, padding.PKCS1v15(), chosen_hash
|
||||
)
|
||||
elif isinstance(private_key, dsa.DSAPrivateKey):
|
||||
if not _is_dsa(dnskey.algorithm):
|
||||
raise ValueError("Invalid DNSKEY algorithm for DSA key")
|
||||
public_dsa_key = private_key.public_key()
|
||||
if public_dsa_key.key_size > 1024:
|
||||
raise ValueError("DSA key size overflow")
|
||||
der_signature = private_key.sign(data, chosen_hash)
|
||||
if verify:
|
||||
public_dsa_key.verify(der_signature, data, chosen_hash)
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
|
||||
octets = 20
|
||||
signature = (
|
||||
struct.pack("!B", dsa_t)
|
||||
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
|
||||
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
|
||||
)
|
||||
elif isinstance(private_key, ec.EllipticCurvePrivateKey):
|
||||
if not _is_ecdsa(dnskey.algorithm):
|
||||
raise ValueError("Invalid DNSKEY algorithm for EC key")
|
||||
der_signature = private_key.sign(data, ec.ECDSA(chosen_hash))
|
||||
if verify:
|
||||
private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash))
|
||||
if dnskey.algorithm == Algorithm.ECDSAP256SHA256:
|
||||
octets = 32
|
||||
else:
|
||||
octets = 48
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes(
|
||||
dsa_s, length=octets, byteorder="big"
|
||||
)
|
||||
elif isinstance(private_key, ed25519.Ed25519PrivateKey):
|
||||
if dnskey.algorithm != Algorithm.ED25519:
|
||||
raise ValueError("Invalid DNSKEY algorithm for ED25519 key")
|
||||
signature = private_key.sign(data)
|
||||
if verify:
|
||||
private_key.public_key().verify(signature, data)
|
||||
elif isinstance(private_key, ed448.Ed448PrivateKey):
|
||||
if dnskey.algorithm != Algorithm.ED448:
|
||||
raise ValueError("Invalid DNSKEY algorithm for ED448 key")
|
||||
signature = private_key.sign(data)
|
||||
if verify:
|
||||
private_key.public_key().verify(signature, data)
|
||||
if isinstance(private_key, GenericPrivateKey):
|
||||
signing_key = private_key
|
||||
else:
|
||||
try:
|
||||
private_cls = get_algorithm_cls_from_dnskey(dnskey)
|
||||
signing_key = private_cls(key=private_key)
|
||||
except UnsupportedAlgorithm:
|
||||
raise TypeError("Unsupported key algorithm")
|
||||
|
||||
signature = signing_key.sign(data, verify)
|
||||
|
||||
return cast(RRSIG, rrsig_template.replace(signature=signature))
|
||||
|
||||
|
||||
|
@ -858,9 +637,12 @@ def _make_rrsig_signature_data(
|
|||
raise ValidationFailure("relative RR name without an origin specified")
|
||||
rrname = rrname.derelativize(origin)
|
||||
|
||||
if len(rrname) - 1 < rrsig.labels:
|
||||
name_len = len(rrname)
|
||||
if rrname.is_wild() and rrsig.labels != name_len - 2:
|
||||
raise ValidationFailure("wild owner name has wrong label length")
|
||||
if name_len - 1 < rrsig.labels:
|
||||
raise ValidationFailure("owner name longer than RRSIG labels")
|
||||
elif rrsig.labels < len(rrname) - 1:
|
||||
elif rrsig.labels < name_len - 1:
|
||||
suffix = rrname.split(rrsig.labels + 1)[1]
|
||||
rrname = dns.name.from_text("*", suffix)
|
||||
rrnamebuf = rrname.to_digestable()
|
||||
|
@ -884,9 +666,8 @@ def _make_dnskey(
|
|||
) -> DNSKEY:
|
||||
"""Convert a public key to DNSKEY Rdata
|
||||
|
||||
*public_key*, the public key to convert, a
|
||||
``cryptography.hazmat.primitives.asymmetric`` public key class applicable
|
||||
for DNSSEC.
|
||||
*public_key*, a ``PublicKey`` (``GenericPublicKey`` or
|
||||
``cryptography.hazmat.primitives.asymmetric``) to convert.
|
||||
|
||||
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||
|
||||
|
@ -902,72 +683,13 @@ def _make_dnskey(
|
|||
Return DNSKEY ``Rdata``.
|
||||
"""
|
||||
|
||||
def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes:
|
||||
"""Encode a public key per RFC 3110, section 2."""
|
||||
pn = public_key.public_numbers()
|
||||
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
|
||||
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
|
||||
if _exp_len > 255:
|
||||
exp_header = b"\0" + struct.pack("!H", _exp_len)
|
||||
algorithm = Algorithm.make(algorithm)
|
||||
|
||||
if isinstance(public_key, GenericPublicKey):
|
||||
return public_key.to_dnskey(flags=flags, protocol=protocol)
|
||||
else:
|
||||
exp_header = struct.pack("!B", _exp_len)
|
||||
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
|
||||
raise ValueError("unsupported RSA key length")
|
||||
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
|
||||
|
||||
def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes:
|
||||
"""Encode a public key per RFC 2536, section 2."""
|
||||
pn = public_key.public_numbers()
|
||||
dsa_t = (public_key.key_size // 8 - 64) // 8
|
||||
if dsa_t > 8:
|
||||
raise ValueError("unsupported DSA key size")
|
||||
octets = 64 + dsa_t * 8
|
||||
res = struct.pack("!B", dsa_t)
|
||||
res += pn.parameter_numbers.q.to_bytes(20, "big")
|
||||
res += pn.parameter_numbers.p.to_bytes(octets, "big")
|
||||
res += pn.parameter_numbers.g.to_bytes(octets, "big")
|
||||
res += pn.y.to_bytes(octets, "big")
|
||||
return res
|
||||
|
||||
def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes:
|
||||
"""Encode a public key per RFC 6605, section 4."""
|
||||
pn = public_key.public_numbers()
|
||||
if isinstance(public_key.curve, ec.SECP256R1):
|
||||
return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big")
|
||||
elif isinstance(public_key.curve, ec.SECP384R1):
|
||||
return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big")
|
||||
else:
|
||||
raise ValueError("unsupported ECDSA curve")
|
||||
|
||||
the_algorithm = Algorithm.make(algorithm)
|
||||
|
||||
_ensure_algorithm_key_combination(the_algorithm, public_key)
|
||||
|
||||
if isinstance(public_key, rsa.RSAPublicKey):
|
||||
key_bytes = encode_rsa_public_key(public_key)
|
||||
elif isinstance(public_key, dsa.DSAPublicKey):
|
||||
key_bytes = encode_dsa_public_key(public_key)
|
||||
elif isinstance(public_key, ec.EllipticCurvePublicKey):
|
||||
key_bytes = encode_ecdsa_public_key(public_key)
|
||||
elif isinstance(public_key, ed25519.Ed25519PublicKey):
|
||||
key_bytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
elif isinstance(public_key, ed448.Ed448PublicKey):
|
||||
key_bytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
else:
|
||||
raise TypeError("unsupported key algorithm")
|
||||
|
||||
return DNSKEY(
|
||||
rdclass=dns.rdataclass.IN,
|
||||
rdtype=dns.rdatatype.DNSKEY,
|
||||
flags=flags,
|
||||
protocol=protocol,
|
||||
algorithm=the_algorithm,
|
||||
key=key_bytes,
|
||||
)
|
||||
public_cls = get_algorithm_cls(algorithm).public_cls
|
||||
return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol)
|
||||
|
||||
|
||||
def _make_cdnskey(
|
||||
|
@ -1216,23 +938,252 @@ def dnskey_rdataset_to_cdnskey_rdataset(
|
|||
return dns.rdataset.from_rdata_list(rdataset.ttl, res)
|
||||
|
||||
|
||||
def default_rrset_signer(
|
||||
txn: dns.transaction.Transaction,
|
||||
rrset: dns.rrset.RRset,
|
||||
signer: dns.name.Name,
|
||||
ksks: List[Tuple[PrivateKey, DNSKEY]],
|
||||
zsks: List[Tuple[PrivateKey, DNSKEY]],
|
||||
inception: Optional[Union[datetime, str, int, float]] = None,
|
||||
expiration: Optional[Union[datetime, str, int, float]] = None,
|
||||
lifetime: Optional[int] = None,
|
||||
policy: Optional[Policy] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
) -> None:
|
||||
"""Default RRset signer"""
|
||||
|
||||
if rrset.rdtype in set(
|
||||
[
|
||||
dns.rdatatype.RdataType.DNSKEY,
|
||||
dns.rdatatype.RdataType.CDS,
|
||||
dns.rdatatype.RdataType.CDNSKEY,
|
||||
]
|
||||
):
|
||||
keys = ksks
|
||||
else:
|
||||
keys = zsks
|
||||
|
||||
for private_key, dnskey in keys:
|
||||
rrsig = dns.dnssec.sign(
|
||||
rrset=rrset,
|
||||
private_key=private_key,
|
||||
dnskey=dnskey,
|
||||
inception=inception,
|
||||
expiration=expiration,
|
||||
lifetime=lifetime,
|
||||
signer=signer,
|
||||
policy=policy,
|
||||
origin=origin,
|
||||
)
|
||||
txn.add(rrset.name, rrset.ttl, rrsig)
|
||||
|
||||
|
||||
def sign_zone(
|
||||
zone: dns.zone.Zone,
|
||||
txn: Optional[dns.transaction.Transaction] = None,
|
||||
keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None,
|
||||
add_dnskey: bool = True,
|
||||
dnskey_ttl: Optional[int] = None,
|
||||
inception: Optional[Union[datetime, str, int, float]] = None,
|
||||
expiration: Optional[Union[datetime, str, int, float]] = None,
|
||||
lifetime: Optional[int] = None,
|
||||
nsec3: Optional[NSEC3PARAM] = None,
|
||||
rrset_signer: Optional[RRsetSigner] = None,
|
||||
policy: Optional[Policy] = None,
|
||||
) -> None:
|
||||
"""Sign zone.
|
||||
|
||||
*zone*, a ``dns.zone.Zone``, the zone to sign.
|
||||
|
||||
*txn*, a ``dns.transaction.Transaction``, an optional transaction to use for
|
||||
signing.
|
||||
|
||||
*keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK
|
||||
roles are assigned automatically if the SEP flag is used, otherwise all RRsets are
|
||||
signed by all keys.
|
||||
|
||||
*add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are
|
||||
automatically added to the zone on signing.
|
||||
|
||||
*dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL
|
||||
of the existing DNSKEY RRset used or the TTL of the SOA RRset.
|
||||
|
||||
*inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
|
||||
inception time. If ``None``, the current time is used. If a ``str``, the format is
|
||||
"YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text
|
||||
form; this is the same the RRSIG rdata's text form. Values of type `int` or `float`
|
||||
are interpreted as seconds since the UNIX epoch.
|
||||
|
||||
*expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
|
||||
expiration time. If ``None``, the expiration time will be the inception time plus
|
||||
the value of the *lifetime* parameter. See the description of *inception* above for
|
||||
how the various parameter types are interpreted.
|
||||
|
||||
*lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This
|
||||
parameter is only meaningful if *expiration* is ``None``.
|
||||
|
||||
*nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet
|
||||
implemented.
|
||||
|
||||
*rrset_signer*, a ``Callable``, an optional function for signing RRsets. The
|
||||
function requires two arguments: transaction and RRset. If the not specified,
|
||||
``dns.dnssec.default_rrset_signer`` will be used.
|
||||
|
||||
Returns ``None``.
|
||||
"""
|
||||
|
||||
ksks = []
|
||||
zsks = []
|
||||
|
||||
# if we have both KSKs and ZSKs, split by SEP flag. if not, sign all
|
||||
# records with all keys
|
||||
if keys:
|
||||
for key in keys:
|
||||
if key[1].flags & Flag.SEP:
|
||||
ksks.append(key)
|
||||
else:
|
||||
zsks.append(key)
|
||||
if not ksks:
|
||||
ksks = keys
|
||||
if not zsks:
|
||||
zsks = keys
|
||||
else:
|
||||
keys = []
|
||||
|
||||
if txn:
|
||||
cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn)
|
||||
else:
|
||||
cm = zone.writer()
|
||||
|
||||
with cm as _txn:
|
||||
if add_dnskey:
|
||||
if dnskey_ttl is None:
|
||||
dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY)
|
||||
if dnskey:
|
||||
dnskey_ttl = dnskey.ttl
|
||||
else:
|
||||
soa = _txn.get(zone.origin, dns.rdatatype.SOA)
|
||||
dnskey_ttl = soa.ttl
|
||||
for _, dnskey in keys:
|
||||
_txn.add(zone.origin, dnskey_ttl, dnskey)
|
||||
|
||||
if nsec3:
|
||||
raise NotImplementedError("Signing with NSEC3 not yet implemented")
|
||||
else:
|
||||
_rrset_signer = rrset_signer or functools.partial(
|
||||
default_rrset_signer,
|
||||
signer=zone.origin,
|
||||
ksks=ksks,
|
||||
zsks=zsks,
|
||||
inception=inception,
|
||||
expiration=expiration,
|
||||
lifetime=lifetime,
|
||||
policy=policy,
|
||||
origin=zone.origin,
|
||||
)
|
||||
return _sign_zone_nsec(zone, _txn, _rrset_signer)
|
||||
|
||||
|
||||
def _sign_zone_nsec(
|
||||
zone: dns.zone.Zone,
|
||||
txn: dns.transaction.Transaction,
|
||||
rrset_signer: Optional[RRsetSigner] = None,
|
||||
) -> None:
|
||||
"""NSEC zone signer"""
|
||||
|
||||
def _txn_add_nsec(
|
||||
txn: dns.transaction.Transaction,
|
||||
name: dns.name.Name,
|
||||
next_secure: Optional[dns.name.Name],
|
||||
rdclass: dns.rdataclass.RdataClass,
|
||||
ttl: int,
|
||||
rrset_signer: Optional[RRsetSigner] = None,
|
||||
) -> None:
|
||||
"""NSEC zone signer helper"""
|
||||
mandatory_types = set(
|
||||
[dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC]
|
||||
)
|
||||
node = txn.get_node(name)
|
||||
if node and next_secure:
|
||||
types = (
|
||||
set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types
|
||||
)
|
||||
windows = Bitmap.from_rdtypes(list(types))
|
||||
rrset = dns.rrset.from_rdata(
|
||||
name,
|
||||
ttl,
|
||||
NSEC(
|
||||
rdclass=rdclass,
|
||||
rdtype=dns.rdatatype.RdataType.NSEC,
|
||||
next=next_secure,
|
||||
windows=windows,
|
||||
),
|
||||
)
|
||||
txn.add(rrset)
|
||||
if rrset_signer:
|
||||
rrset_signer(txn, rrset)
|
||||
|
||||
rrsig_ttl = zone.get_soa().minimum
|
||||
delegation = None
|
||||
last_secure = None
|
||||
|
||||
for name in sorted(txn.iterate_names()):
|
||||
if delegation and name.is_subdomain(delegation):
|
||||
# names below delegations are not secure
|
||||
continue
|
||||
elif txn.get(name, dns.rdatatype.NS) and name != zone.origin:
|
||||
# inside delegation
|
||||
delegation = name
|
||||
else:
|
||||
# outside delegation
|
||||
delegation = None
|
||||
|
||||
if rrset_signer:
|
||||
node = txn.get_node(name)
|
||||
if node:
|
||||
for rdataset in node.rdatasets:
|
||||
if rdataset.rdtype == dns.rdatatype.RRSIG:
|
||||
# do not sign RRSIGs
|
||||
continue
|
||||
elif delegation and rdataset.rdtype != dns.rdatatype.DS:
|
||||
# do not sign delegations except DS records
|
||||
continue
|
||||
else:
|
||||
rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset)
|
||||
rrset_signer(txn, rrset)
|
||||
|
||||
# We need "is not None" as the empty name is False because its length is 0.
|
||||
if last_secure is not None:
|
||||
_txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer)
|
||||
last_secure = name
|
||||
|
||||
if last_secure:
|
||||
_txn_add_nsec(
|
||||
txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer
|
||||
)
|
||||
|
||||
|
||||
def _need_pyca(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"DNSSEC validation requires " + "python cryptography"
|
||||
"DNSSEC validation requires python cryptography"
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
try:
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric import utils
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from cryptography.hazmat.primitives.asymmetric import ed448
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
|
||||
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
|
||||
from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611
|
||||
from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611
|
||||
ed25519,
|
||||
)
|
||||
|
||||
from dns.dnssecalgs import ( # pylint: disable=C0412
|
||||
get_algorithm_cls,
|
||||
get_algorithm_cls_from_dnskey,
|
||||
)
|
||||
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
|
||||
except ImportError: # pragma: no cover
|
||||
validate = _need_pyca
|
||||
validate_rrsig = _need_pyca
|
||||
|
|
121
lib/dns/dnssecalgs/__init__.py
Normal file
121
lib/dns/dnssecalgs/__init__.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
from typing import Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import dns.name
|
||||
|
||||
try:
|
||||
from dns.dnssecalgs.base import GenericPrivateKey
|
||||
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
|
||||
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
|
||||
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
|
||||
from dns.dnssecalgs.rsa import (
|
||||
PrivateRSAMD5,
|
||||
PrivateRSASHA1,
|
||||
PrivateRSASHA1NSEC3SHA1,
|
||||
PrivateRSASHA256,
|
||||
PrivateRSASHA512,
|
||||
)
|
||||
|
||||
_have_cryptography = True
|
||||
except ImportError:
|
||||
_have_cryptography = False
|
||||
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.exception import UnsupportedAlgorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
|
||||
|
||||
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
|
||||
if _have_cryptography:
|
||||
algorithms.update(
|
||||
{
|
||||
(Algorithm.RSAMD5, None): PrivateRSAMD5,
|
||||
(Algorithm.DSA, None): PrivateDSA,
|
||||
(Algorithm.RSASHA1, None): PrivateRSASHA1,
|
||||
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
|
||||
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
|
||||
(Algorithm.RSASHA256, None): PrivateRSASHA256,
|
||||
(Algorithm.RSASHA512, None): PrivateRSASHA512,
|
||||
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
|
||||
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
|
||||
(Algorithm.ED25519, None): PrivateED25519,
|
||||
(Algorithm.ED448, None): PrivateED448,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_algorithm_cls(
|
||||
algorithm: Union[int, str], prefix: AlgorithmPrefix = None
|
||||
) -> Type[GenericPrivateKey]:
|
||||
"""Get Private Key class from Algorithm.
|
||||
|
||||
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||
|
||||
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
|
||||
|
||||
Returns a ``dns.dnssecalgs.GenericPrivateKey``
|
||||
"""
|
||||
algorithm = Algorithm.make(algorithm)
|
||||
cls = algorithms.get((algorithm, prefix))
|
||||
if cls:
|
||||
return cls
|
||||
raise UnsupportedAlgorithm(
|
||||
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
|
||||
)
|
||||
|
||||
|
||||
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
|
||||
"""Get Private Key class from DNSKEY.
|
||||
|
||||
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
|
||||
|
||||
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
|
||||
|
||||
Returns a ``dns.dnssecalgs.GenericPrivateKey``
|
||||
"""
|
||||
prefix: AlgorithmPrefix = None
|
||||
if dnskey.algorithm == Algorithm.PRIVATEDNS:
|
||||
prefix, _ = dns.name.from_wire(dnskey.key, 0)
|
||||
elif dnskey.algorithm == Algorithm.PRIVATEOID:
|
||||
length = int(dnskey.key[0])
|
||||
prefix = dnskey.key[0 : length + 1]
|
||||
return get_algorithm_cls(dnskey.algorithm, prefix)
|
||||
|
||||
|
||||
def register_algorithm_cls(
|
||||
algorithm: Union[int, str],
|
||||
algorithm_cls: Type[GenericPrivateKey],
|
||||
name: Optional[Union[dns.name.Name, str]] = None,
|
||||
oid: Optional[bytes] = None,
|
||||
) -> None:
|
||||
"""Register Algorithm Private Key class.
|
||||
|
||||
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||
|
||||
*algorithm_cls*: A `GenericPrivateKey` class.
|
||||
|
||||
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
|
||||
|
||||
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
|
||||
|
||||
Raises ``ValueError`` if a name or oid is specified incorrectly.
|
||||
"""
|
||||
if not issubclass(algorithm_cls, GenericPrivateKey):
|
||||
raise TypeError("Invalid algorithm class")
|
||||
algorithm = Algorithm.make(algorithm)
|
||||
prefix: AlgorithmPrefix = None
|
||||
if algorithm == Algorithm.PRIVATEDNS:
|
||||
if name is None:
|
||||
raise ValueError("Name required for PRIVATEDNS algorithms")
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name)
|
||||
prefix = name
|
||||
elif algorithm == Algorithm.PRIVATEOID:
|
||||
if oid is None:
|
||||
raise ValueError("OID required for PRIVATEOID algorithms")
|
||||
prefix = bytes([len(oid)]) + oid
|
||||
elif name:
|
||||
raise ValueError("Name only supported for PRIVATEDNS algorithm")
|
||||
elif oid:
|
||||
raise ValueError("OID only supported for PRIVATEOID algorithm")
|
||||
algorithms[(algorithm, prefix)] = algorithm_cls
|
84
lib/dns/dnssecalgs/base.py
Normal file
84
lib/dns/dnssecalgs/base.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.exception import AlgorithmKeyMismatch
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
from dns.rdtypes.dnskeybase import Flag
|
||||
|
||||
|
||||
class GenericPublicKey(ABC):
|
||||
algorithm: Algorithm
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, key: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
"""Verify signed DNSSEC data"""
|
||||
|
||||
@abstractmethod
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode key as bytes for DNSKEY"""
|
||||
|
||||
@classmethod
|
||||
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
|
||||
if key.algorithm != cls.algorithm:
|
||||
raise AlgorithmKeyMismatch
|
||||
|
||||
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
|
||||
"""Return public key as DNSKEY"""
|
||||
return DNSKEY(
|
||||
rdclass=dns.rdataclass.IN,
|
||||
rdtype=dns.rdatatype.DNSKEY,
|
||||
flags=flags,
|
||||
protocol=protocol,
|
||||
algorithm=self.algorithm,
|
||||
key=self.encode_key_bytes(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
|
||||
"""Create public key from DNSKEY"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
|
||||
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
|
||||
in RFC 5280"""
|
||||
|
||||
@abstractmethod
|
||||
def to_pem(self) -> bytes:
|
||||
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
|
||||
in RFC 5280"""
|
||||
|
||||
|
||||
class GenericPrivateKey(ABC):
|
||||
public_cls: Type[GenericPublicKey]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, key: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
"""Sign DNSSEC data"""
|
||||
|
||||
@abstractmethod
|
||||
def public_key(self) -> "GenericPublicKey":
|
||||
"""Return public key instance"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pem(
|
||||
cls, private_pem: bytes, password: Optional[bytes] = None
|
||||
) -> "GenericPrivateKey":
|
||||
"""Create private key from PEM-encoded PKCS#8"""
|
||||
|
||||
@abstractmethod
|
||||
def to_pem(self, password: Optional[bytes] = None) -> bytes:
|
||||
"""Return private key as PEM-encoded PKCS#8"""
|
68
lib/dns/dnssecalgs/cryptography.py
Normal file
68
lib/dns/dnssecalgs/cryptography.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
from typing import Any, Optional, Type
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
|
||||
from dns.exception import AlgorithmKeyMismatch
|
||||
|
||||
|
||||
class CryptographyPublicKey(GenericPublicKey):
|
||||
key: Any = None
|
||||
key_cls: Any = None
|
||||
|
||||
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
|
||||
if self.key_cls is None:
|
||||
raise TypeError("Undefined private key class")
|
||||
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
|
||||
key, self.key_cls
|
||||
):
|
||||
raise AlgorithmKeyMismatch
|
||||
self.key = key
|
||||
|
||||
@classmethod
|
||||
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
|
||||
key = serialization.load_pem_public_key(public_pem)
|
||||
return cls(key=key)
|
||||
|
||||
def to_pem(self) -> bytes:
|
||||
return self.key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
|
||||
class CryptographyPrivateKey(GenericPrivateKey):
|
||||
key: Any = None
|
||||
key_cls: Any = None
|
||||
public_cls: Type[CryptographyPublicKey]
|
||||
|
||||
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
|
||||
if self.key_cls is None:
|
||||
raise TypeError("Undefined private key class")
|
||||
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
|
||||
key, self.key_cls
|
||||
):
|
||||
raise AlgorithmKeyMismatch
|
||||
self.key = key
|
||||
|
||||
def public_key(self) -> "CryptographyPublicKey":
|
||||
return self.public_cls(key=self.key.public_key())
|
||||
|
||||
@classmethod
|
||||
def from_pem(
|
||||
cls, private_pem: bytes, password: Optional[bytes] = None
|
||||
) -> "GenericPrivateKey":
|
||||
key = serialization.load_pem_private_key(private_pem, password=password)
|
||||
return cls(key=key)
|
||||
|
||||
def to_pem(self, password: Optional[bytes] = None) -> bytes:
|
||||
encryption_algorithm: serialization.KeySerializationEncryption
|
||||
if password:
|
||||
encryption_algorithm = serialization.BestAvailableEncryption(password)
|
||||
else:
|
||||
encryption_algorithm = serialization.NoEncryption()
|
||||
return self.key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=encryption_algorithm,
|
||||
)
|
101
lib/dns/dnssecalgs/dsa.py
Normal file
101
lib/dns/dnssecalgs/dsa.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import struct
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa, utils
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicDSA(CryptographyPublicKey):
|
||||
key: dsa.DSAPublicKey
|
||||
key_cls = dsa.DSAPublicKey
|
||||
algorithm = Algorithm.DSA
|
||||
chosen_hash = hashes.SHA1()
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
sig_r = signature[1:21]
|
||||
sig_s = signature[21:]
|
||||
sig = utils.encode_dss_signature(
|
||||
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
|
||||
)
|
||||
self.key.verify(sig, data, self.chosen_hash)
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 2536, section 2."""
|
||||
pn = self.key.public_numbers()
|
||||
dsa_t = (self.key.key_size // 8 - 64) // 8
|
||||
if dsa_t > 8:
|
||||
raise ValueError("unsupported DSA key size")
|
||||
octets = 64 + dsa_t * 8
|
||||
res = struct.pack("!B", dsa_t)
|
||||
res += pn.parameter_numbers.q.to_bytes(20, "big")
|
||||
res += pn.parameter_numbers.p.to_bytes(octets, "big")
|
||||
res += pn.parameter_numbers.g.to_bytes(octets, "big")
|
||||
res += pn.y.to_bytes(octets, "big")
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
keyptr = key.key
|
||||
(t,) = struct.unpack("!B", keyptr[0:1])
|
||||
keyptr = keyptr[1:]
|
||||
octets = 64 + t * 8
|
||||
dsa_q = keyptr[0:20]
|
||||
keyptr = keyptr[20:]
|
||||
dsa_p = keyptr[0:octets]
|
||||
keyptr = keyptr[octets:]
|
||||
dsa_g = keyptr[0:octets]
|
||||
keyptr = keyptr[octets:]
|
||||
dsa_y = keyptr[0:octets]
|
||||
return cls(
|
||||
key=dsa.DSAPublicNumbers( # type: ignore
|
||||
int.from_bytes(dsa_y, "big"),
|
||||
dsa.DSAParameterNumbers(
|
||||
int.from_bytes(dsa_p, "big"),
|
||||
int.from_bytes(dsa_q, "big"),
|
||||
int.from_bytes(dsa_g, "big"),
|
||||
),
|
||||
).public_key(default_backend()),
|
||||
)
|
||||
|
||||
|
||||
class PrivateDSA(CryptographyPrivateKey):
|
||||
key: dsa.DSAPrivateKey
|
||||
key_cls = dsa.DSAPrivateKey
|
||||
public_cls = PublicDSA
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
"""Sign using a private key per RFC 2536, section 3."""
|
||||
public_dsa_key = self.key.public_key()
|
||||
if public_dsa_key.key_size > 1024:
|
||||
raise ValueError("DSA key size overflow")
|
||||
der_signature = self.key.sign(data, self.public_cls.chosen_hash)
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
|
||||
octets = 20
|
||||
signature = (
|
||||
struct.pack("!B", dsa_t)
|
||||
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
|
||||
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
|
||||
)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls, key_size: int) -> "PrivateDSA":
|
||||
return cls(
|
||||
key=dsa.generate_private_key(key_size=key_size),
|
||||
)
|
||||
|
||||
|
||||
class PublicDSANSEC3SHA1(PublicDSA):
|
||||
algorithm = Algorithm.DSANSEC3SHA1
|
||||
|
||||
|
||||
class PrivateDSANSEC3SHA1(PrivateDSA):
|
||||
public_cls = PublicDSANSEC3SHA1
|
89
lib/dns/dnssecalgs/ecdsa.py
Normal file
89
lib/dns/dnssecalgs/ecdsa.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, utils
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicECDSA(CryptographyPublicKey):
|
||||
key: ec.EllipticCurvePublicKey
|
||||
key_cls = ec.EllipticCurvePublicKey
|
||||
algorithm: Algorithm
|
||||
chosen_hash: hashes.HashAlgorithm
|
||||
curve: ec.EllipticCurve
|
||||
octets: int
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
sig_r = signature[0 : self.octets]
|
||||
sig_s = signature[self.octets :]
|
||||
sig = utils.encode_dss_signature(
|
||||
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
|
||||
)
|
||||
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 6605, section 4."""
|
||||
pn = self.key.public_numbers()
|
||||
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
ecdsa_x = key.key[0 : cls.octets]
|
||||
ecdsa_y = key.key[cls.octets : cls.octets * 2]
|
||||
return cls(
|
||||
key=ec.EllipticCurvePublicNumbers(
|
||||
curve=cls.curve,
|
||||
x=int.from_bytes(ecdsa_x, "big"),
|
||||
y=int.from_bytes(ecdsa_y, "big"),
|
||||
).public_key(default_backend()),
|
||||
)
|
||||
|
||||
|
||||
class PrivateECDSA(CryptographyPrivateKey):
|
||||
key: ec.EllipticCurvePrivateKey
|
||||
key_cls = ec.EllipticCurvePrivateKey
|
||||
public_cls = PublicECDSA
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
"""Sign using a private key per RFC 6605, section 4."""
|
||||
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
signature = int.to_bytes(
|
||||
dsa_r, length=self.public_cls.octets, byteorder="big"
|
||||
) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big")
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PrivateECDSA":
|
||||
return cls(
|
||||
key=ec.generate_private_key(
|
||||
curve=cls.public_cls.curve, backend=default_backend()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PublicECDSAP256SHA256(PublicECDSA):
|
||||
algorithm = Algorithm.ECDSAP256SHA256
|
||||
chosen_hash = hashes.SHA256()
|
||||
curve = ec.SECP256R1()
|
||||
octets = 32
|
||||
|
||||
|
||||
class PrivateECDSAP256SHA256(PrivateECDSA):
|
||||
public_cls = PublicECDSAP256SHA256
|
||||
|
||||
|
||||
class PublicECDSAP384SHA384(PublicECDSA):
|
||||
algorithm = Algorithm.ECDSAP384SHA384
|
||||
chosen_hash = hashes.SHA384()
|
||||
curve = ec.SECP384R1()
|
||||
octets = 48
|
||||
|
||||
|
||||
class PrivateECDSAP384SHA384(PrivateECDSA):
|
||||
public_cls = PublicECDSAP384SHA384
|
65
lib/dns/dnssecalgs/eddsa.py
Normal file
65
lib/dns/dnssecalgs/eddsa.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
from typing import Type
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicEDDSA(CryptographyPublicKey):
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
self.key.verify(signature, data)
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 8080, section 3."""
|
||||
return self.key.public_bytes(
|
||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
return cls(
|
||||
key=cls.key_cls.from_public_bytes(key.key),
|
||||
)
|
||||
|
||||
|
||||
class PrivateEDDSA(CryptographyPrivateKey):
|
||||
public_cls: Type[PublicEDDSA]
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
"""Sign using a private key per RFC 8080, section 4."""
|
||||
signature = self.key.sign(data)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PrivateEDDSA":
|
||||
return cls(key=cls.key_cls.generate())
|
||||
|
||||
|
||||
class PublicED25519(PublicEDDSA):
|
||||
key: ed25519.Ed25519PublicKey
|
||||
key_cls = ed25519.Ed25519PublicKey
|
||||
algorithm = Algorithm.ED25519
|
||||
|
||||
|
||||
class PrivateED25519(PrivateEDDSA):
|
||||
key: ed25519.Ed25519PrivateKey
|
||||
key_cls = ed25519.Ed25519PrivateKey
|
||||
public_cls = PublicED25519
|
||||
|
||||
|
||||
class PublicED448(PublicEDDSA):
|
||||
key: ed448.Ed448PublicKey
|
||||
key_cls = ed448.Ed448PublicKey
|
||||
algorithm = Algorithm.ED448
|
||||
|
||||
|
||||
class PrivateED448(PrivateEDDSA):
|
||||
key: ed448.Ed448PrivateKey
|
||||
key_cls = ed448.Ed448PrivateKey
|
||||
public_cls = PublicED448
|
119
lib/dns/dnssecalgs/rsa.py
Normal file
119
lib/dns/dnssecalgs/rsa.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import math
|
||||
import struct
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
|
||||
class PublicRSA(CryptographyPublicKey):
|
||||
key: rsa.RSAPublicKey
|
||||
key_cls = rsa.RSAPublicKey
|
||||
algorithm: Algorithm
|
||||
chosen_hash: hashes.HashAlgorithm
|
||||
|
||||
def verify(self, signature: bytes, data: bytes) -> None:
|
||||
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
|
||||
|
||||
def encode_key_bytes(self) -> bytes:
|
||||
"""Encode a public key per RFC 3110, section 2."""
|
||||
pn = self.key.public_numbers()
|
||||
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
|
||||
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
|
||||
if _exp_len > 255:
|
||||
exp_header = b"\0" + struct.pack("!H", _exp_len)
|
||||
else:
|
||||
exp_header = struct.pack("!B", _exp_len)
|
||||
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
|
||||
raise ValueError("unsupported RSA key length")
|
||||
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
|
||||
|
||||
@classmethod
|
||||
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
|
||||
cls._ensure_algorithm_key_combination(key)
|
||||
keyptr = key.key
|
||||
(bytes_,) = struct.unpack("!B", keyptr[0:1])
|
||||
keyptr = keyptr[1:]
|
||||
if bytes_ == 0:
|
||||
(bytes_,) = struct.unpack("!H", keyptr[0:2])
|
||||
keyptr = keyptr[2:]
|
||||
rsa_e = keyptr[0:bytes_]
|
||||
rsa_n = keyptr[bytes_:]
|
||||
return cls(
|
||||
key=rsa.RSAPublicNumbers(
|
||||
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
|
||||
).public_key(default_backend())
|
||||
)
|
||||
|
||||
|
||||
class PrivateRSA(CryptographyPrivateKey):
|
||||
key: rsa.RSAPrivateKey
|
||||
key_cls = rsa.RSAPrivateKey
|
||||
public_cls = PublicRSA
|
||||
default_public_exponent = 65537
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
"""Sign using a private key per RFC 3110, section 3."""
|
||||
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
|
||||
if verify:
|
||||
self.public_key().verify(signature, data)
|
||||
return signature
|
||||
|
||||
@classmethod
|
||||
def generate(cls, key_size: int) -> "PrivateRSA":
|
||||
return cls(
|
||||
key=rsa.generate_private_key(
|
||||
public_exponent=cls.default_public_exponent,
|
||||
key_size=key_size,
|
||||
backend=default_backend(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PublicRSAMD5(PublicRSA):
|
||||
algorithm = Algorithm.RSAMD5
|
||||
chosen_hash = hashes.MD5()
|
||||
|
||||
|
||||
class PrivateRSAMD5(PrivateRSA):
|
||||
public_cls = PublicRSAMD5
|
||||
|
||||
|
||||
class PublicRSASHA1(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA1
|
||||
chosen_hash = hashes.SHA1()
|
||||
|
||||
|
||||
class PrivateRSASHA1(PrivateRSA):
|
||||
public_cls = PublicRSASHA1
|
||||
|
||||
|
||||
class PublicRSASHA1NSEC3SHA1(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA1NSEC3SHA1
|
||||
chosen_hash = hashes.SHA1()
|
||||
|
||||
|
||||
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
|
||||
public_cls = PublicRSASHA1NSEC3SHA1
|
||||
|
||||
|
||||
class PublicRSASHA256(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA256
|
||||
chosen_hash = hashes.SHA256()
|
||||
|
||||
|
||||
class PrivateRSASHA256(PrivateRSA):
|
||||
public_cls = PublicRSASHA256
|
||||
|
||||
|
||||
class PublicRSASHA512(PublicRSA):
|
||||
algorithm = Algorithm.RSASHA512
|
||||
chosen_hash = hashes.SHA512()
|
||||
|
||||
|
||||
class PrivateRSASHA512(PrivateRSA):
|
||||
public_cls = PublicRSASHA512
|
|
@ -17,11 +17,10 @@
|
|||
|
||||
"""EDNS Options"""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import math
|
||||
import socket
|
||||
import struct
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import dns.enum
|
||||
import dns.inet
|
||||
|
@ -380,7 +379,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
|
|||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
||||
) -> Option:
|
||||
the_code = EDECode.make(parser.get_uint16())
|
||||
code = EDECode.make(parser.get_uint16())
|
||||
text = parser.get_remaining()
|
||||
|
||||
if text:
|
||||
|
@ -390,7 +389,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
|
|||
else:
|
||||
btext = None
|
||||
|
||||
return cls(the_code, btext)
|
||||
return cls(code, btext)
|
||||
|
||||
|
||||
_type_to_class: Dict[OptionType, Any] = {
|
||||
|
@ -424,8 +423,8 @@ def option_from_wire_parser(
|
|||
|
||||
Returns an instance of a subclass of ``dns.edns.Option``.
|
||||
"""
|
||||
the_otype = OptionType.make(otype)
|
||||
cls = get_option_class(the_otype)
|
||||
otype = OptionType.make(otype)
|
||||
cls = get_option_class(otype)
|
||||
return cls.from_wire_parser(otype, parser)
|
||||
|
||||
|
||||
|
|
|
@ -15,17 +15,15 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class EntropyPool:
|
||||
|
||||
# This is an entropy pool for Python implementations that do not
|
||||
# have a working SystemRandom. I'm not sure there are any, but
|
||||
# leaving this code doesn't hurt anything as the library code
|
||||
|
|
|
@ -16,18 +16,31 @@
|
|||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import enum
|
||||
from typing import Type, TypeVar, Union
|
||||
|
||||
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
|
||||
|
||||
|
||||
class IntEnum(enum.IntEnum):
|
||||
@classmethod
|
||||
def _check_value(cls, value):
|
||||
max = cls._maximum()
|
||||
if value < 0 or value > max:
|
||||
name = cls._short_name()
|
||||
raise ValueError(f"{name} must be between >= 0 and <= {max}")
|
||||
def _missing_(cls, value):
|
||||
cls._check_value(value)
|
||||
val = int.__new__(cls, value)
|
||||
val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
|
||||
val._value_ = value
|
||||
return val
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text):
|
||||
def _check_value(cls, value):
|
||||
max = cls._maximum()
|
||||
if not isinstance(value, int):
|
||||
raise TypeError
|
||||
if value < 0 or value > max:
|
||||
name = cls._short_name()
|
||||
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
|
||||
|
||||
@classmethod
|
||||
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
|
||||
text = text.upper()
|
||||
try:
|
||||
return cls[text]
|
||||
|
@ -47,7 +60,7 @@ class IntEnum(enum.IntEnum):
|
|||
raise cls._unknown_exception_class()
|
||||
|
||||
@classmethod
|
||||
def to_text(cls, value):
|
||||
def to_text(cls: Type[TIntEnum], value: int) -> str:
|
||||
cls._check_value(value)
|
||||
try:
|
||||
text = cls(value).name
|
||||
|
@ -59,7 +72,7 @@ class IntEnum(enum.IntEnum):
|
|||
return text
|
||||
|
||||
@classmethod
|
||||
def make(cls, value):
|
||||
def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
|
||||
"""Convert text or a value into an enumerated type, if possible.
|
||||
|
||||
*value*, the ``int`` or ``str`` to convert.
|
||||
|
@ -76,10 +89,7 @@ class IntEnum(enum.IntEnum):
|
|||
if isinstance(value, str):
|
||||
return cls.from_text(value)
|
||||
cls._check_value(value)
|
||||
try:
|
||||
return cls(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
|
|
|
@ -140,6 +140,22 @@ class Timeout(DNSException):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class UnsupportedAlgorithm(DNSException):
|
||||
"""The DNSSEC algorithm is not supported."""
|
||||
|
||||
|
||||
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
|
||||
"""The DNSSEC algorithm is not supported for the given key type."""
|
||||
|
||||
|
||||
class ValidationFailure(DNSException):
|
||||
"""The DNSSEC signature is invalid."""
|
||||
|
||||
|
||||
class DeniedByPolicy(DNSException):
|
||||
"""Denied by DNSSEC policy."""
|
||||
|
||||
|
||||
class ExceptionWrapper:
|
||||
def __init__(self, exception_class):
|
||||
self.exception_class = exception_class
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
"""DNS Message Flags."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
# Standard DNS flags
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
from typing import Any
|
||||
|
||||
import collections.abc
|
||||
from typing import Any
|
||||
|
||||
from dns._immutable_ctx import immutable
|
||||
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
|
||||
"""Generic Internet address helper functions."""
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import socket
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import dns.ipv4
|
||||
import dns.ipv6
|
||||
|
||||
|
||||
# We assume that AF_INET and AF_INET6 are always defined. We keep
|
||||
# these here for the benefit of any old code (unlikely though that
|
||||
# is!).
|
||||
|
@ -171,3 +169,12 @@ def low_level_address_tuple(
|
|||
return tup
|
||||
else:
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
|
||||
|
||||
def any_for_af(af):
|
||||
"""Return the 'any' address for the specified address family."""
|
||||
if af == socket.AF_INET:
|
||||
return "0.0.0.0"
|
||||
elif af == socket.AF_INET6:
|
||||
return "::"
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
"""IPv4 helper functions."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import struct
|
||||
from typing import Union
|
||||
|
||||
import dns.exception
|
||||
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
|
||||
"""IPv6 helper functions."""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import re
|
||||
import binascii
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
import dns.exception
|
||||
import dns.ipv4
|
||||
|
|
|
@ -17,30 +17,29 @@
|
|||
|
||||
"""DNS Messages"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import dns.wire
|
||||
import dns.edns
|
||||
import dns.entropy
|
||||
import dns.enum
|
||||
import dns.exception
|
||||
import dns.flags
|
||||
import dns.name
|
||||
import dns.opcode
|
||||
import dns.entropy
|
||||
import dns.rcode
|
||||
import dns.rdata
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.rrset
|
||||
import dns.renderer
|
||||
import dns.ttl
|
||||
import dns.tsig
|
||||
import dns.rdtypes.ANY.OPT
|
||||
import dns.rdtypes.ANY.TSIG
|
||||
import dns.renderer
|
||||
import dns.rrset
|
||||
import dns.tsig
|
||||
import dns.ttl
|
||||
import dns.wire
|
||||
|
||||
|
||||
class ShortHeader(dns.exception.FormError):
|
||||
|
@ -135,7 +134,7 @@ IndexKeyType = Tuple[
|
|||
Optional[dns.rdataclass.RdataClass],
|
||||
]
|
||||
IndexType = Dict[IndexKeyType, dns.rrset.RRset]
|
||||
SectionType = Union[int, List[dns.rrset.RRset]]
|
||||
SectionType = Union[int, str, List[dns.rrset.RRset]]
|
||||
|
||||
|
||||
class Message:
|
||||
|
@ -231,7 +230,7 @@ class Message:
|
|||
s.write("payload %d\n" % self.payload)
|
||||
for opt in self.options:
|
||||
s.write("option %s\n" % opt.to_text())
|
||||
for (name, which) in self._section_enum.__members__.items():
|
||||
for name, which in self._section_enum.__members__.items():
|
||||
s.write(f";{name}\n")
|
||||
for rrset in self.section_from_number(which):
|
||||
s.write(rrset.to_text(origin, relativize, **kw))
|
||||
|
@ -348,27 +347,29 @@ class Message:
|
|||
deleting: Optional[dns.rdataclass.RdataClass] = None,
|
||||
create: bool = False,
|
||||
force_unique: bool = False,
|
||||
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||
) -> dns.rrset.RRset:
|
||||
"""Find the RRset with the given attributes in the specified section.
|
||||
|
||||
*section*, an ``int`` section number, or one of the section
|
||||
attributes of this message. This specifies the
|
||||
*section*, an ``int`` section number, a ``str`` section name, or one of
|
||||
the section attributes of this message. This specifies the
|
||||
the section of the message to search. For example::
|
||||
|
||||
my_message.find_rrset(my_message.answer, name, rdclass, rdtype)
|
||||
my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype)
|
||||
my_message.find_rrset("ANSWER", name, rdclass, rdtype)
|
||||
|
||||
*name*, a ``dns.name.Name``, the name of the RRset.
|
||||
*name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
|
||||
|
||||
*rdclass*, an ``int``, the class of the RRset.
|
||||
*rdclass*, an ``int`` or ``str``, the class of the RRset.
|
||||
|
||||
*rdtype*, an ``int``, the type of the RRset.
|
||||
*rdtype*, an ``int`` or ``str``, the type of the RRset.
|
||||
|
||||
*covers*, an ``int`` or ``None``, the covers value of the RRset.
|
||||
The default is ``None``.
|
||||
*covers*, an ``int`` or ``str``, the covers value of the RRset.
|
||||
The default is ``dns.rdatatype.NONE``.
|
||||
|
||||
*deleting*, an ``int`` or ``None``, the deleting value of the RRset.
|
||||
The default is ``None``.
|
||||
*deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
|
||||
RRset. The default is ``None``.
|
||||
|
||||
*create*, a ``bool``. If ``True``, create the RRset if it is not found.
|
||||
The created RRset is appended to *section*.
|
||||
|
@ -378,6 +379,10 @@ class Message:
|
|||
already. The default is ``False``. This is useful when creating
|
||||
DDNS Update messages, as order matters for them.
|
||||
|
||||
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
|
||||
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
|
||||
is used.
|
||||
|
||||
Raises ``KeyError`` if the RRset was not found and create was
|
||||
``False``.
|
||||
|
||||
|
@ -386,10 +391,19 @@ class Message:
|
|||
|
||||
if isinstance(section, int):
|
||||
section_number = section
|
||||
the_section = self.section_from_number(section_number)
|
||||
section = self.section_from_number(section_number)
|
||||
elif isinstance(section, str):
|
||||
section_number = MessageSection.from_text(section)
|
||||
section = self.section_from_number(section_number)
|
||||
else:
|
||||
section_number = self.section_number(section)
|
||||
the_section = section
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name, idna_codec=idna_codec)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
covers = dns.rdatatype.RdataType.make(covers)
|
||||
if deleting is not None:
|
||||
deleting = dns.rdataclass.RdataClass.make(deleting)
|
||||
key = (section_number, name, rdclass, rdtype, covers, deleting)
|
||||
if not force_unique:
|
||||
if self.index is not None:
|
||||
|
@ -397,13 +411,13 @@ class Message:
|
|||
if rrset is not None:
|
||||
return rrset
|
||||
else:
|
||||
for rrset in the_section:
|
||||
for rrset in section:
|
||||
if rrset.full_match(name, rdclass, rdtype, covers, deleting):
|
||||
return rrset
|
||||
if not create:
|
||||
raise KeyError
|
||||
rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
|
||||
the_section.append(rrset)
|
||||
section.append(rrset)
|
||||
if self.index is not None:
|
||||
self.index[key] = rrset
|
||||
return rrset
|
||||
|
@ -418,29 +432,31 @@ class Message:
|
|||
deleting: Optional[dns.rdataclass.RdataClass] = None,
|
||||
create: bool = False,
|
||||
force_unique: bool = False,
|
||||
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||
) -> Optional[dns.rrset.RRset]:
|
||||
"""Get the RRset with the given attributes in the specified section.
|
||||
|
||||
If the RRset is not found, None is returned.
|
||||
|
||||
*section*, an ``int`` section number, or one of the section
|
||||
attributes of this message. This specifies the
|
||||
*section*, an ``int`` section number, a ``str`` section name, or one of
|
||||
the section attributes of this message. This specifies the
|
||||
the section of the message to search. For example::
|
||||
|
||||
my_message.get_rrset(my_message.answer, name, rdclass, rdtype)
|
||||
my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype)
|
||||
my_message.get_rrset("ANSWER", name, rdclass, rdtype)
|
||||
|
||||
*name*, a ``dns.name.Name``, the name of the RRset.
|
||||
*name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
|
||||
|
||||
*rdclass*, an ``int``, the class of the RRset.
|
||||
*rdclass*, an ``int`` or ``str``, the class of the RRset.
|
||||
|
||||
*rdtype*, an ``int``, the type of the RRset.
|
||||
*rdtype*, an ``int`` or ``str``, the type of the RRset.
|
||||
|
||||
*covers*, an ``int`` or ``None``, the covers value of the RRset.
|
||||
The default is ``None``.
|
||||
*covers*, an ``int`` or ``str``, the covers value of the RRset.
|
||||
The default is ``dns.rdatatype.NONE``.
|
||||
|
||||
*deleting*, an ``int`` or ``None``, the deleting value of the RRset.
|
||||
The default is ``None``.
|
||||
*deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
|
||||
RRset. The default is ``None``.
|
||||
|
||||
*create*, a ``bool``. If ``True``, create the RRset if it is not found.
|
||||
The created RRset is appended to *section*.
|
||||
|
@ -450,12 +466,24 @@ class Message:
|
|||
already. The default is ``False``. This is useful when creating
|
||||
DDNS Update messages, as order matters for them.
|
||||
|
||||
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
|
||||
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
|
||||
is used.
|
||||
|
||||
Returns a ``dns.rrset.RRset object`` or ``None``.
|
||||
"""
|
||||
|
||||
try:
|
||||
rrset = self.find_rrset(
|
||||
section, name, rdclass, rdtype, covers, deleting, create, force_unique
|
||||
section,
|
||||
name,
|
||||
rdclass,
|
||||
rdtype,
|
||||
covers,
|
||||
deleting,
|
||||
create,
|
||||
force_unique,
|
||||
idna_codec,
|
||||
)
|
||||
except KeyError:
|
||||
rrset = None
|
||||
|
@ -1708,13 +1736,11 @@ def make_query(
|
|||
|
||||
if isinstance(qname, str):
|
||||
qname = dns.name.from_text(qname, idna_codec=idna_codec)
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
m = QueryMessage(id=id)
|
||||
m.flags = dns.flags.Flag(flags)
|
||||
m.find_rrset(
|
||||
m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True
|
||||
)
|
||||
m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
|
||||
# only pass keywords on to use_edns if they have been set to a
|
||||
# non-None value. Setting a field will turn EDNS on if it hasn't
|
||||
# been configured.
|
||||
|
|
|
@ -18,12 +18,10 @@
|
|||
"""DNS Names.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import copy
|
||||
import struct
|
||||
|
||||
import encodings.idna # type: ignore
|
||||
import struct
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
try:
|
||||
import idna # type: ignore
|
||||
|
@ -33,10 +31,9 @@ except ImportError: # pragma: no cover
|
|||
have_idna_2008 = False
|
||||
|
||||
import dns.enum
|
||||
import dns.wire
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
|
||||
import dns.wire
|
||||
|
||||
CompressType = Dict["Name", int]
|
||||
|
||||
|
|
329
lib/dns/nameserver.py
Normal file
329
lib/dns/nameserver.py
Normal file
|
@ -0,0 +1,329 @@
|
|||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.asyncquery
|
||||
import dns.inet
|
||||
import dns.message
|
||||
import dns.query
|
||||
|
||||
|
||||
class Nameserver:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def kind(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_always_max_size(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def answer_nameserver(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def answer_port(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
backend: dns.asyncbackend.Backend,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AddressAndPortNameserver(Nameserver):
|
||||
def __init__(self, address: str, port: int):
|
||||
super().__init__()
|
||||
self.address = address
|
||||
self.port = port
|
||||
|
||||
def kind(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_always_max_size(self) -> bool:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
ns_kind = self.kind()
|
||||
return f"{ns_kind}:{self.address}@{self.port}"
|
||||
|
||||
def answer_nameserver(self) -> str:
|
||||
return self.address
|
||||
|
||||
def answer_port(self) -> int:
|
||||
return self.port
|
||||
|
||||
|
||||
class Do53Nameserver(AddressAndPortNameserver):
|
||||
def __init__(self, address: str, port: int = 53):
|
||||
super().__init__(address, port)
|
||||
|
||||
def kind(self):
|
||||
return "Do53"
|
||||
|
||||
def query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
if max_size:
|
||||
response = dns.query.tcp(
|
||||
request,
|
||||
self.address,
|
||||
timeout=timeout,
|
||||
port=self.port,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
else:
|
||||
response = dns.query.udp(
|
||||
request,
|
||||
self.address,
|
||||
timeout=timeout,
|
||||
port=self.port,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
raise_on_truncation=True,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
return response
|
||||
|
||||
async def async_query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
backend: dns.asyncbackend.Backend,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
if max_size:
|
||||
response = await dns.asyncquery.tcp(
|
||||
request,
|
||||
self.address,
|
||||
timeout=timeout,
|
||||
port=self.port,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
backend=backend,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
else:
|
||||
response = await dns.asyncquery.udp(
|
||||
request,
|
||||
self.address,
|
||||
timeout=timeout,
|
||||
port=self.port,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
raise_on_truncation=True,
|
||||
backend=backend,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class DoHNameserver(Nameserver):
|
||||
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.url = url
|
||||
self.bootstrap_address = bootstrap_address
|
||||
|
||||
def kind(self):
|
||||
return "DoH"
|
||||
|
||||
def is_always_max_size(self) -> bool:
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
return self.url
|
||||
|
||||
def answer_nameserver(self) -> str:
|
||||
return self.url
|
||||
|
||||
def answer_port(self) -> int:
|
||||
port = urlparse(self.url).port
|
||||
if port is None:
|
||||
port = 443
|
||||
return port
|
||||
|
||||
def query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
return dns.query.https(
|
||||
request,
|
||||
self.url,
|
||||
timeout=timeout,
|
||||
bootstrap_address=self.bootstrap_address,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
|
||||
async def async_query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
backend: dns.asyncbackend.Backend,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
return await dns.asyncquery.https(
|
||||
request,
|
||||
self.url,
|
||||
timeout=timeout,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
|
||||
|
||||
class DoTNameserver(AddressAndPortNameserver):
|
||||
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
|
||||
super().__init__(address, port)
|
||||
self.hostname = hostname
|
||||
|
||||
def kind(self):
|
||||
return "DoT"
|
||||
|
||||
def query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
return dns.query.tls(
|
||||
request,
|
||||
self.address,
|
||||
port=self.port,
|
||||
timeout=timeout,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
server_hostname=self.hostname,
|
||||
)
|
||||
|
||||
async def async_query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
backend: dns.asyncbackend.Backend,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
return await dns.asyncquery.tls(
|
||||
request,
|
||||
self.address,
|
||||
port=self.port,
|
||||
timeout=timeout,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
server_hostname=self.hostname,
|
||||
)
|
||||
|
||||
|
||||
class DoQNameserver(AddressAndPortNameserver):
|
||||
def __init__(
|
||||
self,
|
||||
address: str,
|
||||
port: int = 853,
|
||||
verify: Union[bool, str] = True,
|
||||
server_hostname: Optional[str] = None,
|
||||
):
|
||||
super().__init__(address, port)
|
||||
self.verify = verify
|
||||
self.server_hostname = server_hostname
|
||||
|
||||
def kind(self):
|
||||
return "DoQ"
|
||||
|
||||
def query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool = False,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
return dns.query.quic(
|
||||
request,
|
||||
self.address,
|
||||
port=self.port,
|
||||
timeout=timeout,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
verify=self.verify,
|
||||
server_hostname=self.server_hostname,
|
||||
)
|
||||
|
||||
async def async_query(
|
||||
self,
|
||||
request: dns.message.QueryMessage,
|
||||
timeout: float,
|
||||
source: Optional[str],
|
||||
source_port: int,
|
||||
max_size: bool,
|
||||
backend: dns.asyncbackend.Backend,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
) -> dns.message.Message:
|
||||
return await dns.asyncquery.quic(
|
||||
request,
|
||||
self.address,
|
||||
port=self.port,
|
||||
timeout=timeout,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
verify=self.verify,
|
||||
server_hostname=self.server_hostname,
|
||||
)
|
|
@ -17,19 +17,17 @@
|
|||
|
||||
"""DNS nodes. A node is a set of rdatasets."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import enum
|
||||
import io
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdataclass
|
||||
import dns.rdataset
|
||||
import dns.rdatatype
|
||||
import dns.rrset
|
||||
import dns.renderer
|
||||
|
||||
import dns.rrset
|
||||
|
||||
_cname_types = {
|
||||
dns.rdatatype.CNAME,
|
||||
|
|
248
lib/dns/query.py
248
lib/dns/query.py
|
@ -17,8 +17,6 @@
|
|||
|
||||
"""Talk to a DNS server."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import enum
|
||||
|
@ -28,12 +26,12 @@ import selectors
|
|||
import socket
|
||||
import struct
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
import dns.name
|
||||
import dns.message
|
||||
import dns.name
|
||||
import dns.quic
|
||||
import dns.rcode
|
||||
import dns.rdataclass
|
||||
|
@ -43,20 +41,32 @@ import dns.transaction
|
|||
import dns.tsig
|
||||
import dns.xfr
|
||||
|
||||
try:
|
||||
import requests
|
||||
from requests_toolbelt.adapters.source import SourceAddressAdapter
|
||||
from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
|
||||
|
||||
_have_requests = True
|
||||
except ImportError: # pragma: no cover
|
||||
_have_requests = False
|
||||
def _remaining(expiration):
|
||||
if expiration is None:
|
||||
return None
|
||||
timeout = expiration - time.time()
|
||||
if timeout <= 0.0:
|
||||
raise dns.exception.Timeout
|
||||
return timeout
|
||||
|
||||
|
||||
def _expiration_for_this_attempt(timeout, expiration):
|
||||
if expiration is None:
|
||||
return None
|
||||
return min(time.time() + timeout, expiration)
|
||||
|
||||
|
||||
_have_httpx = False
|
||||
_have_http2 = False
|
||||
try:
|
||||
import httpcore
|
||||
import httpcore._backends.sync
|
||||
import httpx
|
||||
|
||||
_CoreNetworkBackend = httpcore.NetworkBackend
|
||||
_CoreSyncStream = httpcore._backends.sync.SyncStream
|
||||
|
||||
_have_httpx = True
|
||||
try:
|
||||
# See if http2 support is available.
|
||||
|
@ -64,10 +74,87 @@ try:
|
|||
_have_http2 = True
|
||||
except Exception:
|
||||
pass
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
|
||||
have_doh = _have_requests or _have_httpx
|
||||
class _NetworkBackend(_CoreNetworkBackend):
|
||||
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||
super().__init__()
|
||||
self._local_port = local_port
|
||||
self._resolver = resolver
|
||||
self._bootstrap_address = bootstrap_address
|
||||
self._family = family
|
||||
|
||||
def connect_tcp(
|
||||
self, host, port, timeout, local_address, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
addresses = []
|
||||
_, expiration = _compute_times(timeout)
|
||||
if dns.inet.is_address(host):
|
||||
addresses.append(host)
|
||||
elif self._bootstrap_address is not None:
|
||||
addresses.append(self._bootstrap_address)
|
||||
else:
|
||||
timeout = _remaining(expiration)
|
||||
family = self._family
|
||||
if local_address:
|
||||
family = dns.inet.af_for_address(local_address)
|
||||
answers = self._resolver.resolve_name(
|
||||
host, family=family, lifetime=timeout
|
||||
)
|
||||
addresses = answers.addresses()
|
||||
for address in addresses:
|
||||
af = dns.inet.af_for_address(address)
|
||||
if local_address is not None or self._local_port != 0:
|
||||
source = dns.inet.low_level_address_tuple(
|
||||
(local_address, self._local_port), af
|
||||
)
|
||||
else:
|
||||
source = None
|
||||
sock = _make_socket(af, socket.SOCK_STREAM, source)
|
||||
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||
try:
|
||||
_connect(
|
||||
sock,
|
||||
dns.inet.low_level_address_tuple((address, port), af),
|
||||
attempt_expiration,
|
||||
)
|
||||
return _CoreSyncStream(sock)
|
||||
except Exception:
|
||||
pass
|
||||
raise httpcore.ConnectError
|
||||
|
||||
def connect_unix_socket(
|
||||
self, path, timeout, socket_options=None
|
||||
): # pylint: disable=signature-differs
|
||||
raise NotImplementedError
|
||||
|
||||
class _HTTPTransport(httpx.HTTPTransport):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
local_port=0,
|
||||
bootstrap_address=None,
|
||||
resolver=None,
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.resolver
|
||||
|
||||
resolver = dns.resolver.Resolver()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pool._network_backend = _NetworkBackend(
|
||||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
except ImportError: # pragma: no cover
|
||||
|
||||
class _HTTPTransport: # type: ignore
|
||||
def connect_tcp(self, host, port, timeout, local_address):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
have_doh = _have_httpx
|
||||
|
||||
try:
|
||||
import ssl
|
||||
|
@ -88,7 +175,7 @@ except ImportError: # pragma: no cover
|
|||
|
||||
@classmethod
|
||||
def create_default_context(cls, *args, **kwargs):
|
||||
raise Exception("no ssl support")
|
||||
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||
|
||||
|
||||
# Function used to create a socket. Can be overridden if needed in special
|
||||
|
@ -105,7 +192,7 @@ class BadResponse(dns.exception.FormError):
|
|||
|
||||
|
||||
class NoDOH(dns.exception.DNSException):
|
||||
"""DNS over HTTPS (DOH) was requested but the requests module is not
|
||||
"""DNS over HTTPS (DOH) was requested but the httpx module is not
|
||||
available."""
|
||||
|
||||
|
||||
|
@ -230,7 +317,7 @@ def _destination_and_source(
|
|||
# We know the destination af, so source had better agree!
|
||||
if saf != af:
|
||||
raise ValueError(
|
||||
"different address families for source " + "and destination"
|
||||
"different address families for source and destination"
|
||||
)
|
||||
else:
|
||||
# We didn't know the destination af, but we know the source,
|
||||
|
@ -240,11 +327,10 @@ def _destination_and_source(
|
|||
# Caller has specified a source_port but not an address, so we
|
||||
# need to return a source, and we need to use the appropriate
|
||||
# wildcard address as the address.
|
||||
if af == socket.AF_INET:
|
||||
source = "0.0.0.0"
|
||||
elif af == socket.AF_INET6:
|
||||
source = "::"
|
||||
else:
|
||||
try:
|
||||
source = dns.inet.any_for_af(af)
|
||||
except Exception:
|
||||
# we catch this and raise ValueError for backwards compatibility
|
||||
raise ValueError("source_port specified but address family is unknown")
|
||||
# Convert high-level (address, port) tuples into low-level address
|
||||
# tuples.
|
||||
|
@ -289,6 +375,8 @@ def https(
|
|||
post: bool = True,
|
||||
bootstrap_address: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
resolver: Optional["dns.resolver.Resolver"] = None,
|
||||
family: Optional[int] = socket.AF_UNSPEC,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
|
@ -314,91 +402,78 @@ def https(
|
|||
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
|
||||
received message.
|
||||
|
||||
*session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the
|
||||
client/session to use to send the queries.
|
||||
*session*, an ``httpx.Client``. If provided, the client session to use to send the
|
||||
queries.
|
||||
|
||||
*path*, a ``str``. If *where* is an IP address, then *path* will be used to
|
||||
construct the URL to send the DNS query to.
|
||||
|
||||
*post*, a ``bool``. If ``True``, the default, POST method will be used.
|
||||
|
||||
*bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS
|
||||
resolver.
|
||||
*bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
|
||||
|
||||
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
|
||||
of the server is done using the default CA bundle; if ``False``, then no
|
||||
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||
directory which will be used for verification.
|
||||
|
||||
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
|
||||
resolution of hostnames in URLs. If not specified, a new resolver with a default
|
||||
configuration will be used; note this is *not* the default resolver as that resolver
|
||||
might have been configured to use DoH causing a chicken-and-egg problem. This
|
||||
parameter only has an effect if the HTTP library is httpx.
|
||||
|
||||
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
|
||||
and AAAA records will be retrieved.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
|
||||
if not have_doh:
|
||||
raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover
|
||||
|
||||
_httpx_ok = _have_httpx
|
||||
raise NoDOH # pragma: no cover
|
||||
if session and not isinstance(session, httpx.Client):
|
||||
raise ValueError("session parameter must be an httpx.Client")
|
||||
|
||||
wire = q.to_wire()
|
||||
(af, _, source) = _destination_and_source(where, port, source, source_port, False)
|
||||
transport_adapter = None
|
||||
(af, _, the_source) = _destination_and_source(
|
||||
where, port, source, source_port, False
|
||||
)
|
||||
transport = None
|
||||
headers = {"accept": "application/dns-message"}
|
||||
if af is not None:
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = "https://{}:{}{}".format(where, port, path)
|
||||
elif af == socket.AF_INET6:
|
||||
url = "https://[{}]:{}{}".format(where, port, path)
|
||||
elif bootstrap_address is not None:
|
||||
_httpx_ok = False
|
||||
split_url = urllib.parse.urlsplit(where)
|
||||
if split_url.hostname is None:
|
||||
raise ValueError("DoH URL has no hostname")
|
||||
headers["Host"] = split_url.hostname
|
||||
url = where.replace(split_url.hostname, bootstrap_address)
|
||||
if _have_requests:
|
||||
transport_adapter = HostHeaderSSLAdapter()
|
||||
else:
|
||||
url = where
|
||||
if source is not None:
|
||||
|
||||
# set source port and source address
|
||||
if _have_httpx:
|
||||
if source_port == 0:
|
||||
transport = httpx.HTTPTransport(local_address=source[0], verify=verify)
|
||||
else:
|
||||
_httpx_ok = False
|
||||
if _have_requests:
|
||||
transport_adapter = SourceAddressAdapter(source)
|
||||
|
||||
if session:
|
||||
if _have_httpx:
|
||||
_is_httpx = isinstance(session, httpx.Client)
|
||||
if the_source is None:
|
||||
local_address = None
|
||||
local_port = 0
|
||||
else:
|
||||
_is_httpx = False
|
||||
if _is_httpx and not _httpx_ok:
|
||||
raise NoDOH(
|
||||
"Session is httpx, but httpx cannot be used for "
|
||||
"the requested operation."
|
||||
)
|
||||
else:
|
||||
_is_httpx = _httpx_ok
|
||||
|
||||
if not _httpx_ok and not _have_requests:
|
||||
raise NoDOH(
|
||||
"Cannot use httpx for this operation, and requests is not available."
|
||||
local_address = the_source[0]
|
||||
local_port = the_source[1]
|
||||
transport = _HTTPTransport(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=_have_http2,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
if session:
|
||||
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
|
||||
elif _is_httpx:
|
||||
else:
|
||||
cm = httpx.Client(
|
||||
http1=True, http2=_have_http2, verify=verify, transport=transport
|
||||
)
|
||||
else:
|
||||
cm = requests.sessions.Session()
|
||||
with cm as session:
|
||||
if transport_adapter and not _is_httpx:
|
||||
session.mount(url, transport_adapter)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
if post:
|
||||
|
@ -408,29 +483,13 @@ def https(
|
|||
"content-length": str(len(wire)),
|
||||
}
|
||||
)
|
||||
if _is_httpx:
|
||||
response = session.post(
|
||||
url, headers=headers, content=wire, timeout=timeout
|
||||
)
|
||||
else:
|
||||
response = session.post(
|
||||
url, headers=headers, data=wire, timeout=timeout, verify=verify
|
||||
)
|
||||
response = session.post(url, headers=headers, content=wire, timeout=timeout)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
if _is_httpx:
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = session.get(
|
||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
||||
)
|
||||
else:
|
||||
response = session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
params={"dns": wire},
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
# status codes
|
||||
|
@ -1070,6 +1129,7 @@ def quic(
|
|||
ignore_trailing: bool = False,
|
||||
connection: Optional[dns.quic.SyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-QUIC.
|
||||
|
||||
|
@ -1101,6 +1161,10 @@ def quic(
|
|||
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||
directory which will be used for verification.
|
||||
|
||||
*server_hostname*, a ``str`` containing the server's hostname. The
|
||||
default is ``None``, which means that no hostname is known, and if an
|
||||
SSL context is created, hostname checking will be disabled.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
|
||||
|
@ -1115,16 +1179,18 @@ def quic(
|
|||
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
|
||||
the_connection = connection
|
||||
else:
|
||||
manager = dns.quic.SyncQuicManager(verify_mode=verify)
|
||||
manager = dns.quic.SyncQuicManager(
|
||||
verify_mode=verify, server_name=server_hostname
|
||||
)
|
||||
the_manager = manager # for type checking happiness
|
||||
|
||||
with manager:
|
||||
if not connection:
|
||||
the_connection = the_manager.connect(where, port, source, source_port)
|
||||
start = time.time()
|
||||
with the_connection.make_stream() as stream:
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
with the_connection.make_stream(timeout) as stream:
|
||||
stream.send(wire, True)
|
||||
wire = stream.receive(timeout)
|
||||
wire = stream.receive(_remaining(expiration))
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
|
|
|
@ -5,13 +5,13 @@ try:
|
|||
|
||||
import dns.asyncbackend
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
|
||||
from dns.quic._asyncio import (
|
||||
AsyncioQuicManager,
|
||||
AsyncioQuicConnection,
|
||||
AsyncioQuicManager,
|
||||
AsyncioQuicStream,
|
||||
)
|
||||
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
|
||||
from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream
|
||||
|
||||
have_quic = True
|
||||
|
||||
|
@ -33,9 +33,10 @@ try:
|
|||
|
||||
try:
|
||||
import trio
|
||||
|
||||
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
|
||||
TrioQuicManager,
|
||||
TrioQuicConnection,
|
||||
TrioQuicManager,
|
||||
TrioQuicStream,
|
||||
)
|
||||
|
||||
|
|
|
@ -9,14 +9,16 @@ import time
|
|||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import aioquic.quic.events # type: ignore
|
||||
import dns.inet
|
||||
import dns.asyncbackend
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
from dns.quic._common import (
|
||||
BaseQuicStream,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
AsyncQuicConnection,
|
||||
AsyncQuicManager,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
BaseQuicStream,
|
||||
UnexpectedEOF,
|
||||
)
|
||||
|
||||
|
||||
|
@ -30,15 +32,15 @@ class AsyncioQuicStream(BaseQuicStream):
|
|||
await self._wake_up.wait()
|
||||
|
||||
async def wait_for(self, amount, expiration):
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
while True:
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
if self._buffer.have(amount):
|
||||
return
|
||||
self._expecting = amount
|
||||
try:
|
||||
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
|
||||
except Exception:
|
||||
pass
|
||||
except TimeoutError:
|
||||
raise dns.exception.Timeout
|
||||
self._expecting = 0
|
||||
|
||||
async def receive(self, timeout=None):
|
||||
|
@ -86,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
try:
|
||||
af = dns.inet.af_for_address(self._address)
|
||||
backend = dns.asyncbackend.get_backend("asyncio")
|
||||
# Note that peer is a low-level address tuple, but make_socket() wants
|
||||
# a high-level address tuple, so we convert.
|
||||
self._socket = await backend.make_socket(
|
||||
af, socket.SOCK_DGRAM, 0, self._source, self._peer
|
||||
af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
|
||||
)
|
||||
self._socket_created.set()
|
||||
async with self._socket:
|
||||
|
@ -106,6 +110,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
self._wake_timer.notify_all()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._done = True
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
self._handshake_complete.set()
|
||||
|
||||
async def _wait_for_wake_timer(self):
|
||||
async with self._wake_timer:
|
||||
|
@ -115,7 +124,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
await self._socket_created.wait()
|
||||
while not self._done:
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for (datagram, address) in datagrams:
|
||||
for datagram, address in datagrams:
|
||||
assert address == self._peer[0]
|
||||
await self._socket.sendto(datagram, self._peer, None)
|
||||
(expiration, interval) = self._get_timer_values()
|
||||
|
@ -160,8 +169,13 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
self._receiver_task = asyncio.Task(self._receiver())
|
||||
self._sender_task = asyncio.Task(self._sender())
|
||||
|
||||
async def make_stream(self):
|
||||
await self._handshake_complete.wait()
|
||||
async def make_stream(self, timeout=None):
|
||||
try:
|
||||
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
|
||||
except TimeoutError:
|
||||
raise dns.exception.Timeout
|
||||
if self._done:
|
||||
raise UnexpectedEOF
|
||||
stream_id = self._connection.get_next_available_stream_id(False)
|
||||
stream = AsyncioQuicStream(self, stream_id)
|
||||
self._streams[stream_id] = stream
|
||||
|
@ -172,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
self._manager.closed(self._peer[0], self._peer[1])
|
||||
self._closed = True
|
||||
self._connection.close()
|
||||
# sender might be blocked on this, so set it
|
||||
self._socket_created.set()
|
||||
await self._socket.close()
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
try:
|
||||
|
@ -185,8 +202,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
|
||||
|
||||
class AsyncioQuicManager(AsyncQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
||||
super().__init__(conf, verify_mode, AsyncioQuicConnection)
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
|
@ -198,7 +215,7 @@ class AsyncioQuicManager(AsyncQuicManager):
|
|||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Copy the itertor into a list as exiting things will mutate the connections
|
||||
# Copy the iterator into a list as exiting things will mutate the connections
|
||||
# table.
|
||||
connections = list(self._connections.values())
|
||||
for connection in connections:
|
||||
|
|
|
@ -3,13 +3,12 @@
|
|||
import socket
|
||||
import struct
|
||||
import time
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import dns.inet
|
||||
|
||||
import dns.inet
|
||||
|
||||
QUIC_MAX_DATAGRAM = 2048
|
||||
|
||||
|
@ -135,12 +134,12 @@ class BaseQuicConnection:
|
|||
|
||||
|
||||
class AsyncQuicConnection(BaseQuicConnection):
|
||||
async def make_stream(self) -> Any:
|
||||
async def make_stream(self, timeout: Optional[float] = None) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class BaseQuicManager:
|
||||
def __init__(self, conf, verify_mode, connection_factory):
|
||||
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
|
||||
self._connections = {}
|
||||
self._connection_factory = connection_factory
|
||||
if conf is None:
|
||||
|
@ -151,6 +150,7 @@ class BaseQuicManager:
|
|||
conf = aioquic.quic.configuration.QuicConfiguration(
|
||||
alpn_protocols=["doq", "doq-i03"],
|
||||
verify_mode=verify_mode,
|
||||
server_name=server_name,
|
||||
)
|
||||
if verify_path is not None:
|
||||
conf.load_verify_locations(verify_path)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import selectors
|
||||
import socket
|
||||
import ssl
|
||||
import selectors
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
|
@ -10,13 +10,15 @@ import time
|
|||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
import aioquic.quic.events # type: ignore
|
||||
import dns.inet
|
||||
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
from dns.quic._common import (
|
||||
BaseQuicStream,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
BaseQuicConnection,
|
||||
BaseQuicManager,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
BaseQuicStream,
|
||||
UnexpectedEOF,
|
||||
)
|
||||
|
||||
# Avoid circularity with dns.query
|
||||
|
@ -33,14 +35,15 @@ class SyncQuicStream(BaseQuicStream):
|
|||
self._lock = threading.Lock()
|
||||
|
||||
def wait_for(self, amount, expiration):
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
while True:
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
with self._lock:
|
||||
if self._buffer.have(amount):
|
||||
return
|
||||
self._expecting = amount
|
||||
with self._wake_up:
|
||||
self._wake_up.wait(timeout)
|
||||
if not self._wake_up.wait(timeout):
|
||||
raise dns.exception.Timeout
|
||||
self._expecting = 0
|
||||
|
||||
def receive(self, timeout=None):
|
||||
|
@ -114,24 +117,30 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
return
|
||||
|
||||
def _worker(self):
|
||||
try:
|
||||
sel = _selector_class()
|
||||
sel.register(self._socket, selectors.EVENT_READ, self._read)
|
||||
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
|
||||
while not self._done:
|
||||
(expiration, interval) = self._get_timer_values(False)
|
||||
items = sel.select(interval)
|
||||
for (key, _) in items:
|
||||
for key, _ in items:
|
||||
key.data()
|
||||
with self._lock:
|
||||
self._handle_timer(expiration)
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for (datagram, _) in datagrams:
|
||||
for datagram, _ in datagrams:
|
||||
try:
|
||||
self._socket.send(datagram)
|
||||
except BlockingIOError:
|
||||
# we let QUIC handle any lossage
|
||||
pass
|
||||
self._handle_events()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._done = True
|
||||
# Ensure anyone waiting for this gets woken up.
|
||||
self._handshake_complete.set()
|
||||
|
||||
def _handle_events(self):
|
||||
while True:
|
||||
|
@ -163,9 +172,12 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
self._worker_thread = threading.Thread(target=self._worker)
|
||||
self._worker_thread.start()
|
||||
|
||||
def make_stream(self):
|
||||
self._handshake_complete.wait()
|
||||
def make_stream(self, timeout=None):
|
||||
if not self._handshake_complete.wait(timeout):
|
||||
raise dns.exception.Timeout
|
||||
with self._lock:
|
||||
if self._done:
|
||||
raise UnexpectedEOF
|
||||
stream_id = self._connection.get_next_available_stream_id(False)
|
||||
stream = SyncQuicStream(self, stream_id)
|
||||
self._streams[stream_id] = stream
|
||||
|
@ -187,8 +199,8 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
|
||||
|
||||
class SyncQuicManager(BaseQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
||||
super().__init__(conf, verify_mode, SyncQuicConnection)
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
|
@ -206,7 +218,7 @@ class SyncQuicManager(BaseQuicManager):
|
|||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Copy the itertor into a list as exiting things will mutate the connections
|
||||
# Copy the iterator into a list as exiting things will mutate the connections
|
||||
# table.
|
||||
connections = list(self._connections.values())
|
||||
for connection in connections:
|
||||
|
|
|
@ -10,13 +10,15 @@ import aioquic.quic.connection # type: ignore
|
|||
import aioquic.quic.events # type: ignore
|
||||
import trio
|
||||
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.quic._common import (
|
||||
BaseQuicStream,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
AsyncQuicConnection,
|
||||
AsyncQuicManager,
|
||||
QUIC_MAX_DATAGRAM,
|
||||
BaseQuicStream,
|
||||
UnexpectedEOF,
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,6 +46,7 @@ class TrioQuicStream(BaseQuicStream):
|
|||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size)
|
||||
return self._buffer.get(size)
|
||||
raise dns.exception.Timeout
|
||||
|
||||
async def send(self, datagram, is_end=False):
|
||||
data = self._encapsulate(datagram)
|
||||
|
@ -80,6 +83,7 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
self._worker_scope = None
|
||||
|
||||
async def _worker(self):
|
||||
try:
|
||||
await self._socket.connect(self._peer)
|
||||
while not self._done:
|
||||
(expiration, interval) = self._get_timer_values(False)
|
||||
|
@ -87,13 +91,18 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
deadline=trio.current_time() + interval
|
||||
) as self._worker_scope:
|
||||
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
|
||||
self._connection.receive_datagram(datagram, self._peer[0], time.time())
|
||||
self._connection.receive_datagram(
|
||||
datagram, self._peer[0], time.time()
|
||||
)
|
||||
self._worker_scope = None
|
||||
self._handle_timer(expiration)
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for (datagram, _) in datagrams:
|
||||
for datagram, _ in datagrams:
|
||||
await self._socket.send(datagram)
|
||||
await self._handle_events()
|
||||
finally:
|
||||
self._done = True
|
||||
self._handshake_complete.set()
|
||||
|
||||
async def _handle_events(self):
|
||||
count = 0
|
||||
|
@ -130,12 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
nursery.start_soon(self._worker)
|
||||
self._run_done.set()
|
||||
|
||||
async def make_stream(self):
|
||||
async def make_stream(self, timeout=None):
|
||||
if timeout is None:
|
||||
context = NullContext(None)
|
||||
else:
|
||||
context = trio.move_on_after(timeout)
|
||||
with context:
|
||||
await self._handshake_complete.wait()
|
||||
if self._done:
|
||||
raise UnexpectedEOF
|
||||
stream_id = self._connection.get_next_available_stream_id(False)
|
||||
stream = TrioQuicStream(self, stream_id)
|
||||
self._streams[stream_id] = stream
|
||||
return stream
|
||||
raise dns.exception.Timeout
|
||||
|
||||
async def close(self):
|
||||
if not self._closed:
|
||||
|
@ -148,8 +165,10 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
|
||||
|
||||
class TrioQuicManager(AsyncQuicManager):
|
||||
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
||||
super().__init__(conf, verify_mode, TrioQuicConnection)
|
||||
def __init__(
|
||||
self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
|
||||
):
|
||||
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
|
||||
self._nursery = nursery
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
|
@ -162,7 +181,7 @@ class TrioQuicManager(AsyncQuicManager):
|
|||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Copy the itertor into a list as exiting things will mutate the connections
|
||||
# Copy the iterator into a list as exiting things will mutate the connections
|
||||
# table.
|
||||
connections = list(self._connections.values())
|
||||
for connection in connections:
|
||||
|
|
|
@ -17,17 +17,15 @@
|
|||
|
||||
"""DNS rdata."""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from importlib import import_module
|
||||
import base64
|
||||
import binascii
|
||||
import io
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import random
|
||||
from importlib import import_module
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import dns.wire
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.ipv4
|
||||
|
@ -37,6 +35,7 @@ import dns.rdataclass
|
|||
import dns.rdatatype
|
||||
import dns.tokenizer
|
||||
import dns.ttl
|
||||
import dns.wire
|
||||
|
||||
_chunksize = 32
|
||||
|
||||
|
@ -358,7 +357,6 @@ class Rdata:
|
|||
or self.rdclass != other.rdclass
|
||||
or self.rdtype != other.rdtype
|
||||
):
|
||||
|
||||
return NotImplemented
|
||||
return self._cmp(other) < 0
|
||||
|
||||
|
@ -881,16 +879,11 @@ def register_type(
|
|||
it applies to all classes.
|
||||
"""
|
||||
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
existing_cls = get_rdata_class(rdclass, the_rdtype)
|
||||
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
|
||||
try:
|
||||
if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
|
||||
except ValueError:
|
||||
pass
|
||||
_rdata_classes[(rdclass, the_rdtype)] = getattr(
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
existing_cls = get_rdata_class(rdclass, rdtype)
|
||||
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
|
||||
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
|
||||
_rdata_classes[(rdclass, rdtype)] = getattr(
|
||||
implementation, rdtype_text.replace("-", "_")
|
||||
)
|
||||
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)
|
||||
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
|
||||
|
|
|
@ -17,18 +17,17 @@
|
|||
|
||||
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
|
||||
|
||||
from typing import Any, cast, Collection, Dict, List, Optional, Union
|
||||
|
||||
import io
|
||||
import random
|
||||
import struct
|
||||
from typing import Any, Collection, Dict, List, Optional, Union, cast
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdatatype
|
||||
import dns.rdataclass
|
||||
import dns.rdata
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.set
|
||||
import dns.ttl
|
||||
|
||||
|
@ -471,9 +470,9 @@ def from_text_list(
|
|||
Returns a ``dns.rdataset.Rdataset`` object.
|
||||
"""
|
||||
|
||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
r = Rdataset(the_rdclass, the_rdtype)
|
||||
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
r = Rdataset(rdclass, rdtype)
|
||||
r.update_ttl(ttl)
|
||||
for t in text_rdatas:
|
||||
rd = dns.rdata.from_text(
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.mxbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.mxbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.txtbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.txtbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,15 +15,15 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||
import dns.immutable
|
||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from dns.rdtypes.dnskeybase import (
|
||||
SEP,
|
||||
from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
|
||||
REVOKE,
|
||||
SEP,
|
||||
ZONE,
|
||||
) # noqa: F401 lgtm[py/unused-import]
|
||||
)
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.dsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import dns.dnssectypes
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.dnssectypes
|
||||
import dns.rdata
|
||||
import dns.tokenizer
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.nsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.nsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -19,9 +19,9 @@ import struct
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdatatype
|
||||
import dns.name
|
||||
import dns.rdtypes.util
|
||||
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.dsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.nsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.nsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,15 +15,15 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||
import dns.immutable
|
||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from dns.rdtypes.dnskeybase import (
|
||||
SEP,
|
||||
from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
|
||||
REVOKE,
|
||||
SEP,
|
||||
ZONE,
|
||||
) # noqa: F401 lgtm[py/unused-import]
|
||||
)
|
||||
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.dsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.dsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.euibase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.euibase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.euibase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.euibase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import base64
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.exception
|
|||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
||||
|
||||
_pows = tuple(10**i for i in range(0, 11))
|
||||
|
||||
# default values are in centimeters
|
||||
|
@ -40,7 +39,7 @@ def _exponent_of(what, desc):
|
|||
if what == 0:
|
||||
return 0
|
||||
exp = None
|
||||
for (i, pow) in enumerate(_pows):
|
||||
for i, pow in enumerate(_pows):
|
||||
if what < pow:
|
||||
exp = i - 1
|
||||
break
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.mxbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.mxbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.txtbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.txtbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.nsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.nsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdatatype
|
||||
import dns.name
|
||||
import dns.rdtypes.util
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdata
|
|||
import dns.rdatatype
|
||||
import dns.rdtypes.util
|
||||
|
||||
|
||||
b32_hex_to_normal = bytes.maketrans(
|
||||
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
|
||||
)
|
||||
|
@ -67,6 +66,7 @@ class NSEC3(dns.rdata.Rdata):
|
|||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
|
||||
next = next.rstrip("=")
|
||||
if self.salt == b"":
|
||||
salt = "-"
|
||||
else:
|
||||
|
@ -94,6 +94,10 @@ class NSEC3(dns.rdata.Rdata):
|
|||
else:
|
||||
salt = binascii.unhexlify(salt.encode("ascii"))
|
||||
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
|
||||
if next.endswith(b"="):
|
||||
raise binascii.Error("Incorrect padding")
|
||||
if len(next) % 8 != 0:
|
||||
next += b"=" * (8 - len(next) % 8)
|
||||
next = base64.b32decode(next)
|
||||
bitmap = Bitmap.from_text(tok)
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
|
|
|
@ -18,11 +18,10 @@
|
|||
import struct
|
||||
|
||||
import dns.edns
|
||||
import dns.immutable
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
||||
|
||||
# We don't implement from_text, and that's ok.
|
||||
# pylint: disable=abstract-method
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.nsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.nsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -21,8 +21,8 @@ import struct
|
|||
import time
|
||||
|
||||
import dns.dnssectypes
|
||||
import dns.immutable
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.rdatatype
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.mxbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.mxbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -19,8 +19,8 @@ import struct
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.txtbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.txtbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
import dns.rdata
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.rdatatype
|
||||
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
import base64
|
||||
import struct
|
||||
|
||||
import dns.immutable
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.txtbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.txtbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -20,9 +20,9 @@ import struct
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdtypes.util
|
||||
import dns.name
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import struct
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
import struct
|
||||
|
||||
import dns.rdtypes.mxbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.mxbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -124,7 +124,6 @@ class APL(dns.rdata.Rdata):
|
|||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
|
||||
items = []
|
||||
while parser.remaining() > 0:
|
||||
header = parser.get_struct("!HBB")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import dns.rdtypes.svcbbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.svcbbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.mxbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.mxbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.rdtypes.nsbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.nsbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -19,9 +19,9 @@ import struct
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdtypes.util
|
||||
import dns.name
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -19,9 +19,9 @@ import struct
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdtypes.util
|
||||
import dns.name
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import dns.rdtypes.svcbbase
|
||||
import dns.immutable
|
||||
import dns.rdtypes.svcbbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
import socket
|
||||
import struct
|
||||
|
||||
import dns.ipv4
|
||||
import dns.immutable
|
||||
import dns.ipv4
|
||||
import dns.rdata
|
||||
|
||||
try:
|
||||
|
|
|
@ -19,9 +19,9 @@ import base64
|
|||
import enum
|
||||
import struct
|
||||
|
||||
import dns.dnssectypes
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.dnssectypes
|
||||
import dns.rdata
|
||||
|
||||
# wildcard import
|
||||
|
@ -43,7 +43,7 @@ class DNSKEYBase(dns.rdata.Rdata):
|
|||
|
||||
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
|
||||
super().__init__(rdclass, rdtype)
|
||||
self.flags = self._as_uint16(flags)
|
||||
self.flags = Flag(self._as_uint16(flags))
|
||||
self.protocol = self._as_uint8(protocol)
|
||||
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
|
||||
self.key = self._as_bytes(key)
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
import dns.dnssectypes
|
||||
import dns.immutable
|
||||
|
@ -44,7 +44,7 @@ class DSBase(dns.rdata.Rdata):
|
|||
super().__init__(rdclass, rdtype)
|
||||
self.key_tag = self._as_uint16(key_tag)
|
||||
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
|
||||
self.digest_type = self._as_uint8(digest_type)
|
||||
self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type))
|
||||
self.digest = self._as_bytes(digest)
|
||||
try:
|
||||
if len(self.digest) != self._digest_length_by_type[self.digest_type]:
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
import binascii
|
||||
|
||||
import dns.rdata
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -21,8 +21,8 @@ import struct
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdtypes.util
|
||||
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
|
|
|
@ -34,6 +34,7 @@ class ParamKey(dns.enum.IntEnum):
|
|||
IPV4HINT = 4
|
||||
ECH = 5
|
||||
IPV6HINT = 6
|
||||
DOHPATH = 7
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import struct
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
import dns.rdata
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.rdatatype
|
||||
|
||||
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
"""TXT-like base class."""
|
||||
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import struct
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import collections
|
||||
import random
|
||||
import struct
|
||||
from typing import Any, List
|
||||
|
||||
import dns.exception
|
||||
import dns.ipv4
|
||||
|
@ -119,7 +120,7 @@ class Bitmap:
|
|||
def __init__(self, windows=None):
|
||||
last_window = -1
|
||||
self.windows = windows
|
||||
for (window, bitmap) in self.windows:
|
||||
for window, bitmap in self.windows:
|
||||
if not isinstance(window, int):
|
||||
raise ValueError(f"bad {self.type_name} window type")
|
||||
if window <= last_window:
|
||||
|
@ -132,11 +133,11 @@ class Bitmap:
|
|||
if len(bitmap) == 0 or len(bitmap) > 32:
|
||||
raise ValueError(f"bad {self.type_name} octets")
|
||||
|
||||
def to_text(self):
|
||||
def to_text(self) -> str:
|
||||
text = ""
|
||||
for (window, bitmap) in self.windows:
|
||||
for window, bitmap in self.windows:
|
||||
bits = []
|
||||
for (i, byte) in enumerate(bitmap):
|
||||
for i, byte in enumerate(bitmap):
|
||||
for j in range(0, 8):
|
||||
if byte & (0x80 >> j):
|
||||
rdtype = window * 256 + i * 8 + j
|
||||
|
@ -145,14 +146,18 @@ class Bitmap:
|
|||
return text
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, tok):
|
||||
def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap":
|
||||
rdtypes = []
|
||||
for token in tok.get_remaining():
|
||||
rdtype = dns.rdatatype.from_text(token.unescape().value)
|
||||
if rdtype == 0:
|
||||
raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
|
||||
rdtypes.append(rdtype)
|
||||
rdtypes.sort()
|
||||
return cls.from_rdtypes(rdtypes)
|
||||
|
||||
@classmethod
|
||||
def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap":
|
||||
rdtypes = sorted(rdtypes)
|
||||
window = 0
|
||||
octets = 0
|
||||
prior_rdtype = 0
|
||||
|
@ -177,13 +182,13 @@ class Bitmap:
|
|||
windows.append((window, bytes(bitmap[0:octets])))
|
||||
return cls(windows)
|
||||
|
||||
def to_wire(self, file):
|
||||
for (window, bitmap) in self.windows:
|
||||
def to_wire(self, file: Any) -> None:
|
||||
for window, bitmap in self.windows:
|
||||
file.write(struct.pack("!BB", window, len(bitmap)))
|
||||
file.write(bitmap)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, parser):
|
||||
def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap":
|
||||
windows = []
|
||||
while parser.remaining() > 0:
|
||||
window = parser.get_uint8()
|
||||
|
@ -226,7 +231,7 @@ def weighted_processing_order(iterable):
|
|||
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
|
||||
while len(rdatas) > 1:
|
||||
r = random.uniform(0, total)
|
||||
for (n, rdata) in enumerate(rdatas):
|
||||
for n, rdata in enumerate(rdatas):
|
||||
weight = rdata._processing_weight() or _no_weight
|
||||
if weight > r:
|
||||
break
|
||||
|
|
|
@ -19,14 +19,13 @@
|
|||
|
||||
import contextlib
|
||||
import io
|
||||
import struct
|
||||
import random
|
||||
import struct
|
||||
import time
|
||||
|
||||
import dns.exception
|
||||
import dns.tsig
|
||||
|
||||
|
||||
QUESTION = 0
|
||||
ANSWER = 1
|
||||
AUTHORITY = 2
|
||||
|
|
|
@ -17,29 +17,31 @@
|
|||
|
||||
"""DNS stub resolver."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from urllib.parse import urlparse
|
||||
import contextlib
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import dns.exception
|
||||
import dns._ddr
|
||||
import dns.edns
|
||||
import dns.exception
|
||||
import dns.flags
|
||||
import dns.inet
|
||||
import dns.ipv4
|
||||
import dns.ipv6
|
||||
import dns.message
|
||||
import dns.name
|
||||
import dns.nameserver
|
||||
import dns.query
|
||||
import dns.rcode
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.rdtypes.svcbbase
|
||||
import dns.reversename
|
||||
import dns.tsig
|
||||
|
||||
|
@ -72,7 +74,7 @@ class NXDOMAIN(dns.exception.DNSException):
|
|||
kwargs = dict(qnames=qnames, responses=responses)
|
||||
return kwargs
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if "qnames" not in self.kwargs:
|
||||
return super().__str__()
|
||||
qnames = self.kwargs["qnames"]
|
||||
|
@ -140,7 +142,11 @@ class YXDOMAIN(dns.exception.DNSException):
|
|||
|
||||
|
||||
ErrorTuple = Tuple[
|
||||
Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]
|
||||
Optional[str],
|
||||
bool,
|
||||
int,
|
||||
Union[Exception, str],
|
||||
Optional[dns.message.Message],
|
||||
]
|
||||
|
||||
|
||||
|
@ -148,11 +154,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
|
|||
"""Turn a resolution errors trace into a list of text."""
|
||||
texts = []
|
||||
for err in errors:
|
||||
texts.append(
|
||||
"Server {} {} port {} answered {}".format(
|
||||
err[0], "TCP" if err[1] else "UDP", err[2], err[3]
|
||||
)
|
||||
)
|
||||
texts.append("Server {} answered {}".format(err[0], err[3]))
|
||||
return texts
|
||||
|
||||
|
||||
|
@ -184,7 +186,7 @@ Timeout = LifetimeTimeout
|
|||
class NoAnswer(dns.exception.DNSException):
|
||||
"""The DNS response does not contain an answer to the question."""
|
||||
|
||||
fmt = "The DNS response does not contain an answer " + "to the question: {query}"
|
||||
fmt = "The DNS response does not contain an answer to the question: {query}"
|
||||
supp_kwargs = {"response"}
|
||||
|
||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||
|
@ -264,7 +266,7 @@ class Answer:
|
|||
response: dns.message.QueryMessage,
|
||||
nameserver: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.qname = qname
|
||||
self.rdtype = rdtype
|
||||
self.rdclass = rdclass
|
||||
|
@ -292,7 +294,7 @@ class Answer:
|
|||
else:
|
||||
raise AttributeError(attr)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.rrset and len(self.rrset) or 0
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -309,14 +311,67 @@ class Answer:
|
|||
del self.rrset[i]
|
||||
|
||||
|
||||
class Answers(dict):
|
||||
"""A dict of DNS stub resolver answers, indexed by type."""
|
||||
|
||||
|
||||
class HostAnswers(Answers):
|
||||
"""A dict of DNS stub resolver answers to a host name lookup, indexed by
|
||||
type.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def make(
|
||||
cls,
|
||||
v6: Optional[Answer] = None,
|
||||
v4: Optional[Answer] = None,
|
||||
add_empty: bool = True,
|
||||
) -> "HostAnswers":
|
||||
answers = HostAnswers()
|
||||
if v6 is not None and (add_empty or v6.rrset):
|
||||
answers[dns.rdatatype.AAAA] = v6
|
||||
if v4 is not None and (add_empty or v4.rrset):
|
||||
answers[dns.rdatatype.A] = v4
|
||||
return answers
|
||||
|
||||
# Returns pairs of (address, family) from this result, potentiallys
|
||||
# filtering by address family.
|
||||
def addresses_and_families(
|
||||
self, family: int = socket.AF_UNSPEC
|
||||
) -> Iterator[Tuple[str, int]]:
|
||||
if family == socket.AF_UNSPEC:
|
||||
yield from self.addresses_and_families(socket.AF_INET6)
|
||||
yield from self.addresses_and_families(socket.AF_INET)
|
||||
return
|
||||
elif family == socket.AF_INET6:
|
||||
answer = self.get(dns.rdatatype.AAAA)
|
||||
elif family == socket.AF_INET:
|
||||
answer = self.get(dns.rdatatype.A)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown address family {family}")
|
||||
if answer:
|
||||
for rdata in answer:
|
||||
yield (rdata.address, family)
|
||||
|
||||
# Returns addresses from this result, potentially filtering by
|
||||
# address family.
|
||||
def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]:
|
||||
return (pair[0] for pair in self.addresses_and_families(family))
|
||||
|
||||
# Returns the canonical name from this result.
|
||||
def canonical_name(self) -> dns.name.Name:
|
||||
answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A))
|
||||
return answer.canonical_name
|
||||
|
||||
|
||||
class CacheStatistics:
|
||||
"""Cache Statistics"""
|
||||
|
||||
def __init__(self, hits=0, misses=0):
|
||||
def __init__(self, hits: int = 0, misses: int = 0) -> None:
|
||||
self.hits = hits
|
||||
self.misses = misses
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
|
||||
|
@ -325,7 +380,7 @@ class CacheStatistics:
|
|||
|
||||
|
||||
class CacheBase:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.lock = threading.Lock()
|
||||
self.statistics = CacheStatistics()
|
||||
|
||||
|
@ -361,7 +416,7 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla
|
|||
class Cache(CacheBase):
|
||||
"""Simple thread-safe DNS answer cache."""
|
||||
|
||||
def __init__(self, cleaning_interval: float = 300.0):
|
||||
def __init__(self, cleaning_interval: float = 300.0) -> None:
|
||||
"""*cleaning_interval*, a ``float`` is the number of seconds between
|
||||
periodic cleanings.
|
||||
"""
|
||||
|
@ -377,7 +432,7 @@ class Cache(CacheBase):
|
|||
now = time.time()
|
||||
if self.next_cleaning <= now:
|
||||
keys_to_delete = []
|
||||
for (k, v) in self.data.items():
|
||||
for k, v in self.data.items():
|
||||
if v.expiration <= now:
|
||||
keys_to_delete.append(k)
|
||||
for k in keys_to_delete:
|
||||
|
@ -447,13 +502,13 @@ class LRUCacheNode:
|
|||
self.prev = self
|
||||
self.next = self
|
||||
|
||||
def link_after(self, node):
|
||||
def link_after(self, node: "LRUCacheNode") -> None:
|
||||
self.prev = node
|
||||
self.next = node.next
|
||||
node.next.prev = self
|
||||
node.next = self
|
||||
|
||||
def unlink(self):
|
||||
def unlink(self) -> None:
|
||||
self.next.prev = self.prev
|
||||
self.prev.next = self.next
|
||||
|
||||
|
@ -468,7 +523,7 @@ class LRUCache(CacheBase):
|
|||
for a new one.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100000):
|
||||
def __init__(self, max_size: int = 100000) -> None:
|
||||
"""*max_size*, an ``int``, is the maximum number of nodes to cache;
|
||||
it must be greater than 0.
|
||||
"""
|
||||
|
@ -590,30 +645,29 @@ class _Resolution:
|
|||
tcp: bool,
|
||||
raise_on_no_answer: bool,
|
||||
search: Optional[bool],
|
||||
):
|
||||
) -> None:
|
||||
if isinstance(qname, str):
|
||||
qname = dns.name.from_text(qname, None)
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
if dns.rdatatype.is_metatype(the_rdtype):
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
if dns.rdatatype.is_metatype(rdtype):
|
||||
raise NoMetaqueries
|
||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
if dns.rdataclass.is_metaclass(the_rdclass):
|
||||
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
if dns.rdataclass.is_metaclass(rdclass):
|
||||
raise NoMetaqueries
|
||||
self.resolver = resolver
|
||||
self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
|
||||
self.qnames = self.qnames_to_try[:]
|
||||
self.rdtype = the_rdtype
|
||||
self.rdclass = the_rdclass
|
||||
self.rdtype = rdtype
|
||||
self.rdclass = rdclass
|
||||
self.tcp = tcp
|
||||
self.raise_on_no_answer = raise_on_no_answer
|
||||
self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
|
||||
# Initialize other things to help analysis tools
|
||||
self.qname = dns.name.empty
|
||||
self.nameservers: List[str] = []
|
||||
self.current_nameservers: List[str] = []
|
||||
self.nameservers: List[dns.nameserver.Nameserver] = []
|
||||
self.current_nameservers: List[dns.nameserver.Nameserver] = []
|
||||
self.errors: List[ErrorTuple] = []
|
||||
self.nameserver: Optional[str] = None
|
||||
self.port = 0
|
||||
self.nameserver: Optional[dns.nameserver.Nameserver] = None
|
||||
self.tcp_attempt = False
|
||||
self.retry_with_tcp = False
|
||||
self.request: Optional[dns.message.QueryMessage] = None
|
||||
|
@ -670,7 +724,11 @@ class _Resolution:
|
|||
if self.resolver.flags is not None:
|
||||
request.flags = self.resolver.flags
|
||||
|
||||
self.nameservers = self.resolver.nameservers[:]
|
||||
self.nameservers = self.resolver._enrich_nameservers(
|
||||
self.resolver._nameservers,
|
||||
self.resolver.nameserver_ports,
|
||||
self.resolver.port,
|
||||
)
|
||||
if self.resolver.rotate:
|
||||
random.shuffle(self.nameservers)
|
||||
self.current_nameservers = self.nameservers[:]
|
||||
|
@ -690,12 +748,13 @@ class _Resolution:
|
|||
#
|
||||
raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
|
||||
|
||||
def next_nameserver(self) -> Tuple[str, int, bool, float]:
|
||||
def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]:
|
||||
if self.retry_with_tcp:
|
||||
assert self.nameserver is not None
|
||||
assert not self.nameserver.is_always_max_size()
|
||||
self.tcp_attempt = True
|
||||
self.retry_with_tcp = False
|
||||
return (self.nameserver, self.port, True, 0)
|
||||
return (self.nameserver, True, 0)
|
||||
|
||||
backoff = 0.0
|
||||
if not self.current_nameservers:
|
||||
|
@ -707,11 +766,8 @@ class _Resolution:
|
|||
self.backoff = min(self.backoff * 2, 2)
|
||||
|
||||
self.nameserver = self.current_nameservers.pop(0)
|
||||
self.port = self.resolver.nameserver_ports.get(
|
||||
self.nameserver, self.resolver.port
|
||||
)
|
||||
self.tcp_attempt = self.tcp
|
||||
return (self.nameserver, self.port, self.tcp_attempt, backoff)
|
||||
self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size()
|
||||
return (self.nameserver, self.tcp_attempt, backoff)
|
||||
|
||||
def query_result(
|
||||
self, response: Optional[dns.message.Message], ex: Optional[Exception]
|
||||
|
@ -724,7 +780,13 @@ class _Resolution:
|
|||
# Exception during I/O or from_wire()
|
||||
assert response is None
|
||||
self.errors.append(
|
||||
(self.nameserver, self.tcp_attempt, self.port, ex, response)
|
||||
(
|
||||
str(self.nameserver),
|
||||
self.tcp_attempt,
|
||||
self.nameserver.answer_port(),
|
||||
ex,
|
||||
response,
|
||||
)
|
||||
)
|
||||
if (
|
||||
isinstance(ex, dns.exception.FormError)
|
||||
|
@ -752,12 +814,18 @@ class _Resolution:
|
|||
self.rdtype,
|
||||
self.rdclass,
|
||||
response,
|
||||
self.nameserver,
|
||||
self.port,
|
||||
self.nameserver.answer_nameserver(),
|
||||
self.nameserver.answer_port(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(
|
||||
(self.nameserver, self.tcp_attempt, self.port, e, response)
|
||||
(
|
||||
str(self.nameserver),
|
||||
self.tcp_attempt,
|
||||
self.nameserver.answer_port(),
|
||||
e,
|
||||
response,
|
||||
)
|
||||
)
|
||||
# The nameserver is no good, take it out of the mix.
|
||||
self.nameservers.remove(self.nameserver)
|
||||
|
@ -776,7 +844,13 @@ class _Resolution:
|
|||
)
|
||||
except Exception as e:
|
||||
self.errors.append(
|
||||
(self.nameserver, self.tcp_attempt, self.port, e, response)
|
||||
(
|
||||
str(self.nameserver),
|
||||
self.tcp_attempt,
|
||||
self.nameserver.answer_port(),
|
||||
e,
|
||||
response,
|
||||
)
|
||||
)
|
||||
# The nameserver is no good, take it out of the mix.
|
||||
self.nameservers.remove(self.nameserver)
|
||||
|
@ -792,7 +866,13 @@ class _Resolution:
|
|||
elif rcode == dns.rcode.YXDOMAIN:
|
||||
yex = YXDOMAIN()
|
||||
self.errors.append(
|
||||
(self.nameserver, self.tcp_attempt, self.port, yex, response)
|
||||
(
|
||||
str(self.nameserver),
|
||||
self.tcp_attempt,
|
||||
self.nameserver.answer_port(),
|
||||
yex,
|
||||
response,
|
||||
)
|
||||
)
|
||||
raise yex
|
||||
else:
|
||||
|
@ -804,9 +884,9 @@ class _Resolution:
|
|||
self.nameservers.remove(self.nameserver)
|
||||
self.errors.append(
|
||||
(
|
||||
self.nameserver,
|
||||
str(self.nameserver),
|
||||
self.tcp_attempt,
|
||||
self.port,
|
||||
self.nameserver.answer_port(),
|
||||
dns.rcode.to_text(rcode),
|
||||
response,
|
||||
)
|
||||
|
@ -840,8 +920,11 @@ class BaseResolver:
|
|||
retry_servfail: bool
|
||||
rotate: bool
|
||||
ndots: Optional[int]
|
||||
_nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
|
||||
|
||||
def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True):
|
||||
def __init__(
|
||||
self, filename: str = "/etc/resolv.conf", configure: bool = True
|
||||
) -> None:
|
||||
"""*filename*, a ``str`` or file object, specifying a file
|
||||
in standard /etc/resolv.conf format. This parameter is meaningful
|
||||
only when *configure* is true and the platform is POSIX.
|
||||
|
@ -860,13 +943,13 @@ class BaseResolver:
|
|||
elif filename:
|
||||
self.read_resolv_conf(filename)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Reset all resolver configuration to the defaults."""
|
||||
|
||||
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
|
||||
if len(self.domain) == 0:
|
||||
self.domain = dns.name.root
|
||||
self.nameservers = []
|
||||
self._nameservers = []
|
||||
self.nameserver_ports = {}
|
||||
self.port = 53
|
||||
self.search = []
|
||||
|
@ -903,6 +986,7 @@ class BaseResolver:
|
|||
|
||||
"""
|
||||
|
||||
nameservers = []
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
cm: contextlib.AbstractContextManager = open(f)
|
||||
|
@ -922,7 +1006,7 @@ class BaseResolver:
|
|||
continue
|
||||
|
||||
if tokens[0] == "nameserver":
|
||||
self.nameservers.append(tokens[1])
|
||||
nameservers.append(tokens[1])
|
||||
elif tokens[0] == "domain":
|
||||
self.domain = dns.name.from_text(tokens[1])
|
||||
# domain and search are exclusive
|
||||
|
@ -950,8 +1034,11 @@ class BaseResolver:
|
|||
self.ndots = int(opt.split(":")[1])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
if len(self.nameservers) == 0:
|
||||
if len(nameservers) == 0:
|
||||
raise NoResolverConfiguration("no nameservers")
|
||||
# Assigning directly instead of appending means we invoke the
|
||||
# setter logic, with additonal checking and enrichment.
|
||||
self.nameservers = nameservers
|
||||
|
||||
def read_registry(self) -> None:
|
||||
"""Extract resolver configuration from the Windows registry."""
|
||||
|
@ -1086,34 +1173,64 @@ class BaseResolver:
|
|||
|
||||
self.flags = flags
|
||||
|
||||
@property
|
||||
def nameservers(self) -> List[str]:
|
||||
return self._nameservers
|
||||
|
||||
@nameservers.setter
|
||||
def nameservers(self, nameservers: List[str]) -> None:
|
||||
"""
|
||||
*nameservers*, a ``list`` of nameservers.
|
||||
|
||||
Raises ``ValueError`` if *nameservers* is anything other than a
|
||||
``list``.
|
||||
"""
|
||||
@classmethod
|
||||
def _enrich_nameservers(
|
||||
cls,
|
||||
nameservers: Sequence[Union[str, dns.nameserver.Nameserver]],
|
||||
nameserver_ports: Dict[str, int],
|
||||
default_port: int,
|
||||
) -> List[dns.nameserver.Nameserver]:
|
||||
enriched_nameservers = []
|
||||
if isinstance(nameservers, list):
|
||||
for nameserver in nameservers:
|
||||
if not dns.inet.is_address(nameserver):
|
||||
enriched_nameserver: dns.nameserver.Nameserver
|
||||
if isinstance(nameserver, dns.nameserver.Nameserver):
|
||||
enriched_nameserver = nameserver
|
||||
elif dns.inet.is_address(nameserver):
|
||||
port = nameserver_ports.get(nameserver, default_port)
|
||||
enriched_nameserver = dns.nameserver.Do53Nameserver(
|
||||
nameserver, port
|
||||
)
|
||||
else:
|
||||
try:
|
||||
if urlparse(nameserver).scheme != "https":
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"nameserver {nameserver} is not an "
|
||||
"IP address or valid https URL"
|
||||
f"nameserver {nameserver} is not a "
|
||||
"dns.nameserver.Nameserver instance or text form, "
|
||||
"IP address, nor a valid https URL"
|
||||
)
|
||||
self._nameservers = nameservers
|
||||
enriched_nameserver = dns.nameserver.DoHNameserver(nameserver)
|
||||
enriched_nameservers.append(enriched_nameserver)
|
||||
else:
|
||||
raise ValueError(
|
||||
"nameservers must be a list (not a {})".format(type(nameservers))
|
||||
"nameservers must be a list or tuple (not a {})".format(
|
||||
type(nameservers)
|
||||
)
|
||||
)
|
||||
return enriched_nameservers
|
||||
|
||||
@property
|
||||
def nameservers(
|
||||
self,
|
||||
) -> Sequence[Union[str, dns.nameserver.Nameserver]]:
|
||||
return self._nameservers
|
||||
|
||||
@nameservers.setter
|
||||
def nameservers(
|
||||
self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
|
||||
) -> None:
|
||||
"""
|
||||
*nameservers*, a ``list`` of nameservers, where a nameserver is either
|
||||
a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
|
||||
instance.
|
||||
|
||||
Raises ``ValueError`` if *nameservers* is not a list of nameservers.
|
||||
"""
|
||||
# We just call _enrich_nameservers() for checking
|
||||
self._enrich_nameservers(nameservers, self.nameserver_ports, self.port)
|
||||
self._nameservers = nameservers
|
||||
|
||||
|
||||
class Resolver(BaseResolver):
|
||||
|
@ -1198,33 +1315,18 @@ class Resolver(BaseResolver):
|
|||
assert request is not None # needed for type checking
|
||||
done = False
|
||||
while not done:
|
||||
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
|
||||
(nameserver, tcp, backoff) = resolution.next_nameserver()
|
||||
if backoff:
|
||||
time.sleep(backoff)
|
||||
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
||||
try:
|
||||
if dns.inet.is_address(nameserver):
|
||||
if tcp:
|
||||
response = dns.query.tcp(
|
||||
response = nameserver.query(
|
||||
request,
|
||||
nameserver,
|
||||
timeout=timeout,
|
||||
port=port,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
max_size=tcp,
|
||||
)
|
||||
else:
|
||||
response = dns.query.udp(
|
||||
request,
|
||||
nameserver,
|
||||
timeout=timeout,
|
||||
port=port,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
raise_on_truncation=True,
|
||||
)
|
||||
else:
|
||||
response = dns.query.https(request, nameserver, timeout=timeout)
|
||||
except Exception as ex:
|
||||
(_, done) = resolution.query_result(None, ex)
|
||||
continue
|
||||
|
@ -1293,7 +1395,72 @@ class Resolver(BaseResolver):
|
|||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||
return self.resolve(
|
||||
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
||||
) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def resolve_name(
|
||||
self,
|
||||
name: Union[dns.name.Name, str],
|
||||
family: int = socket.AF_UNSPEC,
|
||||
**kwargs: Any,
|
||||
) -> HostAnswers:
|
||||
"""Use a resolver to query for address records.
|
||||
|
||||
This utilizes the resolve() method to perform A and/or AAAA lookups on
|
||||
the specified name.
|
||||
|
||||
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
|
||||
|
||||
*family*, an ``int``, the address family. If socket.AF_UNSPEC
|
||||
(the default), both A and AAAA records will be retrieved.
|
||||
|
||||
All other arguments that can be passed to the resolve() function
|
||||
except for rdtype and rdclass are also supported by this
|
||||
function.
|
||||
"""
|
||||
# We make a modified kwargs for type checking happiness, as otherwise
|
||||
# we get a legit warning about possibly having rdtype and rdclass
|
||||
# in the kwargs more than once.
|
||||
modified_kwargs: Dict[str, Any] = {}
|
||||
modified_kwargs.update(kwargs)
|
||||
modified_kwargs.pop("rdtype", None)
|
||||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||
|
||||
if family == socket.AF_INET:
|
||||
v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs)
|
||||
return HostAnswers.make(v4=v4)
|
||||
elif family == socket.AF_INET6:
|
||||
v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
|
||||
return HostAnswers.make(v6=v6)
|
||||
elif family != socket.AF_UNSPEC:
|
||||
raise NotImplementedError(f"unknown address family {family}")
|
||||
|
||||
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
|
||||
lifetime = modified_kwargs.pop("lifetime", None)
|
||||
start = time.time()
|
||||
v6 = self.resolve(
|
||||
name,
|
||||
dns.rdatatype.AAAA,
|
||||
raise_on_no_answer=False,
|
||||
lifetime=self._compute_timeout(start, lifetime),
|
||||
**modified_kwargs,
|
||||
)
|
||||
# Note that setting name ensures we query the same name
|
||||
# for A as we did for AAAA. (This is just in case search lists
|
||||
# are active by default in the resolver configuration and
|
||||
# we might be talking to a server that says NXDOMAIN when it
|
||||
# wants to say NOERROR no data.
|
||||
name = v6.qname
|
||||
v4 = self.resolve(
|
||||
name,
|
||||
dns.rdatatype.A,
|
||||
raise_on_no_answer=False,
|
||||
lifetime=self._compute_timeout(start, lifetime),
|
||||
**modified_kwargs,
|
||||
)
|
||||
answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer)
|
||||
if not answers:
|
||||
raise NoAnswer(response=v6.response)
|
||||
return answers
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
@ -1320,6 +1487,37 @@ class Resolver(BaseResolver):
|
|||
|
||||
# pylint: enable=redefined-outer-name
|
||||
|
||||
def try_ddr(self, lifetime: float = 5.0) -> None:
|
||||
"""Try to update the resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
|
||||
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
|
||||
is 5 seconds.
|
||||
|
||||
If the SVCB query is successful and results in a non-empty list of nameservers,
|
||||
then the resolver's nameservers are set to the returned servers in priority
|
||||
order.
|
||||
|
||||
The current implementation does not use any address hints from the SVCB record,
|
||||
nor does it resolve addresses for the SCVB target name, rather it assumes that
|
||||
the bootstrap nameserver will always be one of the addresses and uses it.
|
||||
A future revision to the code may offer fuller support. The code verifies that
|
||||
the bootstrap nameserver is in the Subject Alternative Name field of the
|
||||
TLS certficate.
|
||||
"""
|
||||
try:
|
||||
expiration = time.time() + lifetime
|
||||
answer = self.resolve(
|
||||
dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime
|
||||
)
|
||||
timeout = dns.query._remaining(expiration)
|
||||
nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
|
||||
if len(nameservers) > 0:
|
||||
self.nameservers = nameservers
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
#: The default resolver.
|
||||
default_resolver: Optional[Resolver] = None
|
||||
|
@ -1333,7 +1531,7 @@ def get_default_resolver() -> Resolver:
|
|||
return default_resolver
|
||||
|
||||
|
||||
def reset_default_resolver():
|
||||
def reset_default_resolver() -> None:
|
||||
"""Re-initialize default resolver.
|
||||
|
||||
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
|
||||
|
@ -1355,7 +1553,6 @@ def resolve(
|
|||
lifetime: Optional[float] = None,
|
||||
search: Optional[bool] = None,
|
||||
) -> Answer: # pragma: no cover
|
||||
|
||||
"""Query nameservers to find the answer to the question.
|
||||
|
||||
This is a convenience function that uses the default resolver
|
||||
|
@ -1421,6 +1618,18 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer:
|
|||
return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
||||
|
||||
|
||||
def resolve_name(
|
||||
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
|
||||
) -> HostAnswers:
|
||||
"""Use a resolver to query for address records.
|
||||
|
||||
See ``dns.resolver.Resolver.resolve_name`` for more information on the
|
||||
parameters.
|
||||
"""
|
||||
|
||||
return get_default_resolver().resolve_name(name, family, **kwargs)
|
||||
|
||||
|
||||
def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
"""Determine the canonical name of *name*.
|
||||
|
||||
|
@ -1431,6 +1640,16 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
|||
return get_default_resolver().canonical_name(name)
|
||||
|
||||
|
||||
def try_ddr(lifetime: float = 5.0) -> None:
|
||||
"""Try to update the default resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
|
||||
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
|
||||
"""
|
||||
return get_default_resolver().try_ddr(lifetime)
|
||||
|
||||
|
||||
def zone_for_name(
|
||||
name: Union[dns.name.Name, str],
|
||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||
|
@ -1478,7 +1697,7 @@ def zone_for_name(
|
|||
while 1:
|
||||
try:
|
||||
rlifetime: Optional[float]
|
||||
if expiration:
|
||||
if expiration is not None:
|
||||
rlifetime = expiration - time.time()
|
||||
if rlifetime <= 0:
|
||||
rlifetime = 0
|
||||
|
@ -1516,6 +1735,83 @@ def zone_for_name(
|
|||
raise NoRootSOA
|
||||
|
||||
|
||||
def make_resolver_at(
|
||||
where: Union[dns.name.Name, str],
|
||||
port: int = 53,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
resolver: Optional[Resolver] = None,
|
||||
) -> Resolver:
|
||||
"""Make a stub resolver using the specified destination as the full resolver.
|
||||
|
||||
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
|
||||
full resolver.
|
||||
|
||||
*port*, an ``int``, the port to use. If not specified, the default is 53.
|
||||
|
||||
*family*, an ``int``, the address family to use. This parameter is used if
|
||||
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
|
||||
the first address returned by ``resolve_name()`` will be used, otherwise the
|
||||
first address of the specified family will be used.
|
||||
|
||||
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
|
||||
resolution of hostnames. If not specified, the default resolver will be used.
|
||||
|
||||
Returns a ``dns.resolver.Resolver`` or raises an exception.
|
||||
"""
|
||||
if resolver is None:
|
||||
resolver = get_default_resolver()
|
||||
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
|
||||
if isinstance(where, str) and dns.inet.is_address(where):
|
||||
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
|
||||
else:
|
||||
for address in resolver.resolve_name(where, family).addresses():
|
||||
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
|
||||
res = dns.resolver.Resolver(configure=False)
|
||||
res.nameservers = nameservers
|
||||
return res
|
||||
|
||||
|
||||
def resolve_at(
|
||||
where: Union[dns.name.Name, str],
|
||||
qname: Union[dns.name.Name, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
|
||||
tcp: bool = False,
|
||||
source: Optional[str] = None,
|
||||
raise_on_no_answer: bool = True,
|
||||
source_port: int = 0,
|
||||
lifetime: Optional[float] = None,
|
||||
search: Optional[bool] = None,
|
||||
port: int = 53,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
resolver: Optional[Resolver] = None,
|
||||
) -> Answer:
|
||||
"""Query nameservers to find the answer to the question.
|
||||
|
||||
This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to
|
||||
make a resolver, and then uses it to resolve the query.
|
||||
|
||||
See ``dns.resolver.Resolver.resolve`` for more information on the resolution
|
||||
parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver
|
||||
parameters *where*, *port*, *family*, and *resolver*.
|
||||
|
||||
If making more than one query, it is more efficient to call
|
||||
``dns.resolver.make_resolver_at()`` and then use that resolver for the queries
|
||||
instead of calling ``resolve_at()`` multiple times.
|
||||
"""
|
||||
return make_resolver_at(where, port, family, resolver).resolve(
|
||||
qname,
|
||||
rdtype,
|
||||
rdclass,
|
||||
tcp,
|
||||
source,
|
||||
raise_on_no_answer,
|
||||
source_port,
|
||||
lifetime,
|
||||
search,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Support for overriding the system resolver for all python code in the
|
||||
# running process.
|
||||
|
@ -1559,8 +1855,7 @@ def _getaddrinfo(
|
|||
)
|
||||
if host is None and service is None:
|
||||
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
||||
v6addrs = []
|
||||
v4addrs = []
|
||||
addrs = []
|
||||
canonical_name = None # pylint: disable=redefined-outer-name
|
||||
# Is host None or an address literal? If so, use the system's
|
||||
# getaddrinfo().
|
||||
|
@ -1576,24 +1871,9 @@ def _getaddrinfo(
|
|||
pass
|
||||
# Something needs resolution!
|
||||
try:
|
||||
if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
|
||||
v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False)
|
||||
# Note that setting host ensures we query the same name
|
||||
# for A as we did for AAAA. (This is just in case search lists
|
||||
# are active by default in the resolver configuration and
|
||||
# we might be talking to a server that says NXDOMAIN when it
|
||||
# wants to say NOERROR no data.
|
||||
host = v6.qname
|
||||
canonical_name = v6.canonical_name.to_text(True)
|
||||
if v6.rrset is not None:
|
||||
for rdata in v6.rrset:
|
||||
v6addrs.append(rdata.address)
|
||||
if family == socket.AF_INET or family == socket.AF_UNSPEC:
|
||||
v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False)
|
||||
canonical_name = v4.canonical_name.to_text(True)
|
||||
if v4.rrset is not None:
|
||||
for rdata in v4.rrset:
|
||||
v4addrs.append(rdata.address)
|
||||
answers = _resolver.resolve_name(host, family)
|
||||
addrs = answers.addresses_and_families()
|
||||
canonical_name = answers.canonical_name().to_text(True)
|
||||
except dns.resolver.NXDOMAIN:
|
||||
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
||||
except Exception:
|
||||
|
@ -1625,20 +1905,11 @@ def _getaddrinfo(
|
|||
cname = canonical_name
|
||||
else:
|
||||
cname = ""
|
||||
if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
|
||||
for addr in v6addrs:
|
||||
for addr, af in addrs:
|
||||
for socktype in socktypes:
|
||||
for proto in _protocols_for_socktype[socktype]:
|
||||
tuples.append(
|
||||
(socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0))
|
||||
)
|
||||
if family == socket.AF_INET or family == socket.AF_UNSPEC:
|
||||
for addr in v4addrs:
|
||||
for socktype in socktypes:
|
||||
for proto in _protocols_for_socktype[socktype]:
|
||||
tuples.append(
|
||||
(socket.AF_INET, socktype, proto, cname, (addr, port))
|
||||
)
|
||||
addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
|
||||
tuples.append((af, socktype, proto, cname, addr_tuple))
|
||||
if len(tuples) == 0:
|
||||
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
||||
return tuples
|
||||
|
|
|
@ -19,9 +19,9 @@
|
|||
|
||||
import binascii
|
||||
|
||||
import dns.name
|
||||
import dns.ipv6
|
||||
import dns.ipv4
|
||||
import dns.ipv6
|
||||
import dns.name
|
||||
|
||||
ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.")
|
||||
ipv6_reverse_domain = dns.name.from_text("ip6.arpa.")
|
||||
|
|
|
@ -17,11 +17,11 @@
|
|||
|
||||
"""DNS RRsets (an RRset is a named rdataset)"""
|
||||
|
||||
from typing import Any, cast, Collection, Dict, Optional, Union
|
||||
from typing import Any, Collection, Dict, Optional, Union, cast
|
||||
|
||||
import dns.name
|
||||
import dns.rdataset
|
||||
import dns.rdataclass
|
||||
import dns.rdataset
|
||||
import dns.renderer
|
||||
|
||||
|
||||
|
@ -214,9 +214,9 @@ def from_text_list(
|
|||
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name, None, idna_codec=idna_codec)
|
||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
r = RRset(name, the_rdclass, the_rdtype)
|
||||
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
r = RRset(name, rdclass, rdtype)
|
||||
r.update_ttl(ttl)
|
||||
for t in text_rdatas:
|
||||
rd = dns.rdata.from_text(
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
|
||||
"""Tokenize DNS zone file format"""
|
||||
|
||||
from typing import Any, Optional, List, Tuple
|
||||
|
||||
import io
|
||||
import sys
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import dns.exception
|
||||
import dns.name
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import collections
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import dns.exception
|
||||
import dns.name
|
||||
|
@ -357,6 +356,27 @@ class Transaction:
|
|||
"""
|
||||
self._check_delete_name.append(check)
|
||||
|
||||
def iterate_rdatasets(
|
||||
self,
|
||||
) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
|
||||
"""Iterate all the rdatasets in the transaction, returning
|
||||
(`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
|
||||
|
||||
Note that as is usual with python iterators, adding or removing items
|
||||
while iterating will invalidate the iterator and may raise `RuntimeError`
|
||||
or fail to iterate over all entries."""
|
||||
self._check_ended()
|
||||
return self._iterate_rdatasets()
|
||||
|
||||
def iterate_names(self) -> Iterator[dns.name.Name]:
|
||||
"""Iterate all the names in the transaction.
|
||||
|
||||
Note that as is usual with python iterators, adding or removing names
|
||||
while iterating will invalidate the iterator and may raise `RuntimeError`
|
||||
or fail to iterate over all entries."""
|
||||
self._check_ended()
|
||||
return self._iterate_names()
|
||||
|
||||
#
|
||||
# Helper methods
|
||||
#
|
||||
|
@ -416,7 +436,7 @@ class Transaction:
|
|||
rdataset = rrset.to_rdataset()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{method} requires a name or RRset " + "as the first argument"
|
||||
f"{method} requires a name or RRset as the first argument"
|
||||
)
|
||||
if rdataset.rdclass != self.manager.get_class():
|
||||
raise ValueError(f"{method} has objects of wrong RdataClass")
|
||||
|
@ -475,7 +495,7 @@ class Transaction:
|
|||
name = rdataset.name
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{method} requires a name or RRset " + "as the first argument"
|
||||
f"{method} requires a name or RRset as the first argument"
|
||||
)
|
||||
self._raise_if_not_empty(method, args)
|
||||
if rdataset:
|
||||
|
@ -610,6 +630,10 @@ class Transaction:
|
|||
"""Return an iterator that yields (name, rdataset) tuples."""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def _iterate_names(self):
|
||||
"""Return an iterator that yields a name."""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def _get_node(self, name):
|
||||
"""Return the node at *name*, if any.
|
||||
|
||||
|
|
|
@ -23,9 +23,9 @@ import hmac
|
|||
import struct
|
||||
|
||||
import dns.exception
|
||||
import dns.rdataclass
|
||||
import dns.name
|
||||
import dns.rcode
|
||||
import dns.rdataclass
|
||||
|
||||
|
||||
class BadTime(dns.exception.DNSException):
|
||||
|
@ -187,9 +187,7 @@ class HMACTSig:
|
|||
try:
|
||||
hashinfo = self._hashes[algorithm]
|
||||
except KeyError:
|
||||
raise NotImplementedError(
|
||||
f"TSIG algorithm {algorithm} " + "is not supported"
|
||||
)
|
||||
raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported")
|
||||
|
||||
# create the HMAC context
|
||||
if isinstance(hashinfo, tuple):
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
"""A place to store TSIG keys."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict
|
||||
|
||||
import dns.name
|
||||
import dns.tsig
|
||||
|
@ -33,7 +32,7 @@ def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
|
|||
@rtype: dict"""
|
||||
|
||||
keyring = {}
|
||||
for (name, value) in textring.items():
|
||||
for name, value in textring.items():
|
||||
kname = dns.name.from_text(name)
|
||||
if isinstance(value, str):
|
||||
keyring[kname] = dns.tsig.Key(kname, value).secret
|
||||
|
@ -55,7 +54,7 @@ def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]:
|
|||
def b64encode(secret):
|
||||
return base64.encodebytes(secret).decode().rstrip()
|
||||
|
||||
for (name, key) in keyring.items():
|
||||
for name, key in keyring.items():
|
||||
tname = name.to_text()
|
||||
if isinstance(key, bytes):
|
||||
textring[tname] = b64encode(key)
|
||||
|
|
|
@ -24,8 +24,8 @@ import dns.name
|
|||
import dns.opcode
|
||||
import dns.rdata
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.rdataset
|
||||
import dns.rdatatype
|
||||
import dns.tsig
|
||||
|
||||
|
||||
|
@ -43,7 +43,6 @@ class UpdateSection(dns.enum.IntEnum):
|
|||
|
||||
|
||||
class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
|
||||
|
||||
# ignore the mypy error here as we mean to use a different enum
|
||||
_section_enum = UpdateSection # type: ignore
|
||||
|
||||
|
@ -336,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
|
|||
True,
|
||||
)
|
||||
else:
|
||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
self.find_rrset(
|
||||
self.prerequisite,
|
||||
name,
|
||||
dns.rdataclass.NONE,
|
||||
the_rdtype,
|
||||
rdtype,
|
||||
dns.rdatatype.NONE,
|
||||
None,
|
||||
True,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue