mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-07 13:41:15 -07:00
Bump dnspython from 2.3.0 to 2.4.2 (#2123)
* Bump dnspython from 2.3.0 to 2.4.2 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.3.0 to 2.4.2. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.3.0...v2.4.2) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update dnspython==2.4.2 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
9f00f5dafa
commit
c0aa4e4996
108 changed files with 2985 additions and 1136 deletions
|
@ -22,6 +22,7 @@ __all__ = [
|
||||||
"asyncquery",
|
"asyncquery",
|
||||||
"asyncresolver",
|
"asyncresolver",
|
||||||
"dnssec",
|
"dnssec",
|
||||||
|
"dnssecalgs",
|
||||||
"dnssectypes",
|
"dnssectypes",
|
||||||
"e164",
|
"e164",
|
||||||
"edns",
|
"edns",
|
||||||
|
|
|
@ -35,6 +35,9 @@ class Socket: # pragma: no cover
|
||||||
async def getsockname(self):
|
async def getsockname(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def getpeercert(self, timeout):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -61,6 +64,11 @@ class StreamSocket(Socket): # pragma: no cover
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class NullTransport:
|
||||||
|
async def connect_tcp(self, host, port, timeout, local_address):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Backend: # pragma: no cover
|
class Backend: # pragma: no cover
|
||||||
def name(self):
|
def name(self):
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
@ -83,3 +91,9 @@ class Backend: # pragma: no cover
|
||||||
|
|
||||||
async def sleep(self, interval):
|
async def sleep(self, interval):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_transport_class(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def wait_for(self, awaitable, timeout):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -2,14 +2,13 @@
|
||||||
|
|
||||||
"""asyncio library query support"""
|
"""asyncio library query support"""
|
||||||
|
|
||||||
import socket
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import dns._asyncbackend
|
import dns._asyncbackend
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
|
||||||
|
|
||||||
_is_win32 = sys.platform == "win32"
|
_is_win32 = sys.platform == "win32"
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,14 +37,21 @@ class _DatagramProtocol:
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
if self.recvfrom and not self.recvfrom.done():
|
if self.recvfrom and not self.recvfrom.done():
|
||||||
self.recvfrom.set_exception(exc)
|
if exc is None:
|
||||||
|
# EOF we triggered. Is there a better way to do this?
|
||||||
|
try:
|
||||||
|
raise EOFError
|
||||||
|
except EOFError as e:
|
||||||
|
self.recvfrom.set_exception(e)
|
||||||
|
else:
|
||||||
|
self.recvfrom.set_exception(exc)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_wait_for(awaitable, timeout):
|
async def _maybe_wait_for(awaitable, timeout):
|
||||||
if timeout:
|
if timeout is not None:
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(awaitable, timeout)
|
return await asyncio.wait_for(awaitable, timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
@ -85,6 +91,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||||
async def getsockname(self):
|
async def getsockname(self):
|
||||||
return self.transport.get_extra_info("sockname")
|
return self.transport.get_extra_info("sockname")
|
||||||
|
|
||||||
|
async def getpeercert(self, timeout):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||||
def __init__(self, af, reader, writer):
|
def __init__(self, af, reader, writer):
|
||||||
|
@ -101,10 +110,6 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
try:
|
|
||||||
await self.writer.wait_closed()
|
|
||||||
except AttributeError: # pragma: no cover
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def getpeername(self):
|
async def getpeername(self):
|
||||||
return self.writer.get_extra_info("peername")
|
return self.writer.get_extra_info("peername")
|
||||||
|
@ -112,6 +117,97 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||||
async def getsockname(self):
|
async def getsockname(self):
|
||||||
return self.writer.get_extra_info("sockname")
|
return self.writer.get_extra_info("sockname")
|
||||||
|
|
||||||
|
async def getpeercert(self, timeout):
|
||||||
|
return self.writer.get_extra_info("peercert")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import anyio
|
||||||
|
import httpcore
|
||||||
|
import httpcore._backends.anyio
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
||||||
|
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream
|
||||||
|
|
||||||
|
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
||||||
|
|
||||||
|
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
||||||
|
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||||
|
super().__init__()
|
||||||
|
self._local_port = local_port
|
||||||
|
self._resolver = resolver
|
||||||
|
self._bootstrap_address = bootstrap_address
|
||||||
|
self._family = family
|
||||||
|
if local_port != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"the asyncio transport for HTTPX cannot set the local port"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def connect_tcp(
|
||||||
|
self, host, port, timeout, local_address, socket_options=None
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
addresses = []
|
||||||
|
_, expiration = _compute_times(timeout)
|
||||||
|
if dns.inet.is_address(host):
|
||||||
|
addresses.append(host)
|
||||||
|
elif self._bootstrap_address is not None:
|
||||||
|
addresses.append(self._bootstrap_address)
|
||||||
|
else:
|
||||||
|
timeout = _remaining(expiration)
|
||||||
|
family = self._family
|
||||||
|
if local_address:
|
||||||
|
family = dns.inet.af_for_address(local_address)
|
||||||
|
answers = await self._resolver.resolve_name(
|
||||||
|
host, family=family, lifetime=timeout
|
||||||
|
)
|
||||||
|
addresses = answers.addresses()
|
||||||
|
for address in addresses:
|
||||||
|
try:
|
||||||
|
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||||
|
timeout = _remaining(attempt_expiration)
|
||||||
|
with anyio.fail_after(timeout):
|
||||||
|
stream = await anyio.connect_tcp(
|
||||||
|
remote_host=address,
|
||||||
|
remote_port=port,
|
||||||
|
local_host=local_address,
|
||||||
|
)
|
||||||
|
return _CoreAnyIOStream(stream)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise httpcore.ConnectError
|
||||||
|
|
||||||
|
async def connect_unix_socket(
|
||||||
|
self, path, timeout, socket_options=None
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def sleep(self, seconds): # pylint: disable=signature-differs
|
||||||
|
await anyio.sleep(seconds)
|
||||||
|
|
||||||
|
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
local_port=0,
|
||||||
|
bootstrap_address=None,
|
||||||
|
resolver=None,
|
||||||
|
family=socket.AF_UNSPEC,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if resolver is None:
|
||||||
|
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||||
|
import dns.asyncresolver
|
||||||
|
|
||||||
|
resolver = dns.asyncresolver.Resolver()
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._pool._network_backend = _NetworkBackend(
|
||||||
|
resolver, local_port, bootstrap_address, family
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Backend(dns._asyncbackend.Backend):
|
class Backend(dns._asyncbackend.Backend):
|
||||||
def name(self):
|
def name(self):
|
||||||
|
@ -171,3 +267,9 @@ class Backend(dns._asyncbackend.Backend):
|
||||||
|
|
||||||
def datagram_connection_required(self):
|
def datagram_connection_required(self):
|
||||||
return _is_win32
|
return _is_win32
|
||||||
|
|
||||||
|
def get_transport_class(self):
|
||||||
|
return _HTTPTransport
|
||||||
|
|
||||||
|
async def wait_for(self, awaitable, timeout):
|
||||||
|
return await _maybe_wait_for(awaitable, timeout)
|
||||||
|
|
|
@ -1,122 +0,0 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
|
||||||
|
|
||||||
"""curio async I/O library query support"""
|
|
||||||
|
|
||||||
import socket
|
|
||||||
import curio
|
|
||||||
import curio.socket # type: ignore
|
|
||||||
|
|
||||||
import dns._asyncbackend
|
|
||||||
import dns.exception
|
|
||||||
import dns.inet
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_timeout(timeout):
|
|
||||||
if timeout:
|
|
||||||
return curio.ignore_after(timeout)
|
|
||||||
else:
|
|
||||||
return dns._asyncbackend.NullContext()
|
|
||||||
|
|
||||||
|
|
||||||
# for brevity
|
|
||||||
_lltuple = dns.inet.low_level_address_tuple
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
|
|
||||||
|
|
||||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
|
||||||
def __init__(self, socket):
|
|
||||||
super().__init__(socket.family)
|
|
||||||
self.socket = socket
|
|
||||||
|
|
||||||
async def sendto(self, what, destination, timeout):
|
|
||||||
async with _maybe_timeout(timeout):
|
|
||||||
return await self.socket.sendto(what, destination)
|
|
||||||
raise dns.exception.Timeout(
|
|
||||||
timeout=timeout
|
|
||||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
|
||||||
|
|
||||||
async def recvfrom(self, size, timeout):
|
|
||||||
async with _maybe_timeout(timeout):
|
|
||||||
return await self.socket.recvfrom(size)
|
|
||||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.socket.close()
|
|
||||||
|
|
||||||
async def getpeername(self):
|
|
||||||
return self.socket.getpeername()
|
|
||||||
|
|
||||||
async def getsockname(self):
|
|
||||||
return self.socket.getsockname()
|
|
||||||
|
|
||||||
|
|
||||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
|
||||||
def __init__(self, socket):
|
|
||||||
self.socket = socket
|
|
||||||
self.family = socket.family
|
|
||||||
|
|
||||||
async def sendall(self, what, timeout):
|
|
||||||
async with _maybe_timeout(timeout):
|
|
||||||
return await self.socket.sendall(what)
|
|
||||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
|
||||||
|
|
||||||
async def recv(self, size, timeout):
|
|
||||||
async with _maybe_timeout(timeout):
|
|
||||||
return await self.socket.recv(size)
|
|
||||||
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.socket.close()
|
|
||||||
|
|
||||||
async def getpeername(self):
|
|
||||||
return self.socket.getpeername()
|
|
||||||
|
|
||||||
async def getsockname(self):
|
|
||||||
return self.socket.getsockname()
|
|
||||||
|
|
||||||
|
|
||||||
class Backend(dns._asyncbackend.Backend):
|
|
||||||
def name(self):
|
|
||||||
return "curio"
|
|
||||||
|
|
||||||
async def make_socket(
|
|
||||||
self,
|
|
||||||
af,
|
|
||||||
socktype,
|
|
||||||
proto=0,
|
|
||||||
source=None,
|
|
||||||
destination=None,
|
|
||||||
timeout=None,
|
|
||||||
ssl_context=None,
|
|
||||||
server_hostname=None,
|
|
||||||
):
|
|
||||||
if socktype == socket.SOCK_DGRAM:
|
|
||||||
s = curio.socket.socket(af, socktype, proto)
|
|
||||||
try:
|
|
||||||
if source:
|
|
||||||
s.bind(_lltuple(source, af))
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
await s.close()
|
|
||||||
raise
|
|
||||||
return DatagramSocket(s)
|
|
||||||
elif socktype == socket.SOCK_STREAM:
|
|
||||||
if source:
|
|
||||||
source_addr = _lltuple(source, af)
|
|
||||||
else:
|
|
||||||
source_addr = None
|
|
||||||
async with _maybe_timeout(timeout):
|
|
||||||
s = await curio.open_connection(
|
|
||||||
destination[0],
|
|
||||||
destination[1],
|
|
||||||
ssl=ssl_context,
|
|
||||||
source_addr=source_addr,
|
|
||||||
server_hostname=server_hostname,
|
|
||||||
)
|
|
||||||
return StreamSocket(s)
|
|
||||||
raise NotImplementedError(
|
|
||||||
"unsupported socket " + f"type {socktype}"
|
|
||||||
) # pragma: no cover
|
|
||||||
|
|
||||||
async def sleep(self, interval):
|
|
||||||
await curio.sleep(interval)
|
|
154
lib/dns/_ddr.py
Normal file
154
lib/dns/_ddr.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
#
|
||||||
|
# Support for Discovery of Designated Resolvers
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import dns.asyncbackend
|
||||||
|
import dns.inet
|
||||||
|
import dns.name
|
||||||
|
import dns.nameserver
|
||||||
|
import dns.query
|
||||||
|
import dns.rdtypes.svcbbase
|
||||||
|
|
||||||
|
# The special name of the local resolver when using DDR
|
||||||
|
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Processing is split up into I/O independent and I/O dependent parts to
|
||||||
|
# make supporting sync and async versions easy.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
class _SVCBInfo:
|
||||||
|
def __init__(self, bootstrap_address, port, hostname, nameservers):
|
||||||
|
self.bootstrap_address = bootstrap_address
|
||||||
|
self.port = port
|
||||||
|
self.hostname = hostname
|
||||||
|
self.nameservers = nameservers
|
||||||
|
|
||||||
|
def ddr_check_certificate(self, cert):
|
||||||
|
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
|
||||||
|
for name, value in cert["subjectAltName"]:
|
||||||
|
if name == "IP Address" and value == self.bootstrap_address:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def make_tls_context(self):
|
||||||
|
ssl = dns.query.ssl
|
||||||
|
ctx = ssl.create_default_context()
|
||||||
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
def ddr_tls_check_sync(self, lifetime):
|
||||||
|
ctx = self.make_tls_context()
|
||||||
|
expiration = time.time() + lifetime
|
||||||
|
with socket.create_connection(
|
||||||
|
(self.bootstrap_address, self.port), lifetime
|
||||||
|
) as s:
|
||||||
|
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
|
||||||
|
ts.settimeout(dns.query._remaining(expiration))
|
||||||
|
ts.do_handshake()
|
||||||
|
cert = ts.getpeercert()
|
||||||
|
return self.ddr_check_certificate(cert)
|
||||||
|
|
||||||
|
async def ddr_tls_check_async(self, lifetime, backend=None):
|
||||||
|
if backend is None:
|
||||||
|
backend = dns.asyncbackend.get_default_backend()
|
||||||
|
ctx = self.make_tls_context()
|
||||||
|
expiration = time.time() + lifetime
|
||||||
|
async with await backend.make_socket(
|
||||||
|
dns.inet.af_for_address(self.bootstrap_address),
|
||||||
|
socket.SOCK_STREAM,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
(self.bootstrap_address, self.port),
|
||||||
|
lifetime,
|
||||||
|
ctx,
|
||||||
|
self.hostname,
|
||||||
|
) as ts:
|
||||||
|
cert = await ts.getpeercert(dns.query._remaining(expiration))
|
||||||
|
return self.ddr_check_certificate(cert)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_nameservers_from_svcb(answer):
|
||||||
|
bootstrap_address = answer.nameserver
|
||||||
|
if not dns.inet.is_address(bootstrap_address):
|
||||||
|
return []
|
||||||
|
infos = []
|
||||||
|
for rr in answer.rrset.processing_order():
|
||||||
|
nameservers = []
|
||||||
|
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
alpns = set(param.ids)
|
||||||
|
host = rr.target.to_text(omit_final_dot=True)
|
||||||
|
port = None
|
||||||
|
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
|
||||||
|
if param is not None:
|
||||||
|
port = param.port
|
||||||
|
# For now we ignore address hints and address resolution and always use the
|
||||||
|
# bootstrap address
|
||||||
|
if b"h2" in alpns:
|
||||||
|
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
|
||||||
|
if param is None or not param.value.endswith(b"{?dns}"):
|
||||||
|
continue
|
||||||
|
path = param.value[:-6].decode()
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = "/" + path
|
||||||
|
if port is None:
|
||||||
|
port = 443
|
||||||
|
url = f"https://{host}:{port}{path}"
|
||||||
|
# check the URL
|
||||||
|
try:
|
||||||
|
urlparse(url)
|
||||||
|
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
|
||||||
|
except Exception:
|
||||||
|
# continue processing other ALPN types
|
||||||
|
pass
|
||||||
|
if b"dot" in alpns:
|
||||||
|
if port is None:
|
||||||
|
port = 853
|
||||||
|
nameservers.append(
|
||||||
|
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
|
||||||
|
)
|
||||||
|
if b"doq" in alpns:
|
||||||
|
if port is None:
|
||||||
|
port = 853
|
||||||
|
nameservers.append(
|
||||||
|
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
|
||||||
|
)
|
||||||
|
if len(nameservers) > 0:
|
||||||
|
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
|
||||||
|
return infos
|
||||||
|
|
||||||
|
|
||||||
|
def _get_nameservers_sync(answer, lifetime):
|
||||||
|
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
|
||||||
|
answer."""
|
||||||
|
nameservers = []
|
||||||
|
infos = _extract_nameservers_from_svcb(answer)
|
||||||
|
for info in infos:
|
||||||
|
try:
|
||||||
|
if info.ddr_tls_check_sync(lifetime):
|
||||||
|
nameservers.extend(info.nameservers)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return nameservers
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_nameservers_async(answer, lifetime):
|
||||||
|
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
|
||||||
|
answer."""
|
||||||
|
nameservers = []
|
||||||
|
infos = _extract_nameservers_from_svcb(answer)
|
||||||
|
for info in infos:
|
||||||
|
try:
|
||||||
|
if await info.ddr_tls_check_async(lifetime):
|
||||||
|
nameservers.extend(info.nameservers)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return nameservers
|
|
@ -7,7 +7,6 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
|
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
"""trio async I/O library query support"""
|
"""trio async I/O library query support"""
|
||||||
|
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import trio.socket # type: ignore
|
import trio.socket # type: ignore
|
||||||
|
|
||||||
|
@ -12,7 +13,7 @@ import dns.inet
|
||||||
|
|
||||||
|
|
||||||
def _maybe_timeout(timeout):
|
def _maybe_timeout(timeout):
|
||||||
if timeout:
|
if timeout is not None:
|
||||||
return trio.move_on_after(timeout)
|
return trio.move_on_after(timeout)
|
||||||
else:
|
else:
|
||||||
return dns._asyncbackend.NullContext()
|
return dns._asyncbackend.NullContext()
|
||||||
|
@ -50,6 +51,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||||
async def getsockname(self):
|
async def getsockname(self):
|
||||||
return self.socket.getsockname()
|
return self.socket.getsockname()
|
||||||
|
|
||||||
|
async def getpeercert(self, timeout):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||||
def __init__(self, family, stream, tls=False):
|
def __init__(self, family, stream, tls=False):
|
||||||
|
@ -82,6 +86,100 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||||
else:
|
else:
|
||||||
return self.stream.socket.getsockname()
|
return self.stream.socket.getsockname()
|
||||||
|
|
||||||
|
async def getpeercert(self, timeout):
|
||||||
|
if self.tls:
|
||||||
|
with _maybe_timeout(timeout):
|
||||||
|
await self.stream.do_handshake()
|
||||||
|
return self.stream.getpeercert()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import httpcore
|
||||||
|
import httpcore._backends.trio
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
||||||
|
_CoreTrioStream = httpcore._backends.trio.TrioStream
|
||||||
|
|
||||||
|
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
||||||
|
|
||||||
|
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
||||||
|
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||||
|
super().__init__()
|
||||||
|
self._local_port = local_port
|
||||||
|
self._resolver = resolver
|
||||||
|
self._bootstrap_address = bootstrap_address
|
||||||
|
self._family = family
|
||||||
|
|
||||||
|
async def connect_tcp(
|
||||||
|
self, host, port, timeout, local_address, socket_options=None
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
addresses = []
|
||||||
|
_, expiration = _compute_times(timeout)
|
||||||
|
if dns.inet.is_address(host):
|
||||||
|
addresses.append(host)
|
||||||
|
elif self._bootstrap_address is not None:
|
||||||
|
addresses.append(self._bootstrap_address)
|
||||||
|
else:
|
||||||
|
timeout = _remaining(expiration)
|
||||||
|
family = self._family
|
||||||
|
if local_address:
|
||||||
|
family = dns.inet.af_for_address(local_address)
|
||||||
|
answers = await self._resolver.resolve_name(
|
||||||
|
host, family=family, lifetime=timeout
|
||||||
|
)
|
||||||
|
addresses = answers.addresses()
|
||||||
|
for address in addresses:
|
||||||
|
try:
|
||||||
|
af = dns.inet.af_for_address(address)
|
||||||
|
if local_address is not None or self._local_port != 0:
|
||||||
|
source = (local_address, self._local_port)
|
||||||
|
else:
|
||||||
|
source = None
|
||||||
|
destination = (address, port)
|
||||||
|
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||||
|
timeout = _remaining(attempt_expiration)
|
||||||
|
sock = await Backend().make_socket(
|
||||||
|
af, socket.SOCK_STREAM, 0, source, destination, timeout
|
||||||
|
)
|
||||||
|
return _CoreTrioStream(sock.stream)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
raise httpcore.ConnectError
|
||||||
|
|
||||||
|
async def connect_unix_socket(
|
||||||
|
self, path, timeout, socket_options=None
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def sleep(self, seconds): # pylint: disable=signature-differs
|
||||||
|
await trio.sleep(seconds)
|
||||||
|
|
||||||
|
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
local_port=0,
|
||||||
|
bootstrap_address=None,
|
||||||
|
resolver=None,
|
||||||
|
family=socket.AF_UNSPEC,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if resolver is None:
|
||||||
|
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||||
|
import dns.asyncresolver
|
||||||
|
|
||||||
|
resolver = dns.asyncresolver.Resolver()
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._pool._network_backend = _NetworkBackend(
|
||||||
|
resolver, local_port, bootstrap_address, family
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Backend(dns._asyncbackend.Backend):
|
class Backend(dns._asyncbackend.Backend):
|
||||||
def name(self):
|
def name(self):
|
||||||
|
@ -104,8 +202,14 @@ class Backend(dns._asyncbackend.Backend):
|
||||||
if source:
|
if source:
|
||||||
await s.bind(_lltuple(source, af))
|
await s.bind(_lltuple(source, af))
|
||||||
if socktype == socket.SOCK_STREAM:
|
if socktype == socket.SOCK_STREAM:
|
||||||
|
connected = False
|
||||||
with _maybe_timeout(timeout):
|
with _maybe_timeout(timeout):
|
||||||
await s.connect(_lltuple(destination, af))
|
await s.connect(_lltuple(destination, af))
|
||||||
|
connected = True
|
||||||
|
if not connected:
|
||||||
|
raise dns.exception.Timeout(
|
||||||
|
timeout=timeout
|
||||||
|
) # lgtm[py/unreachable-statement]
|
||||||
except Exception: # pragma: no cover
|
except Exception: # pragma: no cover
|
||||||
s.close()
|
s.close()
|
||||||
raise
|
raise
|
||||||
|
@ -130,3 +234,13 @@ class Backend(dns._asyncbackend.Backend):
|
||||||
|
|
||||||
async def sleep(self, interval):
|
async def sleep(self, interval):
|
||||||
await trio.sleep(interval)
|
await trio.sleep(interval)
|
||||||
|
|
||||||
|
def get_transport_class(self):
|
||||||
|
return _HTTPTransport
|
||||||
|
|
||||||
|
async def wait_for(self, awaitable, timeout):
|
||||||
|
with _maybe_timeout(timeout):
|
||||||
|
return await awaitable
|
||||||
|
raise dns.exception.Timeout(
|
||||||
|
timeout=timeout
|
||||||
|
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||||
|
|
|
@ -5,13 +5,12 @@ from typing import Dict
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
|
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
|
||||||
from dns._asyncbackend import (
|
|
||||||
Socket,
|
|
||||||
DatagramSocket,
|
|
||||||
StreamSocket,
|
|
||||||
Backend,
|
Backend,
|
||||||
) # noqa: F401 lgtm[py/unused-import]
|
DatagramSocket,
|
||||||
|
Socket,
|
||||||
|
StreamSocket,
|
||||||
|
)
|
||||||
|
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
@ -30,8 +29,8 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
|
||||||
def get_backend(name: str) -> Backend:
|
def get_backend(name: str) -> Backend:
|
||||||
"""Get the specified asynchronous backend.
|
"""Get the specified asynchronous backend.
|
||||||
|
|
||||||
*name*, a ``str``, the name of the backend. Currently the "trio",
|
*name*, a ``str``, the name of the backend. Currently the "trio"
|
||||||
"curio", and "asyncio" backends are available.
|
and "asyncio" backends are available.
|
||||||
|
|
||||||
Raises NotImplementError if an unknown backend name is specified.
|
Raises NotImplementError if an unknown backend name is specified.
|
||||||
"""
|
"""
|
||||||
|
@ -43,10 +42,6 @@ def get_backend(name: str) -> Backend:
|
||||||
import dns._trio_backend
|
import dns._trio_backend
|
||||||
|
|
||||||
backend = dns._trio_backend.Backend()
|
backend = dns._trio_backend.Backend()
|
||||||
elif name == "curio":
|
|
||||||
import dns._curio_backend
|
|
||||||
|
|
||||||
backend = dns._curio_backend.Backend()
|
|
||||||
elif name == "asyncio":
|
elif name == "asyncio":
|
||||||
import dns._asyncio_backend
|
import dns._asyncio_backend
|
||||||
|
|
||||||
|
@ -73,9 +68,7 @@ def sniff() -> str:
|
||||||
try:
|
try:
|
||||||
return sniffio.current_async_library()
|
return sniffio.current_async_library()
|
||||||
except sniffio.AsyncLibraryNotFoundError:
|
except sniffio.AsyncLibraryNotFoundError:
|
||||||
raise AsyncLibraryNotFoundError(
|
raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
|
||||||
"sniffio cannot determine " + "async library"
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
|
@ -17,39 +17,38 @@
|
||||||
|
|
||||||
"""Talk to a DNS server."""
|
"""Talk to a DNS server."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import contextlib
|
import contextlib
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import dns.asyncbackend
|
import dns.asyncbackend
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.inet
|
import dns.inet
|
||||||
import dns.name
|
|
||||||
import dns.message
|
import dns.message
|
||||||
|
import dns.name
|
||||||
import dns.quic
|
import dns.quic
|
||||||
import dns.rcode
|
import dns.rcode
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.transaction
|
import dns.transaction
|
||||||
|
|
||||||
from dns._asyncbackend import NullContext
|
from dns._asyncbackend import NullContext
|
||||||
from dns.query import (
|
from dns.query import (
|
||||||
_compute_times,
|
|
||||||
_matches_destination,
|
|
||||||
BadResponse,
|
BadResponse,
|
||||||
ssl,
|
|
||||||
UDPMode,
|
|
||||||
_have_httpx,
|
|
||||||
_have_http2,
|
|
||||||
NoDOH,
|
NoDOH,
|
||||||
NoDOQ,
|
NoDOQ,
|
||||||
|
UDPMode,
|
||||||
|
_compute_times,
|
||||||
|
_have_http2,
|
||||||
|
_matches_destination,
|
||||||
|
_remaining,
|
||||||
|
have_doh,
|
||||||
|
ssl,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _have_httpx:
|
if have_doh:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
# for brevity
|
# for brevity
|
||||||
|
@ -73,7 +72,7 @@ def _source_tuple(af, address, port):
|
||||||
|
|
||||||
|
|
||||||
def _timeout(expiration, now=None):
|
def _timeout(expiration, now=None):
|
||||||
if expiration:
|
if expiration is not None:
|
||||||
if not now:
|
if not now:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
return max(expiration - now, 0)
|
return max(expiration - now, 0)
|
||||||
|
@ -445,9 +444,6 @@ async def tls(
|
||||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
if server_hostname is None:
|
if server_hostname is None:
|
||||||
ssl_context.check_hostname = False
|
ssl_context.check_hostname = False
|
||||||
else:
|
|
||||||
ssl_context = None
|
|
||||||
server_hostname = None
|
|
||||||
af = dns.inet.af_for_address(where)
|
af = dns.inet.af_for_address(where)
|
||||||
stuple = _source_tuple(af, source, source_port)
|
stuple = _source_tuple(af, source, source_port)
|
||||||
dtuple = (where, port)
|
dtuple = (where, port)
|
||||||
|
@ -495,6 +491,9 @@ async def https(
|
||||||
path: str = "/dns-query",
|
path: str = "/dns-query",
|
||||||
post: bool = True,
|
post: bool = True,
|
||||||
verify: Union[bool, str] = True,
|
verify: Union[bool, str] = True,
|
||||||
|
bootstrap_address: Optional[str] = None,
|
||||||
|
resolver: Optional["dns.asyncresolver.Resolver"] = None,
|
||||||
|
family: Optional[int] = socket.AF_UNSPEC,
|
||||||
) -> dns.message.Message:
|
) -> dns.message.Message:
|
||||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||||
|
|
||||||
|
@ -508,8 +507,10 @@ async def https(
|
||||||
parameters, exceptions, and return type of this method.
|
parameters, exceptions, and return type of this method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not _have_httpx:
|
if not have_doh:
|
||||||
raise NoDOH("httpx is not available.") # pragma: no cover
|
raise NoDOH # pragma: no cover
|
||||||
|
if client and not isinstance(client, httpx.AsyncClient):
|
||||||
|
raise ValueError("session parameter must be an httpx.AsyncClient")
|
||||||
|
|
||||||
wire = q.to_wire()
|
wire = q.to_wire()
|
||||||
try:
|
try:
|
||||||
|
@ -518,15 +519,32 @@ async def https(
|
||||||
af = None
|
af = None
|
||||||
transport = None
|
transport = None
|
||||||
headers = {"accept": "application/dns-message"}
|
headers = {"accept": "application/dns-message"}
|
||||||
if af is not None:
|
if af is not None and dns.inet.is_address(where):
|
||||||
if af == socket.AF_INET:
|
if af == socket.AF_INET:
|
||||||
url = "https://{}:{}{}".format(where, port, path)
|
url = "https://{}:{}{}".format(where, port, path)
|
||||||
elif af == socket.AF_INET6:
|
elif af == socket.AF_INET6:
|
||||||
url = "https://[{}]:{}{}".format(where, port, path)
|
url = "https://[{}]:{}{}".format(where, port, path)
|
||||||
else:
|
else:
|
||||||
url = where
|
url = where
|
||||||
if source is not None:
|
|
||||||
transport = httpx.AsyncHTTPTransport(local_address=source[0])
|
backend = dns.asyncbackend.get_default_backend()
|
||||||
|
|
||||||
|
if source is None:
|
||||||
|
local_address = None
|
||||||
|
local_port = 0
|
||||||
|
else:
|
||||||
|
local_address = source
|
||||||
|
local_port = source_port
|
||||||
|
transport = backend.get_transport_class()(
|
||||||
|
local_address=local_address,
|
||||||
|
http1=True,
|
||||||
|
http2=_have_http2,
|
||||||
|
verify=verify,
|
||||||
|
local_port=local_port,
|
||||||
|
bootstrap_address=bootstrap_address,
|
||||||
|
resolver=resolver,
|
||||||
|
family=family,
|
||||||
|
)
|
||||||
|
|
||||||
if client:
|
if client:
|
||||||
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
||||||
|
@ -545,14 +563,14 @@ async def https(
|
||||||
"content-length": str(len(wire)),
|
"content-length": str(len(wire)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
response = await the_client.post(
|
response = await backend.wait_for(
|
||||||
url, headers=headers, content=wire, timeout=timeout
|
the_client.post(url, headers=headers, content=wire), timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||||
response = await the_client.get(
|
response = await backend.wait_for(
|
||||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
the_client.get(url, headers=headers, params={"dns": twire}), timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||||
|
@ -690,6 +708,7 @@ async def quic(
|
||||||
connection: Optional[dns.quic.AsyncQuicConnection] = None,
|
connection: Optional[dns.quic.AsyncQuicConnection] = None,
|
||||||
verify: Union[bool, str] = True,
|
verify: Union[bool, str] = True,
|
||||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||||
|
server_hostname: Optional[str] = None,
|
||||||
) -> dns.message.Message:
|
) -> dns.message.Message:
|
||||||
"""Return the response obtained after sending an asynchronous query via
|
"""Return the response obtained after sending an asynchronous query via
|
||||||
DNS-over-QUIC.
|
DNS-over-QUIC.
|
||||||
|
@ -715,14 +734,16 @@ async def quic(
|
||||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||||
|
|
||||||
async with cfactory() as context:
|
async with cfactory() as context:
|
||||||
async with mfactory(context, verify_mode=verify) as the_manager:
|
async with mfactory(
|
||||||
|
context, verify_mode=verify, server_name=server_hostname
|
||||||
|
) as the_manager:
|
||||||
if not connection:
|
if not connection:
|
||||||
the_connection = the_manager.connect(where, port, source, source_port)
|
the_connection = the_manager.connect(where, port, source, source_port)
|
||||||
start = time.time()
|
(start, expiration) = _compute_times(timeout)
|
||||||
stream = await the_connection.make_stream()
|
stream = await the_connection.make_stream(timeout)
|
||||||
async with stream:
|
async with stream:
|
||||||
await stream.send(wire, True)
|
await stream.send(wire, True)
|
||||||
wire = await stream.receive(timeout)
|
wire = await stream.receive(_remaining(expiration))
|
||||||
finish = time.time()
|
finish = time.time()
|
||||||
r = dns.message.from_wire(
|
r = dns.message.from_wire(
|
||||||
wire,
|
wire,
|
||||||
|
|
|
@ -17,10 +17,11 @@
|
||||||
|
|
||||||
"""Asynchronous DNS stub resolver."""
|
"""Asynchronous DNS stub resolver."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
import socket
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import dns._ddr
|
||||||
import dns.asyncbackend
|
import dns.asyncbackend
|
||||||
import dns.asyncquery
|
import dns.asyncquery
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
@ -31,8 +32,7 @@ import dns.rdatatype
|
||||||
import dns.resolver # lgtm[py/import-and-import-from]
|
import dns.resolver # lgtm[py/import-and-import-from]
|
||||||
|
|
||||||
# import some resolver symbols for brevity
|
# import some resolver symbols for brevity
|
||||||
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
|
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
|
||||||
|
|
||||||
|
|
||||||
# for indentation purposes below
|
# for indentation purposes below
|
||||||
_udp = dns.asyncquery.udp
|
_udp = dns.asyncquery.udp
|
||||||
|
@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver):
|
||||||
assert request is not None # needed for type checking
|
assert request is not None # needed for type checking
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
|
(nameserver, tcp, backoff) = resolution.next_nameserver()
|
||||||
if backoff:
|
if backoff:
|
||||||
await backend.sleep(backoff)
|
await backend.sleep(backoff)
|
||||||
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
||||||
try:
|
try:
|
||||||
if dns.inet.is_address(nameserver):
|
response = await nameserver.async_query(
|
||||||
if tcp:
|
request,
|
||||||
response = await _tcp(
|
timeout=timeout,
|
||||||
request,
|
source=source,
|
||||||
nameserver,
|
source_port=source_port,
|
||||||
timeout,
|
max_size=tcp,
|
||||||
port,
|
backend=backend,
|
||||||
source,
|
)
|
||||||
source_port,
|
|
||||||
backend=backend,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await _udp(
|
|
||||||
request,
|
|
||||||
nameserver,
|
|
||||||
timeout,
|
|
||||||
port,
|
|
||||||
source,
|
|
||||||
source_port,
|
|
||||||
raise_on_truncation=True,
|
|
||||||
backend=backend,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await dns.asyncquery.https(
|
|
||||||
request, nameserver, timeout=timeout
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
(_, done) = resolution.query_result(None, ex)
|
(_, done) = resolution.query_result(None, ex)
|
||||||
continue
|
continue
|
||||||
|
@ -153,6 +135,73 @@ class Resolver(dns.resolver.BaseResolver):
|
||||||
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def resolve_name(
|
||||||
|
self,
|
||||||
|
name: Union[dns.name.Name, str],
|
||||||
|
family: int = socket.AF_UNSPEC,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dns.resolver.HostAnswers:
|
||||||
|
"""Use an asynchronous resolver to query for address records.
|
||||||
|
|
||||||
|
This utilizes the resolve() method to perform A and/or AAAA lookups on
|
||||||
|
the specified name.
|
||||||
|
|
||||||
|
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
|
||||||
|
|
||||||
|
*family*, an ``int``, the address family. If socket.AF_UNSPEC
|
||||||
|
(the default), both A and AAAA records will be retrieved.
|
||||||
|
|
||||||
|
All other arguments that can be passed to the resolve() function
|
||||||
|
except for rdtype and rdclass are also supported by this
|
||||||
|
function.
|
||||||
|
"""
|
||||||
|
# We make a modified kwargs for type checking happiness, as otherwise
|
||||||
|
# we get a legit warning about possibly having rdtype and rdclass
|
||||||
|
# in the kwargs more than once.
|
||||||
|
modified_kwargs: Dict[str, Any] = {}
|
||||||
|
modified_kwargs.update(kwargs)
|
||||||
|
modified_kwargs.pop("rdtype", None)
|
||||||
|
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||||
|
|
||||||
|
if family == socket.AF_INET:
|
||||||
|
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
|
||||||
|
return dns.resolver.HostAnswers.make(v4=v4)
|
||||||
|
elif family == socket.AF_INET6:
|
||||||
|
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
|
||||||
|
return dns.resolver.HostAnswers.make(v6=v6)
|
||||||
|
elif family != socket.AF_UNSPEC:
|
||||||
|
raise NotImplementedError(f"unknown address family {family}")
|
||||||
|
|
||||||
|
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
|
||||||
|
lifetime = modified_kwargs.pop("lifetime", None)
|
||||||
|
start = time.time()
|
||||||
|
v6 = await self.resolve(
|
||||||
|
name,
|
||||||
|
dns.rdatatype.AAAA,
|
||||||
|
raise_on_no_answer=False,
|
||||||
|
lifetime=self._compute_timeout(start, lifetime),
|
||||||
|
**modified_kwargs,
|
||||||
|
)
|
||||||
|
# Note that setting name ensures we query the same name
|
||||||
|
# for A as we did for AAAA. (This is just in case search lists
|
||||||
|
# are active by default in the resolver configuration and
|
||||||
|
# we might be talking to a server that says NXDOMAIN when it
|
||||||
|
# wants to say NOERROR no data.
|
||||||
|
name = v6.qname
|
||||||
|
v4 = await self.resolve(
|
||||||
|
name,
|
||||||
|
dns.rdatatype.A,
|
||||||
|
raise_on_no_answer=False,
|
||||||
|
lifetime=self._compute_timeout(start, lifetime),
|
||||||
|
**modified_kwargs,
|
||||||
|
)
|
||||||
|
answers = dns.resolver.HostAnswers.make(
|
||||||
|
v6=v6, v4=v4, add_empty=not raise_on_no_answer
|
||||||
|
)
|
||||||
|
if not answers:
|
||||||
|
raise NoAnswer(response=v6.response)
|
||||||
|
return answers
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
|
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||||
|
@ -176,6 +225,37 @@ class Resolver(dns.resolver.BaseResolver):
|
||||||
canonical_name = e.canonical_name
|
canonical_name = e.canonical_name
|
||||||
return canonical_name
|
return canonical_name
|
||||||
|
|
||||||
|
async def try_ddr(self, lifetime: float = 5.0) -> None:
|
||||||
|
"""Try to update the resolver's nameservers using Discovery of Designated
|
||||||
|
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||||
|
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||||
|
|
||||||
|
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
|
||||||
|
is 5 seconds.
|
||||||
|
|
||||||
|
If the SVCB query is successful and results in a non-empty list of nameservers,
|
||||||
|
then the resolver's nameservers are set to the returned servers in priority
|
||||||
|
order.
|
||||||
|
|
||||||
|
The current implementation does not use any address hints from the SVCB record,
|
||||||
|
nor does it resolve addresses for the SCVB target name, rather it assumes that
|
||||||
|
the bootstrap nameserver will always be one of the addresses and uses it.
|
||||||
|
A future revision to the code may offer fuller support. The code verifies that
|
||||||
|
the bootstrap nameserver is in the Subject Alternative Name field of the
|
||||||
|
TLS certficate.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
expiration = time.time() + lifetime
|
||||||
|
answer = await self.resolve(
|
||||||
|
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
|
||||||
|
)
|
||||||
|
timeout = dns.query._remaining(expiration)
|
||||||
|
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
|
||||||
|
if len(nameservers) > 0:
|
||||||
|
self.nameservers = nameservers
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
default_resolver = None
|
default_resolver = None
|
||||||
|
|
||||||
|
@ -246,6 +326,18 @@ async def resolve_address(
|
||||||
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_name(
|
||||||
|
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
|
||||||
|
) -> dns.resolver.HostAnswers:
|
||||||
|
"""Use a resolver to asynchronously query for address records.
|
||||||
|
|
||||||
|
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
|
||||||
|
information on the parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await get_default_resolver().resolve_name(name, family, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||||
"""Determine the canonical name of *name*.
|
"""Determine the canonical name of *name*.
|
||||||
|
|
||||||
|
@ -256,6 +348,16 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||||
return await get_default_resolver().canonical_name(name)
|
return await get_default_resolver().canonical_name(name)
|
||||||
|
|
||||||
|
|
||||||
|
async def try_ddr(timeout: float = 5.0) -> None:
|
||||||
|
"""Try to update the default resolver's nameservers using Discovery of Designated
|
||||||
|
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||||
|
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||||
|
|
||||||
|
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
|
||||||
|
"""
|
||||||
|
return await get_default_resolver().try_ddr(timeout)
|
||||||
|
|
||||||
|
|
||||||
async def zone_for_name(
|
async def zone_for_name(
|
||||||
name: Union[dns.name.Name, str],
|
name: Union[dns.name.Name, str],
|
||||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||||
|
@ -290,3 +392,84 @@ async def zone_for_name(
|
||||||
name = name.parent()
|
name = name.parent()
|
||||||
except dns.name.NoParent: # pragma: no cover
|
except dns.name.NoParent: # pragma: no cover
|
||||||
raise NoRootSOA
|
raise NoRootSOA
|
||||||
|
|
||||||
|
|
||||||
|
async def make_resolver_at(
|
||||||
|
where: Union[dns.name.Name, str],
|
||||||
|
port: int = 53,
|
||||||
|
family: int = socket.AF_UNSPEC,
|
||||||
|
resolver: Optional[Resolver] = None,
|
||||||
|
) -> Resolver:
|
||||||
|
"""Make a stub resolver using the specified destination as the full resolver.
|
||||||
|
|
||||||
|
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
|
||||||
|
full resolver.
|
||||||
|
|
||||||
|
*port*, an ``int``, the port to use. If not specified, the default is 53.
|
||||||
|
|
||||||
|
*family*, an ``int``, the address family to use. This parameter is used if
|
||||||
|
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
|
||||||
|
the first address returned by ``resolve_name()`` will be used, otherwise the
|
||||||
|
first address of the specified family will be used.
|
||||||
|
|
||||||
|
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
|
||||||
|
resolution of hostnames. If not specified, the default resolver will be used.
|
||||||
|
|
||||||
|
Returns a ``dns.resolver.Resolver`` or raises an exception.
|
||||||
|
"""
|
||||||
|
if resolver is None:
|
||||||
|
resolver = get_default_resolver()
|
||||||
|
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
|
||||||
|
if isinstance(where, str) and dns.inet.is_address(where):
|
||||||
|
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
|
||||||
|
else:
|
||||||
|
answers = await resolver.resolve_name(where, family)
|
||||||
|
for address in answers.addresses():
|
||||||
|
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
|
||||||
|
res = dns.asyncresolver.Resolver(configure=False)
|
||||||
|
res.nameservers = nameservers
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_at(
|
||||||
|
where: Union[dns.name.Name, str],
|
||||||
|
qname: Union[dns.name.Name, str],
|
||||||
|
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
|
||||||
|
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
|
||||||
|
tcp: bool = False,
|
||||||
|
source: Optional[str] = None,
|
||||||
|
raise_on_no_answer: bool = True,
|
||||||
|
source_port: int = 0,
|
||||||
|
lifetime: Optional[float] = None,
|
||||||
|
search: Optional[bool] = None,
|
||||||
|
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||||
|
port: int = 53,
|
||||||
|
family: int = socket.AF_UNSPEC,
|
||||||
|
resolver: Optional[Resolver] = None,
|
||||||
|
) -> dns.resolver.Answer:
|
||||||
|
"""Query nameservers to find the answer to the question.
|
||||||
|
|
||||||
|
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
|
||||||
|
to make a resolver, and then uses it to resolve the query.
|
||||||
|
|
||||||
|
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
|
||||||
|
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
|
||||||
|
resolver parameters *where*, *port*, *family*, and *resolver*.
|
||||||
|
|
||||||
|
If making more than one query, it is more efficient to call
|
||||||
|
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
|
||||||
|
instead of calling ``resolve_at()`` multiple times.
|
||||||
|
"""
|
||||||
|
res = await make_resolver_at(where, port, family, resolver)
|
||||||
|
return await res.resolve(
|
||||||
|
qname,
|
||||||
|
rdtype,
|
||||||
|
rdclass,
|
||||||
|
tcp,
|
||||||
|
source,
|
||||||
|
raise_on_no_answer,
|
||||||
|
source_port,
|
||||||
|
lifetime,
|
||||||
|
search,
|
||||||
|
backend,
|
||||||
|
)
|
||||||
|
|
|
@ -17,50 +17,44 @@
|
||||||
|
|
||||||
"""Common DNSSEC-related functions and constants."""
|
"""Common DNSSEC-related functions and constants."""
|
||||||
|
|
||||||
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import math
|
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
import base64
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
||||||
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
|
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.node
|
import dns.node
|
||||||
import dns.rdataset
|
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdatatype
|
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
|
import dns.rdataset
|
||||||
|
import dns.rdatatype
|
||||||
import dns.rrset
|
import dns.rrset
|
||||||
|
import dns.transaction
|
||||||
|
import dns.zone
|
||||||
|
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
|
||||||
|
from dns.exception import ( # pylint: disable=W0611
|
||||||
|
AlgorithmKeyMismatch,
|
||||||
|
DeniedByPolicy,
|
||||||
|
UnsupportedAlgorithm,
|
||||||
|
ValidationFailure,
|
||||||
|
)
|
||||||
from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
|
from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
|
||||||
from dns.rdtypes.ANY.CDS import CDS
|
from dns.rdtypes.ANY.CDS import CDS
|
||||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
from dns.rdtypes.ANY.DS import DS
|
from dns.rdtypes.ANY.DS import DS
|
||||||
|
from dns.rdtypes.ANY.NSEC import NSEC, Bitmap
|
||||||
|
from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM
|
||||||
from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
|
from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
|
||||||
from dns.rdtypes.dnskeybase import Flag
|
from dns.rdtypes.dnskeybase import Flag
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedAlgorithm(dns.exception.DNSException):
|
|
||||||
"""The DNSSEC algorithm is not supported."""
|
|
||||||
|
|
||||||
|
|
||||||
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
|
|
||||||
"""The DNSSEC algorithm is not supported for the given key type."""
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationFailure(dns.exception.DNSException):
|
|
||||||
"""The DNSSEC signature is invalid."""
|
|
||||||
|
|
||||||
|
|
||||||
class DeniedByPolicy(dns.exception.DNSException):
|
|
||||||
"""Denied by DNSSEC policy."""
|
|
||||||
|
|
||||||
|
|
||||||
PublicKey = Union[
|
PublicKey = Union[
|
||||||
|
"GenericPublicKey",
|
||||||
"rsa.RSAPublicKey",
|
"rsa.RSAPublicKey",
|
||||||
"ec.EllipticCurvePublicKey",
|
"ec.EllipticCurvePublicKey",
|
||||||
"ed25519.Ed25519PublicKey",
|
"ed25519.Ed25519PublicKey",
|
||||||
|
@ -68,12 +62,15 @@ PublicKey = Union[
|
||||||
]
|
]
|
||||||
|
|
||||||
PrivateKey = Union[
|
PrivateKey = Union[
|
||||||
|
"GenericPrivateKey",
|
||||||
"rsa.RSAPrivateKey",
|
"rsa.RSAPrivateKey",
|
||||||
"ec.EllipticCurvePrivateKey",
|
"ec.EllipticCurvePrivateKey",
|
||||||
"ed25519.Ed25519PrivateKey",
|
"ed25519.Ed25519PrivateKey",
|
||||||
"ed448.Ed448PrivateKey",
|
"ed448.Ed448PrivateKey",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None]
|
||||||
|
|
||||||
|
|
||||||
def algorithm_from_text(text: str) -> Algorithm:
|
def algorithm_from_text(text: str) -> Algorithm:
|
||||||
"""Convert text into a DNSSEC algorithm value.
|
"""Convert text into a DNSSEC algorithm value.
|
||||||
|
@ -308,113 +305,13 @@ def _find_candidate_keys(
|
||||||
return [
|
return [
|
||||||
cast(DNSKEY, rd)
|
cast(DNSKEY, rd)
|
||||||
for rd in rdataset
|
for rd in rdataset
|
||||||
if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag
|
if rd.algorithm == rrsig.algorithm
|
||||||
|
and key_id(rd) == rrsig.key_tag
|
||||||
|
and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1
|
||||||
|
and rd.protocol == 3 # RFC 4034 2.1.2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _is_rsa(algorithm: int) -> bool:
|
|
||||||
return algorithm in (
|
|
||||||
Algorithm.RSAMD5,
|
|
||||||
Algorithm.RSASHA1,
|
|
||||||
Algorithm.RSASHA1NSEC3SHA1,
|
|
||||||
Algorithm.RSASHA256,
|
|
||||||
Algorithm.RSASHA512,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_dsa(algorithm: int) -> bool:
|
|
||||||
return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_ecdsa(algorithm: int) -> bool:
|
|
||||||
return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_eddsa(algorithm: int) -> bool:
|
|
||||||
return algorithm in (Algorithm.ED25519, Algorithm.ED448)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_gost(algorithm: int) -> bool:
|
|
||||||
return algorithm == Algorithm.ECCGOST
|
|
||||||
|
|
||||||
|
|
||||||
def _is_md5(algorithm: int) -> bool:
|
|
||||||
return algorithm == Algorithm.RSAMD5
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sha1(algorithm: int) -> bool:
|
|
||||||
return algorithm in (
|
|
||||||
Algorithm.DSA,
|
|
||||||
Algorithm.RSASHA1,
|
|
||||||
Algorithm.DSANSEC3SHA1,
|
|
||||||
Algorithm.RSASHA1NSEC3SHA1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sha256(algorithm: int) -> bool:
|
|
||||||
return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sha384(algorithm: int) -> bool:
|
|
||||||
return algorithm == Algorithm.ECDSAP384SHA384
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sha512(algorithm: int) -> bool:
|
|
||||||
return algorithm == Algorithm.RSASHA512
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None:
|
|
||||||
"""Ensure algorithm is valid for key type, throwing an exception on
|
|
||||||
mismatch."""
|
|
||||||
if isinstance(key, rsa.RSAPublicKey):
|
|
||||||
if _is_rsa(algorithm):
|
|
||||||
return
|
|
||||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm)
|
|
||||||
if isinstance(key, dsa.DSAPublicKey):
|
|
||||||
if _is_dsa(algorithm):
|
|
||||||
return
|
|
||||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm)
|
|
||||||
if isinstance(key, ec.EllipticCurvePublicKey):
|
|
||||||
if _is_ecdsa(algorithm):
|
|
||||||
return
|
|
||||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm)
|
|
||||||
if isinstance(key, ed25519.Ed25519PublicKey):
|
|
||||||
if algorithm == Algorithm.ED25519:
|
|
||||||
return
|
|
||||||
raise AlgorithmKeyMismatch(
|
|
||||||
'algorithm "%s" not valid for ED25519 key' % algorithm
|
|
||||||
)
|
|
||||||
if isinstance(key, ed448.Ed448PublicKey):
|
|
||||||
if algorithm == Algorithm.ED448:
|
|
||||||
return
|
|
||||||
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm)
|
|
||||||
|
|
||||||
raise TypeError("unsupported key type")
|
|
||||||
|
|
||||||
|
|
||||||
def _make_hash(algorithm: int) -> Any:
|
|
||||||
if _is_md5(algorithm):
|
|
||||||
return hashes.MD5()
|
|
||||||
if _is_sha1(algorithm):
|
|
||||||
return hashes.SHA1()
|
|
||||||
if _is_sha256(algorithm):
|
|
||||||
return hashes.SHA256()
|
|
||||||
if _is_sha384(algorithm):
|
|
||||||
return hashes.SHA384()
|
|
||||||
if _is_sha512(algorithm):
|
|
||||||
return hashes.SHA512()
|
|
||||||
if algorithm == Algorithm.ED25519:
|
|
||||||
return hashes.SHA512()
|
|
||||||
if algorithm == Algorithm.ED448:
|
|
||||||
return hashes.SHAKE256(114)
|
|
||||||
|
|
||||||
raise ValidationFailure("unknown hash for algorithm %u" % algorithm)
|
|
||||||
|
|
||||||
|
|
||||||
def _bytes_to_long(b: bytes) -> int:
|
|
||||||
return int.from_bytes(b, "big")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_rrname_rdataset(
|
def _get_rrname_rdataset(
|
||||||
rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
|
rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
|
||||||
) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
|
) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
|
||||||
|
@ -424,85 +321,13 @@ def _get_rrname_rdataset(
|
||||||
return rrset.name, rrset
|
return rrset.name, rrset
|
||||||
|
|
||||||
|
|
||||||
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None:
|
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
|
||||||
keyptr: bytes
|
public_cls = get_algorithm_cls_from_dnskey(key).public_cls
|
||||||
if _is_rsa(key.algorithm):
|
try:
|
||||||
# we ignore because mypy is confused and thinks key.key is a str for unknown
|
public_key = public_cls.from_dnskey(key)
|
||||||
# reasons.
|
except ValueError:
|
||||||
keyptr = key.key
|
raise ValidationFailure("invalid public key")
|
||||||
(bytes_,) = struct.unpack("!B", keyptr[0:1])
|
public_key.verify(sig, data)
|
||||||
keyptr = keyptr[1:]
|
|
||||||
if bytes_ == 0:
|
|
||||||
(bytes_,) = struct.unpack("!H", keyptr[0:2])
|
|
||||||
keyptr = keyptr[2:]
|
|
||||||
rsa_e = keyptr[0:bytes_]
|
|
||||||
rsa_n = keyptr[bytes_:]
|
|
||||||
try:
|
|
||||||
rsa_public_key = rsa.RSAPublicNumbers(
|
|
||||||
_bytes_to_long(rsa_e), _bytes_to_long(rsa_n)
|
|
||||||
).public_key(default_backend())
|
|
||||||
except ValueError:
|
|
||||||
raise ValidationFailure("invalid public key")
|
|
||||||
rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
|
|
||||||
elif _is_dsa(key.algorithm):
|
|
||||||
keyptr = key.key
|
|
||||||
(t,) = struct.unpack("!B", keyptr[0:1])
|
|
||||||
keyptr = keyptr[1:]
|
|
||||||
octets = 64 + t * 8
|
|
||||||
dsa_q = keyptr[0:20]
|
|
||||||
keyptr = keyptr[20:]
|
|
||||||
dsa_p = keyptr[0:octets]
|
|
||||||
keyptr = keyptr[octets:]
|
|
||||||
dsa_g = keyptr[0:octets]
|
|
||||||
keyptr = keyptr[octets:]
|
|
||||||
dsa_y = keyptr[0:octets]
|
|
||||||
try:
|
|
||||||
dsa_public_key = dsa.DSAPublicNumbers( # type: ignore
|
|
||||||
_bytes_to_long(dsa_y),
|
|
||||||
dsa.DSAParameterNumbers(
|
|
||||||
_bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g)
|
|
||||||
),
|
|
||||||
).public_key(default_backend())
|
|
||||||
except ValueError:
|
|
||||||
raise ValidationFailure("invalid public key")
|
|
||||||
dsa_public_key.verify(sig, data, chosen_hash)
|
|
||||||
elif _is_ecdsa(key.algorithm):
|
|
||||||
keyptr = key.key
|
|
||||||
curve: Any
|
|
||||||
if key.algorithm == Algorithm.ECDSAP256SHA256:
|
|
||||||
curve = ec.SECP256R1()
|
|
||||||
octets = 32
|
|
||||||
else:
|
|
||||||
curve = ec.SECP384R1()
|
|
||||||
octets = 48
|
|
||||||
ecdsa_x = keyptr[0:octets]
|
|
||||||
ecdsa_y = keyptr[octets : octets * 2]
|
|
||||||
try:
|
|
||||||
ecdsa_public_key = ec.EllipticCurvePublicNumbers(
|
|
||||||
curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y)
|
|
||||||
).public_key(default_backend())
|
|
||||||
except ValueError:
|
|
||||||
raise ValidationFailure("invalid public key")
|
|
||||||
ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
|
|
||||||
elif _is_eddsa(key.algorithm):
|
|
||||||
keyptr = key.key
|
|
||||||
loader: Any
|
|
||||||
if key.algorithm == Algorithm.ED25519:
|
|
||||||
loader = ed25519.Ed25519PublicKey
|
|
||||||
else:
|
|
||||||
loader = ed448.Ed448PublicKey
|
|
||||||
try:
|
|
||||||
eddsa_public_key = loader.from_public_bytes(keyptr)
|
|
||||||
except ValueError:
|
|
||||||
raise ValidationFailure("invalid public key")
|
|
||||||
eddsa_public_key.verify(sig, data)
|
|
||||||
elif _is_gost(key.algorithm):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
'algorithm "%s" not supported by dnspython'
|
|
||||||
% algorithm_to_text(key.algorithm)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValidationFailure("unknown algorithm %u" % key.algorithm)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_rrsig(
|
def _validate_rrsig(
|
||||||
|
@ -559,29 +384,13 @@ def _validate_rrsig(
|
||||||
if rrsig.inception > now:
|
if rrsig.inception > now:
|
||||||
raise ValidationFailure("not yet valid")
|
raise ValidationFailure("not yet valid")
|
||||||
|
|
||||||
if _is_dsa(rrsig.algorithm):
|
|
||||||
sig_r = rrsig.signature[1:21]
|
|
||||||
sig_s = rrsig.signature[21:]
|
|
||||||
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
|
|
||||||
elif _is_ecdsa(rrsig.algorithm):
|
|
||||||
if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
|
|
||||||
octets = 32
|
|
||||||
else:
|
|
||||||
octets = 48
|
|
||||||
sig_r = rrsig.signature[0:octets]
|
|
||||||
sig_s = rrsig.signature[octets:]
|
|
||||||
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
|
|
||||||
else:
|
|
||||||
sig = rrsig.signature
|
|
||||||
|
|
||||||
data = _make_rrsig_signature_data(rrset, rrsig, origin)
|
data = _make_rrsig_signature_data(rrset, rrsig, origin)
|
||||||
chosen_hash = _make_hash(rrsig.algorithm)
|
|
||||||
|
|
||||||
for candidate_key in candidate_keys:
|
for candidate_key in candidate_keys:
|
||||||
if not policy.ok_to_validate(candidate_key):
|
if not policy.ok_to_validate(candidate_key):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
_validate_signature(sig, data, candidate_key, chosen_hash)
|
_validate_signature(rrsig.signature, data, candidate_key)
|
||||||
return
|
return
|
||||||
except (InvalidSignature, ValidationFailure):
|
except (InvalidSignature, ValidationFailure):
|
||||||
# this happens on an individual validation failure
|
# this happens on an individual validation failure
|
||||||
|
@ -673,6 +482,7 @@ def _sign(
|
||||||
lifetime: Optional[int] = None,
|
lifetime: Optional[int] = None,
|
||||||
verify: bool = False,
|
verify: bool = False,
|
||||||
policy: Optional[Policy] = None,
|
policy: Optional[Policy] = None,
|
||||||
|
origin: Optional[dns.name.Name] = None,
|
||||||
) -> RRSIG:
|
) -> RRSIG:
|
||||||
"""Sign RRset using private key.
|
"""Sign RRset using private key.
|
||||||
|
|
||||||
|
@ -708,6 +518,10 @@ def _sign(
|
||||||
*policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy,
|
*policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy,
|
||||||
``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624.
|
``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624.
|
||||||
|
|
||||||
|
*origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all
|
||||||
|
names in the rrset (including its owner name) must be absolute; otherwise the
|
||||||
|
specified origin will be used to make names absolute when signing.
|
||||||
|
|
||||||
Raises ``DeniedByPolicy`` if the signature is denied by policy.
|
Raises ``DeniedByPolicy`` if the signature is denied by policy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -735,16 +549,26 @@ def _sign(
|
||||||
if expiration is not None:
|
if expiration is not None:
|
||||||
rrsig_expiration = to_timestamp(expiration)
|
rrsig_expiration = to_timestamp(expiration)
|
||||||
elif lifetime is not None:
|
elif lifetime is not None:
|
||||||
rrsig_expiration = int(time.time()) + lifetime
|
rrsig_expiration = rrsig_inception + lifetime
|
||||||
else:
|
else:
|
||||||
raise ValueError("expiration or lifetime must be specified")
|
raise ValueError("expiration or lifetime must be specified")
|
||||||
|
|
||||||
|
# Derelativize now because we need a correct labels length for the
|
||||||
|
# rrsig_template.
|
||||||
|
if origin is not None:
|
||||||
|
rrname = rrname.derelativize(origin)
|
||||||
|
labels = len(rrname) - 1
|
||||||
|
|
||||||
|
# Adjust labels appropriately for wildcards.
|
||||||
|
if rrname.is_wild():
|
||||||
|
labels -= 1
|
||||||
|
|
||||||
rrsig_template = RRSIG(
|
rrsig_template = RRSIG(
|
||||||
rdclass=rdclass,
|
rdclass=rdclass,
|
||||||
rdtype=dns.rdatatype.RRSIG,
|
rdtype=dns.rdatatype.RRSIG,
|
||||||
type_covered=rdtype,
|
type_covered=rdtype,
|
||||||
algorithm=dnskey.algorithm,
|
algorithm=dnskey.algorithm,
|
||||||
labels=len(rrname) - 1,
|
labels=labels,
|
||||||
original_ttl=original_ttl,
|
original_ttl=original_ttl,
|
||||||
expiration=rrsig_expiration,
|
expiration=rrsig_expiration,
|
||||||
inception=rrsig_inception,
|
inception=rrsig_inception,
|
||||||
|
@ -753,63 +577,18 @@ def _sign(
|
||||||
signature=b"",
|
signature=b"",
|
||||||
)
|
)
|
||||||
|
|
||||||
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template)
|
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
|
||||||
chosen_hash = _make_hash(rrsig_template.algorithm)
|
|
||||||
signature = None
|
|
||||||
|
|
||||||
if isinstance(private_key, rsa.RSAPrivateKey):
|
if isinstance(private_key, GenericPrivateKey):
|
||||||
if not _is_rsa(dnskey.algorithm):
|
signing_key = private_key
|
||||||
raise ValueError("Invalid DNSKEY algorithm for RSA key")
|
|
||||||
signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash)
|
|
||||||
if verify:
|
|
||||||
private_key.public_key().verify(
|
|
||||||
signature, data, padding.PKCS1v15(), chosen_hash
|
|
||||||
)
|
|
||||||
elif isinstance(private_key, dsa.DSAPrivateKey):
|
|
||||||
if not _is_dsa(dnskey.algorithm):
|
|
||||||
raise ValueError("Invalid DNSKEY algorithm for DSA key")
|
|
||||||
public_dsa_key = private_key.public_key()
|
|
||||||
if public_dsa_key.key_size > 1024:
|
|
||||||
raise ValueError("DSA key size overflow")
|
|
||||||
der_signature = private_key.sign(data, chosen_hash)
|
|
||||||
if verify:
|
|
||||||
public_dsa_key.verify(der_signature, data, chosen_hash)
|
|
||||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
|
||||||
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
|
|
||||||
octets = 20
|
|
||||||
signature = (
|
|
||||||
struct.pack("!B", dsa_t)
|
|
||||||
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
|
|
||||||
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
|
|
||||||
)
|
|
||||||
elif isinstance(private_key, ec.EllipticCurvePrivateKey):
|
|
||||||
if not _is_ecdsa(dnskey.algorithm):
|
|
||||||
raise ValueError("Invalid DNSKEY algorithm for EC key")
|
|
||||||
der_signature = private_key.sign(data, ec.ECDSA(chosen_hash))
|
|
||||||
if verify:
|
|
||||||
private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash))
|
|
||||||
if dnskey.algorithm == Algorithm.ECDSAP256SHA256:
|
|
||||||
octets = 32
|
|
||||||
else:
|
|
||||||
octets = 48
|
|
||||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
|
||||||
signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes(
|
|
||||||
dsa_s, length=octets, byteorder="big"
|
|
||||||
)
|
|
||||||
elif isinstance(private_key, ed25519.Ed25519PrivateKey):
|
|
||||||
if dnskey.algorithm != Algorithm.ED25519:
|
|
||||||
raise ValueError("Invalid DNSKEY algorithm for ED25519 key")
|
|
||||||
signature = private_key.sign(data)
|
|
||||||
if verify:
|
|
||||||
private_key.public_key().verify(signature, data)
|
|
||||||
elif isinstance(private_key, ed448.Ed448PrivateKey):
|
|
||||||
if dnskey.algorithm != Algorithm.ED448:
|
|
||||||
raise ValueError("Invalid DNSKEY algorithm for ED448 key")
|
|
||||||
signature = private_key.sign(data)
|
|
||||||
if verify:
|
|
||||||
private_key.public_key().verify(signature, data)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unsupported key algorithm")
|
try:
|
||||||
|
private_cls = get_algorithm_cls_from_dnskey(dnskey)
|
||||||
|
signing_key = private_cls(key=private_key)
|
||||||
|
except UnsupportedAlgorithm:
|
||||||
|
raise TypeError("Unsupported key algorithm")
|
||||||
|
|
||||||
|
signature = signing_key.sign(data, verify)
|
||||||
|
|
||||||
return cast(RRSIG, rrsig_template.replace(signature=signature))
|
return cast(RRSIG, rrsig_template.replace(signature=signature))
|
||||||
|
|
||||||
|
@ -858,9 +637,12 @@ def _make_rrsig_signature_data(
|
||||||
raise ValidationFailure("relative RR name without an origin specified")
|
raise ValidationFailure("relative RR name without an origin specified")
|
||||||
rrname = rrname.derelativize(origin)
|
rrname = rrname.derelativize(origin)
|
||||||
|
|
||||||
if len(rrname) - 1 < rrsig.labels:
|
name_len = len(rrname)
|
||||||
|
if rrname.is_wild() and rrsig.labels != name_len - 2:
|
||||||
|
raise ValidationFailure("wild owner name has wrong label length")
|
||||||
|
if name_len - 1 < rrsig.labels:
|
||||||
raise ValidationFailure("owner name longer than RRSIG labels")
|
raise ValidationFailure("owner name longer than RRSIG labels")
|
||||||
elif rrsig.labels < len(rrname) - 1:
|
elif rrsig.labels < name_len - 1:
|
||||||
suffix = rrname.split(rrsig.labels + 1)[1]
|
suffix = rrname.split(rrsig.labels + 1)[1]
|
||||||
rrname = dns.name.from_text("*", suffix)
|
rrname = dns.name.from_text("*", suffix)
|
||||||
rrnamebuf = rrname.to_digestable()
|
rrnamebuf = rrname.to_digestable()
|
||||||
|
@ -884,9 +666,8 @@ def _make_dnskey(
|
||||||
) -> DNSKEY:
|
) -> DNSKEY:
|
||||||
"""Convert a public key to DNSKEY Rdata
|
"""Convert a public key to DNSKEY Rdata
|
||||||
|
|
||||||
*public_key*, the public key to convert, a
|
*public_key*, a ``PublicKey`` (``GenericPublicKey`` or
|
||||||
``cryptography.hazmat.primitives.asymmetric`` public key class applicable
|
``cryptography.hazmat.primitives.asymmetric``) to convert.
|
||||||
for DNSSEC.
|
|
||||||
|
|
||||||
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||||
|
|
||||||
|
@ -902,72 +683,13 @@ def _make_dnskey(
|
||||||
Return DNSKEY ``Rdata``.
|
Return DNSKEY ``Rdata``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes:
|
algorithm = Algorithm.make(algorithm)
|
||||||
"""Encode a public key per RFC 3110, section 2."""
|
|
||||||
pn = public_key.public_numbers()
|
|
||||||
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
|
|
||||||
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
|
|
||||||
if _exp_len > 255:
|
|
||||||
exp_header = b"\0" + struct.pack("!H", _exp_len)
|
|
||||||
else:
|
|
||||||
exp_header = struct.pack("!B", _exp_len)
|
|
||||||
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
|
|
||||||
raise ValueError("unsupported RSA key length")
|
|
||||||
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
|
|
||||||
|
|
||||||
def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes:
|
if isinstance(public_key, GenericPublicKey):
|
||||||
"""Encode a public key per RFC 2536, section 2."""
|
return public_key.to_dnskey(flags=flags, protocol=protocol)
|
||||||
pn = public_key.public_numbers()
|
|
||||||
dsa_t = (public_key.key_size // 8 - 64) // 8
|
|
||||||
if dsa_t > 8:
|
|
||||||
raise ValueError("unsupported DSA key size")
|
|
||||||
octets = 64 + dsa_t * 8
|
|
||||||
res = struct.pack("!B", dsa_t)
|
|
||||||
res += pn.parameter_numbers.q.to_bytes(20, "big")
|
|
||||||
res += pn.parameter_numbers.p.to_bytes(octets, "big")
|
|
||||||
res += pn.parameter_numbers.g.to_bytes(octets, "big")
|
|
||||||
res += pn.y.to_bytes(octets, "big")
|
|
||||||
return res
|
|
||||||
|
|
||||||
def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes:
|
|
||||||
"""Encode a public key per RFC 6605, section 4."""
|
|
||||||
pn = public_key.public_numbers()
|
|
||||||
if isinstance(public_key.curve, ec.SECP256R1):
|
|
||||||
return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big")
|
|
||||||
elif isinstance(public_key.curve, ec.SECP384R1):
|
|
||||||
return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big")
|
|
||||||
else:
|
|
||||||
raise ValueError("unsupported ECDSA curve")
|
|
||||||
|
|
||||||
the_algorithm = Algorithm.make(algorithm)
|
|
||||||
|
|
||||||
_ensure_algorithm_key_combination(the_algorithm, public_key)
|
|
||||||
|
|
||||||
if isinstance(public_key, rsa.RSAPublicKey):
|
|
||||||
key_bytes = encode_rsa_public_key(public_key)
|
|
||||||
elif isinstance(public_key, dsa.DSAPublicKey):
|
|
||||||
key_bytes = encode_dsa_public_key(public_key)
|
|
||||||
elif isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
||||||
key_bytes = encode_ecdsa_public_key(public_key)
|
|
||||||
elif isinstance(public_key, ed25519.Ed25519PublicKey):
|
|
||||||
key_bytes = public_key.public_bytes(
|
|
||||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
|
||||||
)
|
|
||||||
elif isinstance(public_key, ed448.Ed448PublicKey):
|
|
||||||
key_bytes = public_key.public_bytes(
|
|
||||||
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("unsupported key algorithm")
|
public_cls = get_algorithm_cls(algorithm).public_cls
|
||||||
|
return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol)
|
||||||
return DNSKEY(
|
|
||||||
rdclass=dns.rdataclass.IN,
|
|
||||||
rdtype=dns.rdatatype.DNSKEY,
|
|
||||||
flags=flags,
|
|
||||||
protocol=protocol,
|
|
||||||
algorithm=the_algorithm,
|
|
||||||
key=key_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_cdnskey(
|
def _make_cdnskey(
|
||||||
|
@ -1216,23 +938,252 @@ def dnskey_rdataset_to_cdnskey_rdataset(
|
||||||
return dns.rdataset.from_rdata_list(rdataset.ttl, res)
|
return dns.rdataset.from_rdata_list(rdataset.ttl, res)
|
||||||
|
|
||||||
|
|
||||||
|
def default_rrset_signer(
|
||||||
|
txn: dns.transaction.Transaction,
|
||||||
|
rrset: dns.rrset.RRset,
|
||||||
|
signer: dns.name.Name,
|
||||||
|
ksks: List[Tuple[PrivateKey, DNSKEY]],
|
||||||
|
zsks: List[Tuple[PrivateKey, DNSKEY]],
|
||||||
|
inception: Optional[Union[datetime, str, int, float]] = None,
|
||||||
|
expiration: Optional[Union[datetime, str, int, float]] = None,
|
||||||
|
lifetime: Optional[int] = None,
|
||||||
|
policy: Optional[Policy] = None,
|
||||||
|
origin: Optional[dns.name.Name] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Default RRset signer"""
|
||||||
|
|
||||||
|
if rrset.rdtype in set(
|
||||||
|
[
|
||||||
|
dns.rdatatype.RdataType.DNSKEY,
|
||||||
|
dns.rdatatype.RdataType.CDS,
|
||||||
|
dns.rdatatype.RdataType.CDNSKEY,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
keys = ksks
|
||||||
|
else:
|
||||||
|
keys = zsks
|
||||||
|
|
||||||
|
for private_key, dnskey in keys:
|
||||||
|
rrsig = dns.dnssec.sign(
|
||||||
|
rrset=rrset,
|
||||||
|
private_key=private_key,
|
||||||
|
dnskey=dnskey,
|
||||||
|
inception=inception,
|
||||||
|
expiration=expiration,
|
||||||
|
lifetime=lifetime,
|
||||||
|
signer=signer,
|
||||||
|
policy=policy,
|
||||||
|
origin=origin,
|
||||||
|
)
|
||||||
|
txn.add(rrset.name, rrset.ttl, rrsig)
|
||||||
|
|
||||||
|
|
||||||
|
def sign_zone(
|
||||||
|
zone: dns.zone.Zone,
|
||||||
|
txn: Optional[dns.transaction.Transaction] = None,
|
||||||
|
keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None,
|
||||||
|
add_dnskey: bool = True,
|
||||||
|
dnskey_ttl: Optional[int] = None,
|
||||||
|
inception: Optional[Union[datetime, str, int, float]] = None,
|
||||||
|
expiration: Optional[Union[datetime, str, int, float]] = None,
|
||||||
|
lifetime: Optional[int] = None,
|
||||||
|
nsec3: Optional[NSEC3PARAM] = None,
|
||||||
|
rrset_signer: Optional[RRsetSigner] = None,
|
||||||
|
policy: Optional[Policy] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Sign zone.
|
||||||
|
|
||||||
|
*zone*, a ``dns.zone.Zone``, the zone to sign.
|
||||||
|
|
||||||
|
*txn*, a ``dns.transaction.Transaction``, an optional transaction to use for
|
||||||
|
signing.
|
||||||
|
|
||||||
|
*keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK
|
||||||
|
roles are assigned automatically if the SEP flag is used, otherwise all RRsets are
|
||||||
|
signed by all keys.
|
||||||
|
|
||||||
|
*add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are
|
||||||
|
automatically added to the zone on signing.
|
||||||
|
|
||||||
|
*dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL
|
||||||
|
of the existing DNSKEY RRset used or the TTL of the SOA RRset.
|
||||||
|
|
||||||
|
*inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
|
||||||
|
inception time. If ``None``, the current time is used. If a ``str``, the format is
|
||||||
|
"YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text
|
||||||
|
form; this is the same the RRSIG rdata's text form. Values of type `int` or `float`
|
||||||
|
are interpreted as seconds since the UNIX epoch.
|
||||||
|
|
||||||
|
*expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
|
||||||
|
expiration time. If ``None``, the expiration time will be the inception time plus
|
||||||
|
the value of the *lifetime* parameter. See the description of *inception* above for
|
||||||
|
how the various parameter types are interpreted.
|
||||||
|
|
||||||
|
*lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This
|
||||||
|
parameter is only meaningful if *expiration* is ``None``.
|
||||||
|
|
||||||
|
*nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet
|
||||||
|
implemented.
|
||||||
|
|
||||||
|
*rrset_signer*, a ``Callable``, an optional function for signing RRsets. The
|
||||||
|
function requires two arguments: transaction and RRset. If the not specified,
|
||||||
|
``dns.dnssec.default_rrset_signer`` will be used.
|
||||||
|
|
||||||
|
Returns ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ksks = []
|
||||||
|
zsks = []
|
||||||
|
|
||||||
|
# if we have both KSKs and ZSKs, split by SEP flag. if not, sign all
|
||||||
|
# records with all keys
|
||||||
|
if keys:
|
||||||
|
for key in keys:
|
||||||
|
if key[1].flags & Flag.SEP:
|
||||||
|
ksks.append(key)
|
||||||
|
else:
|
||||||
|
zsks.append(key)
|
||||||
|
if not ksks:
|
||||||
|
ksks = keys
|
||||||
|
if not zsks:
|
||||||
|
zsks = keys
|
||||||
|
else:
|
||||||
|
keys = []
|
||||||
|
|
||||||
|
if txn:
|
||||||
|
cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn)
|
||||||
|
else:
|
||||||
|
cm = zone.writer()
|
||||||
|
|
||||||
|
with cm as _txn:
|
||||||
|
if add_dnskey:
|
||||||
|
if dnskey_ttl is None:
|
||||||
|
dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY)
|
||||||
|
if dnskey:
|
||||||
|
dnskey_ttl = dnskey.ttl
|
||||||
|
else:
|
||||||
|
soa = _txn.get(zone.origin, dns.rdatatype.SOA)
|
||||||
|
dnskey_ttl = soa.ttl
|
||||||
|
for _, dnskey in keys:
|
||||||
|
_txn.add(zone.origin, dnskey_ttl, dnskey)
|
||||||
|
|
||||||
|
if nsec3:
|
||||||
|
raise NotImplementedError("Signing with NSEC3 not yet implemented")
|
||||||
|
else:
|
||||||
|
_rrset_signer = rrset_signer or functools.partial(
|
||||||
|
default_rrset_signer,
|
||||||
|
signer=zone.origin,
|
||||||
|
ksks=ksks,
|
||||||
|
zsks=zsks,
|
||||||
|
inception=inception,
|
||||||
|
expiration=expiration,
|
||||||
|
lifetime=lifetime,
|
||||||
|
policy=policy,
|
||||||
|
origin=zone.origin,
|
||||||
|
)
|
||||||
|
return _sign_zone_nsec(zone, _txn, _rrset_signer)
|
||||||
|
|
||||||
|
|
||||||
|
def _sign_zone_nsec(
|
||||||
|
zone: dns.zone.Zone,
|
||||||
|
txn: dns.transaction.Transaction,
|
||||||
|
rrset_signer: Optional[RRsetSigner] = None,
|
||||||
|
) -> None:
|
||||||
|
"""NSEC zone signer"""
|
||||||
|
|
||||||
|
def _txn_add_nsec(
|
||||||
|
txn: dns.transaction.Transaction,
|
||||||
|
name: dns.name.Name,
|
||||||
|
next_secure: Optional[dns.name.Name],
|
||||||
|
rdclass: dns.rdataclass.RdataClass,
|
||||||
|
ttl: int,
|
||||||
|
rrset_signer: Optional[RRsetSigner] = None,
|
||||||
|
) -> None:
|
||||||
|
"""NSEC zone signer helper"""
|
||||||
|
mandatory_types = set(
|
||||||
|
[dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC]
|
||||||
|
)
|
||||||
|
node = txn.get_node(name)
|
||||||
|
if node and next_secure:
|
||||||
|
types = (
|
||||||
|
set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types
|
||||||
|
)
|
||||||
|
windows = Bitmap.from_rdtypes(list(types))
|
||||||
|
rrset = dns.rrset.from_rdata(
|
||||||
|
name,
|
||||||
|
ttl,
|
||||||
|
NSEC(
|
||||||
|
rdclass=rdclass,
|
||||||
|
rdtype=dns.rdatatype.RdataType.NSEC,
|
||||||
|
next=next_secure,
|
||||||
|
windows=windows,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
txn.add(rrset)
|
||||||
|
if rrset_signer:
|
||||||
|
rrset_signer(txn, rrset)
|
||||||
|
|
||||||
|
rrsig_ttl = zone.get_soa().minimum
|
||||||
|
delegation = None
|
||||||
|
last_secure = None
|
||||||
|
|
||||||
|
for name in sorted(txn.iterate_names()):
|
||||||
|
if delegation and name.is_subdomain(delegation):
|
||||||
|
# names below delegations are not secure
|
||||||
|
continue
|
||||||
|
elif txn.get(name, dns.rdatatype.NS) and name != zone.origin:
|
||||||
|
# inside delegation
|
||||||
|
delegation = name
|
||||||
|
else:
|
||||||
|
# outside delegation
|
||||||
|
delegation = None
|
||||||
|
|
||||||
|
if rrset_signer:
|
||||||
|
node = txn.get_node(name)
|
||||||
|
if node:
|
||||||
|
for rdataset in node.rdatasets:
|
||||||
|
if rdataset.rdtype == dns.rdatatype.RRSIG:
|
||||||
|
# do not sign RRSIGs
|
||||||
|
continue
|
||||||
|
elif delegation and rdataset.rdtype != dns.rdatatype.DS:
|
||||||
|
# do not sign delegations except DS records
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset)
|
||||||
|
rrset_signer(txn, rrset)
|
||||||
|
|
||||||
|
# We need "is not None" as the empty name is False because its length is 0.
|
||||||
|
if last_secure is not None:
|
||||||
|
_txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer)
|
||||||
|
last_secure = name
|
||||||
|
|
||||||
|
if last_secure:
|
||||||
|
_txn_add_nsec(
|
||||||
|
txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _need_pyca(*args, **kwargs):
|
def _need_pyca(*args, **kwargs):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"DNSSEC validation requires " + "python cryptography"
|
"DNSSEC validation requires python cryptography"
|
||||||
) # pragma: no cover
|
) # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from cryptography.exceptions import InvalidSignature
|
from cryptography.exceptions import InvalidSignature
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding
|
from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611
|
||||||
from cryptography.hazmat.primitives.asymmetric import utils
|
from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611
|
||||||
from cryptography.hazmat.primitives.asymmetric import dsa
|
from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611
|
||||||
from cryptography.hazmat.primitives.asymmetric import ec
|
ed25519,
|
||||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
)
|
||||||
from cryptography.hazmat.primitives.asymmetric import ed448
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from dns.dnssecalgs import ( # pylint: disable=C0412
|
||||||
|
get_algorithm_cls,
|
||||||
|
get_algorithm_cls_from_dnskey,
|
||||||
|
)
|
||||||
|
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
validate = _need_pyca
|
validate = _need_pyca
|
||||||
validate_rrsig = _need_pyca
|
validate_rrsig = _need_pyca
|
||||||
|
|
121
lib/dns/dnssecalgs/__init__.py
Normal file
121
lib/dns/dnssecalgs/__init__.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
from typing import Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import dns.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dns.dnssecalgs.base import GenericPrivateKey
|
||||||
|
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
|
||||||
|
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
|
||||||
|
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
|
||||||
|
from dns.dnssecalgs.rsa import (
|
||||||
|
PrivateRSAMD5,
|
||||||
|
PrivateRSASHA1,
|
||||||
|
PrivateRSASHA1NSEC3SHA1,
|
||||||
|
PrivateRSASHA256,
|
||||||
|
PrivateRSASHA512,
|
||||||
|
)
|
||||||
|
|
||||||
|
_have_cryptography = True
|
||||||
|
except ImportError:
|
||||||
|
_have_cryptography = False
|
||||||
|
|
||||||
|
from dns.dnssectypes import Algorithm
|
||||||
|
from dns.exception import UnsupportedAlgorithm
|
||||||
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
|
|
||||||
|
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
|
||||||
|
|
||||||
|
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
|
||||||
|
if _have_cryptography:
|
||||||
|
algorithms.update(
|
||||||
|
{
|
||||||
|
(Algorithm.RSAMD5, None): PrivateRSAMD5,
|
||||||
|
(Algorithm.DSA, None): PrivateDSA,
|
||||||
|
(Algorithm.RSASHA1, None): PrivateRSASHA1,
|
||||||
|
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
|
||||||
|
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
|
||||||
|
(Algorithm.RSASHA256, None): PrivateRSASHA256,
|
||||||
|
(Algorithm.RSASHA512, None): PrivateRSASHA512,
|
||||||
|
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
|
||||||
|
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
|
||||||
|
(Algorithm.ED25519, None): PrivateED25519,
|
||||||
|
(Algorithm.ED448, None): PrivateED448,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_algorithm_cls(
|
||||||
|
algorithm: Union[int, str], prefix: AlgorithmPrefix = None
|
||||||
|
) -> Type[GenericPrivateKey]:
|
||||||
|
"""Get Private Key class from Algorithm.
|
||||||
|
|
||||||
|
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||||
|
|
||||||
|
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
|
||||||
|
|
||||||
|
Returns a ``dns.dnssecalgs.GenericPrivateKey``
|
||||||
|
"""
|
||||||
|
algorithm = Algorithm.make(algorithm)
|
||||||
|
cls = algorithms.get((algorithm, prefix))
|
||||||
|
if cls:
|
||||||
|
return cls
|
||||||
|
raise UnsupportedAlgorithm(
|
||||||
|
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
|
||||||
|
"""Get Private Key class from DNSKEY.
|
||||||
|
|
||||||
|
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
|
||||||
|
|
||||||
|
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
|
||||||
|
|
||||||
|
Returns a ``dns.dnssecalgs.GenericPrivateKey``
|
||||||
|
"""
|
||||||
|
prefix: AlgorithmPrefix = None
|
||||||
|
if dnskey.algorithm == Algorithm.PRIVATEDNS:
|
||||||
|
prefix, _ = dns.name.from_wire(dnskey.key, 0)
|
||||||
|
elif dnskey.algorithm == Algorithm.PRIVATEOID:
|
||||||
|
length = int(dnskey.key[0])
|
||||||
|
prefix = dnskey.key[0 : length + 1]
|
||||||
|
return get_algorithm_cls(dnskey.algorithm, prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def register_algorithm_cls(
|
||||||
|
algorithm: Union[int, str],
|
||||||
|
algorithm_cls: Type[GenericPrivateKey],
|
||||||
|
name: Optional[Union[dns.name.Name, str]] = None,
|
||||||
|
oid: Optional[bytes] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register Algorithm Private Key class.
|
||||||
|
|
||||||
|
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
|
||||||
|
|
||||||
|
*algorithm_cls*: A `GenericPrivateKey` class.
|
||||||
|
|
||||||
|
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
|
||||||
|
|
||||||
|
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
|
||||||
|
|
||||||
|
Raises ``ValueError`` if a name or oid is specified incorrectly.
|
||||||
|
"""
|
||||||
|
if not issubclass(algorithm_cls, GenericPrivateKey):
|
||||||
|
raise TypeError("Invalid algorithm class")
|
||||||
|
algorithm = Algorithm.make(algorithm)
|
||||||
|
prefix: AlgorithmPrefix = None
|
||||||
|
if algorithm == Algorithm.PRIVATEDNS:
|
||||||
|
if name is None:
|
||||||
|
raise ValueError("Name required for PRIVATEDNS algorithms")
|
||||||
|
if isinstance(name, str):
|
||||||
|
name = dns.name.from_text(name)
|
||||||
|
prefix = name
|
||||||
|
elif algorithm == Algorithm.PRIVATEOID:
|
||||||
|
if oid is None:
|
||||||
|
raise ValueError("OID required for PRIVATEOID algorithms")
|
||||||
|
prefix = bytes([len(oid)]) + oid
|
||||||
|
elif name:
|
||||||
|
raise ValueError("Name only supported for PRIVATEDNS algorithm")
|
||||||
|
elif oid:
|
||||||
|
raise ValueError("OID only supported for PRIVATEOID algorithm")
|
||||||
|
algorithms[(algorithm, prefix)] = algorithm_cls
|
84
lib/dns/dnssecalgs/base.py
Normal file
84
lib/dns/dnssecalgs/base.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
import dns.rdataclass
|
||||||
|
import dns.rdatatype
|
||||||
|
from dns.dnssectypes import Algorithm
|
||||||
|
from dns.exception import AlgorithmKeyMismatch
|
||||||
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
|
from dns.rdtypes.dnskeybase import Flag
|
||||||
|
|
||||||
|
|
||||||
|
class GenericPublicKey(ABC):
|
||||||
|
algorithm: Algorithm
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, key: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def verify(self, signature: bytes, data: bytes) -> None:
|
||||||
|
"""Verify signed DNSSEC data"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode_key_bytes(self) -> bytes:
|
||||||
|
"""Encode key as bytes for DNSKEY"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
|
||||||
|
if key.algorithm != cls.algorithm:
|
||||||
|
raise AlgorithmKeyMismatch
|
||||||
|
|
||||||
|
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
|
||||||
|
"""Return public key as DNSKEY"""
|
||||||
|
return DNSKEY(
|
||||||
|
rdclass=dns.rdataclass.IN,
|
||||||
|
rdtype=dns.rdatatype.DNSKEY,
|
||||||
|
flags=flags,
|
||||||
|
protocol=protocol,
|
||||||
|
algorithm=self.algorithm,
|
||||||
|
key=self.encode_key_bytes(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
|
||||||
|
"""Create public key from DNSKEY"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
|
||||||
|
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
|
||||||
|
in RFC 5280"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_pem(self) -> bytes:
|
||||||
|
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
|
||||||
|
in RFC 5280"""
|
||||||
|
|
||||||
|
|
||||||
|
class GenericPrivateKey(ABC):
|
||||||
|
public_cls: Type[GenericPublicKey]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, key: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||||
|
"""Sign DNSSEC data"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def public_key(self) -> "GenericPublicKey":
|
||||||
|
"""Return public key instance"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_pem(
|
||||||
|
cls, private_pem: bytes, password: Optional[bytes] = None
|
||||||
|
) -> "GenericPrivateKey":
|
||||||
|
"""Create private key from PEM-encoded PKCS#8"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_pem(self, password: Optional[bytes] = None) -> bytes:
|
||||||
|
"""Return private key as PEM-encoded PKCS#8"""
|
68
lib/dns/dnssecalgs/cryptography.py
Normal file
68
lib/dns/dnssecalgs/cryptography.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
|
||||||
|
from dns.exception import AlgorithmKeyMismatch
|
||||||
|
|
||||||
|
|
||||||
|
class CryptographyPublicKey(GenericPublicKey):
|
||||||
|
key: Any = None
|
||||||
|
key_cls: Any = None
|
||||||
|
|
||||||
|
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
|
||||||
|
if self.key_cls is None:
|
||||||
|
raise TypeError("Undefined private key class")
|
||||||
|
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
|
||||||
|
key, self.key_cls
|
||||||
|
):
|
||||||
|
raise AlgorithmKeyMismatch
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
|
||||||
|
key = serialization.load_pem_public_key(public_pem)
|
||||||
|
return cls(key=key)
|
||||||
|
|
||||||
|
def to_pem(self) -> bytes:
|
||||||
|
return self.key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CryptographyPrivateKey(GenericPrivateKey):
|
||||||
|
key: Any = None
|
||||||
|
key_cls: Any = None
|
||||||
|
public_cls: Type[CryptographyPublicKey]
|
||||||
|
|
||||||
|
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
|
||||||
|
if self.key_cls is None:
|
||||||
|
raise TypeError("Undefined private key class")
|
||||||
|
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
|
||||||
|
key, self.key_cls
|
||||||
|
):
|
||||||
|
raise AlgorithmKeyMismatch
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def public_key(self) -> "CryptographyPublicKey":
|
||||||
|
return self.public_cls(key=self.key.public_key())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pem(
|
||||||
|
cls, private_pem: bytes, password: Optional[bytes] = None
|
||||||
|
) -> "GenericPrivateKey":
|
||||||
|
key = serialization.load_pem_private_key(private_pem, password=password)
|
||||||
|
return cls(key=key)
|
||||||
|
|
||||||
|
def to_pem(self, password: Optional[bytes] = None) -> bytes:
|
||||||
|
encryption_algorithm: serialization.KeySerializationEncryption
|
||||||
|
if password:
|
||||||
|
encryption_algorithm = serialization.BestAvailableEncryption(password)
|
||||||
|
else:
|
||||||
|
encryption_algorithm = serialization.NoEncryption()
|
||||||
|
return self.key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=encryption_algorithm,
|
||||||
|
)
|
101
lib/dns/dnssecalgs/dsa.py
Normal file
101
lib/dns/dnssecalgs/dsa.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import dsa, utils
|
||||||
|
|
||||||
|
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||||
|
from dns.dnssectypes import Algorithm
|
||||||
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
|
|
||||||
|
|
||||||
|
class PublicDSA(CryptographyPublicKey):
|
||||||
|
key: dsa.DSAPublicKey
|
||||||
|
key_cls = dsa.DSAPublicKey
|
||||||
|
algorithm = Algorithm.DSA
|
||||||
|
chosen_hash = hashes.SHA1()
|
||||||
|
|
||||||
|
def verify(self, signature: bytes, data: bytes) -> None:
|
||||||
|
sig_r = signature[1:21]
|
||||||
|
sig_s = signature[21:]
|
||||||
|
sig = utils.encode_dss_signature(
|
||||||
|
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
|
||||||
|
)
|
||||||
|
self.key.verify(sig, data, self.chosen_hash)
|
||||||
|
|
||||||
|
def encode_key_bytes(self) -> bytes:
|
||||||
|
"""Encode a public key per RFC 2536, section 2."""
|
||||||
|
pn = self.key.public_numbers()
|
||||||
|
dsa_t = (self.key.key_size // 8 - 64) // 8
|
||||||
|
if dsa_t > 8:
|
||||||
|
raise ValueError("unsupported DSA key size")
|
||||||
|
octets = 64 + dsa_t * 8
|
||||||
|
res = struct.pack("!B", dsa_t)
|
||||||
|
res += pn.parameter_numbers.q.to_bytes(20, "big")
|
||||||
|
res += pn.parameter_numbers.p.to_bytes(octets, "big")
|
||||||
|
res += pn.parameter_numbers.g.to_bytes(octets, "big")
|
||||||
|
res += pn.y.to_bytes(octets, "big")
|
||||||
|
return res
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
|
||||||
|
cls._ensure_algorithm_key_combination(key)
|
||||||
|
keyptr = key.key
|
||||||
|
(t,) = struct.unpack("!B", keyptr[0:1])
|
||||||
|
keyptr = keyptr[1:]
|
||||||
|
octets = 64 + t * 8
|
||||||
|
dsa_q = keyptr[0:20]
|
||||||
|
keyptr = keyptr[20:]
|
||||||
|
dsa_p = keyptr[0:octets]
|
||||||
|
keyptr = keyptr[octets:]
|
||||||
|
dsa_g = keyptr[0:octets]
|
||||||
|
keyptr = keyptr[octets:]
|
||||||
|
dsa_y = keyptr[0:octets]
|
||||||
|
return cls(
|
||||||
|
key=dsa.DSAPublicNumbers( # type: ignore
|
||||||
|
int.from_bytes(dsa_y, "big"),
|
||||||
|
dsa.DSAParameterNumbers(
|
||||||
|
int.from_bytes(dsa_p, "big"),
|
||||||
|
int.from_bytes(dsa_q, "big"),
|
||||||
|
int.from_bytes(dsa_g, "big"),
|
||||||
|
),
|
||||||
|
).public_key(default_backend()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateDSA(CryptographyPrivateKey):
|
||||||
|
key: dsa.DSAPrivateKey
|
||||||
|
key_cls = dsa.DSAPrivateKey
|
||||||
|
public_cls = PublicDSA
|
||||||
|
|
||||||
|
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||||
|
"""Sign using a private key per RFC 2536, section 3."""
|
||||||
|
public_dsa_key = self.key.public_key()
|
||||||
|
if public_dsa_key.key_size > 1024:
|
||||||
|
raise ValueError("DSA key size overflow")
|
||||||
|
der_signature = self.key.sign(data, self.public_cls.chosen_hash)
|
||||||
|
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||||
|
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
|
||||||
|
octets = 20
|
||||||
|
signature = (
|
||||||
|
struct.pack("!B", dsa_t)
|
||||||
|
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
|
||||||
|
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
|
||||||
|
)
|
||||||
|
if verify:
|
||||||
|
self.public_key().verify(signature, data)
|
||||||
|
return signature
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls, key_size: int) -> "PrivateDSA":
|
||||||
|
return cls(
|
||||||
|
key=dsa.generate_private_key(key_size=key_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PublicDSANSEC3SHA1(PublicDSA):
|
||||||
|
algorithm = Algorithm.DSANSEC3SHA1
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateDSANSEC3SHA1(PrivateDSA):
|
||||||
|
public_cls = PublicDSANSEC3SHA1
|
89
lib/dns/dnssecalgs/ecdsa.py
Normal file
89
lib/dns/dnssecalgs/ecdsa.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ec, utils
|
||||||
|
|
||||||
|
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||||
|
from dns.dnssectypes import Algorithm
|
||||||
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
|
|
||||||
|
|
||||||
|
class PublicECDSA(CryptographyPublicKey):
|
||||||
|
key: ec.EllipticCurvePublicKey
|
||||||
|
key_cls = ec.EllipticCurvePublicKey
|
||||||
|
algorithm: Algorithm
|
||||||
|
chosen_hash: hashes.HashAlgorithm
|
||||||
|
curve: ec.EllipticCurve
|
||||||
|
octets: int
|
||||||
|
|
||||||
|
def verify(self, signature: bytes, data: bytes) -> None:
|
||||||
|
sig_r = signature[0 : self.octets]
|
||||||
|
sig_s = signature[self.octets :]
|
||||||
|
sig = utils.encode_dss_signature(
|
||||||
|
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
|
||||||
|
)
|
||||||
|
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
|
||||||
|
|
||||||
|
def encode_key_bytes(self) -> bytes:
|
||||||
|
"""Encode a public key per RFC 6605, section 4."""
|
||||||
|
pn = self.key.public_numbers()
|
||||||
|
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
|
||||||
|
cls._ensure_algorithm_key_combination(key)
|
||||||
|
ecdsa_x = key.key[0 : cls.octets]
|
||||||
|
ecdsa_y = key.key[cls.octets : cls.octets * 2]
|
||||||
|
return cls(
|
||||||
|
key=ec.EllipticCurvePublicNumbers(
|
||||||
|
curve=cls.curve,
|
||||||
|
x=int.from_bytes(ecdsa_x, "big"),
|
||||||
|
y=int.from_bytes(ecdsa_y, "big"),
|
||||||
|
).public_key(default_backend()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateECDSA(CryptographyPrivateKey):
|
||||||
|
key: ec.EllipticCurvePrivateKey
|
||||||
|
key_cls = ec.EllipticCurvePrivateKey
|
||||||
|
public_cls = PublicECDSA
|
||||||
|
|
||||||
|
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||||
|
"""Sign using a private key per RFC 6605, section 4."""
|
||||||
|
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
|
||||||
|
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||||
|
signature = int.to_bytes(
|
||||||
|
dsa_r, length=self.public_cls.octets, byteorder="big"
|
||||||
|
) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big")
|
||||||
|
if verify:
|
||||||
|
self.public_key().verify(signature, data)
|
||||||
|
return signature
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls) -> "PrivateECDSA":
|
||||||
|
return cls(
|
||||||
|
key=ec.generate_private_key(
|
||||||
|
curve=cls.public_cls.curve, backend=default_backend()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PublicECDSAP256SHA256(PublicECDSA):
|
||||||
|
algorithm = Algorithm.ECDSAP256SHA256
|
||||||
|
chosen_hash = hashes.SHA256()
|
||||||
|
curve = ec.SECP256R1()
|
||||||
|
octets = 32
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateECDSAP256SHA256(PrivateECDSA):
|
||||||
|
public_cls = PublicECDSAP256SHA256
|
||||||
|
|
||||||
|
|
||||||
|
class PublicECDSAP384SHA384(PublicECDSA):
|
||||||
|
algorithm = Algorithm.ECDSAP384SHA384
|
||||||
|
chosen_hash = hashes.SHA384()
|
||||||
|
curve = ec.SECP384R1()
|
||||||
|
octets = 48
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateECDSAP384SHA384(PrivateECDSA):
|
||||||
|
public_cls = PublicECDSAP384SHA384
|
65
lib/dns/dnssecalgs/eddsa.py
Normal file
65
lib/dns/dnssecalgs/eddsa.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
|
||||||
|
|
||||||
|
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||||
|
from dns.dnssectypes import Algorithm
|
||||||
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
|
|
||||||
|
|
||||||
|
class PublicEDDSA(CryptographyPublicKey):
|
||||||
|
def verify(self, signature: bytes, data: bytes) -> None:
|
||||||
|
self.key.verify(signature, data)
|
||||||
|
|
||||||
|
def encode_key_bytes(self) -> bytes:
|
||||||
|
"""Encode a public key per RFC 8080, section 3."""
|
||||||
|
return self.key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
|
||||||
|
cls._ensure_algorithm_key_combination(key)
|
||||||
|
return cls(
|
||||||
|
key=cls.key_cls.from_public_bytes(key.key),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateEDDSA(CryptographyPrivateKey):
|
||||||
|
public_cls: Type[PublicEDDSA]
|
||||||
|
|
||||||
|
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||||
|
"""Sign using a private key per RFC 8080, section 4."""
|
||||||
|
signature = self.key.sign(data)
|
||||||
|
if verify:
|
||||||
|
self.public_key().verify(signature, data)
|
||||||
|
return signature
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls) -> "PrivateEDDSA":
|
||||||
|
return cls(key=cls.key_cls.generate())
|
||||||
|
|
||||||
|
|
||||||
|
class PublicED25519(PublicEDDSA):
|
||||||
|
key: ed25519.Ed25519PublicKey
|
||||||
|
key_cls = ed25519.Ed25519PublicKey
|
||||||
|
algorithm = Algorithm.ED25519
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateED25519(PrivateEDDSA):
|
||||||
|
key: ed25519.Ed25519PrivateKey
|
||||||
|
key_cls = ed25519.Ed25519PrivateKey
|
||||||
|
public_cls = PublicED25519
|
||||||
|
|
||||||
|
|
||||||
|
class PublicED448(PublicEDDSA):
|
||||||
|
key: ed448.Ed448PublicKey
|
||||||
|
key_cls = ed448.Ed448PublicKey
|
||||||
|
algorithm = Algorithm.ED448
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateED448(PrivateEDDSA):
|
||||||
|
key: ed448.Ed448PrivateKey
|
||||||
|
key_cls = ed448.Ed448PrivateKey
|
||||||
|
public_cls = PublicED448
|
119
lib/dns/dnssecalgs/rsa.py
Normal file
119
lib/dns/dnssecalgs/rsa.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||||
|
|
||||||
|
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
|
||||||
|
from dns.dnssectypes import Algorithm
|
||||||
|
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||||
|
|
||||||
|
|
||||||
|
class PublicRSA(CryptographyPublicKey):
|
||||||
|
key: rsa.RSAPublicKey
|
||||||
|
key_cls = rsa.RSAPublicKey
|
||||||
|
algorithm: Algorithm
|
||||||
|
chosen_hash: hashes.HashAlgorithm
|
||||||
|
|
||||||
|
def verify(self, signature: bytes, data: bytes) -> None:
|
||||||
|
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
|
||||||
|
|
||||||
|
def encode_key_bytes(self) -> bytes:
|
||||||
|
"""Encode a public key per RFC 3110, section 2."""
|
||||||
|
pn = self.key.public_numbers()
|
||||||
|
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
|
||||||
|
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
|
||||||
|
if _exp_len > 255:
|
||||||
|
exp_header = b"\0" + struct.pack("!H", _exp_len)
|
||||||
|
else:
|
||||||
|
exp_header = struct.pack("!B", _exp_len)
|
||||||
|
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
|
||||||
|
raise ValueError("unsupported RSA key length")
|
||||||
|
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
|
||||||
|
cls._ensure_algorithm_key_combination(key)
|
||||||
|
keyptr = key.key
|
||||||
|
(bytes_,) = struct.unpack("!B", keyptr[0:1])
|
||||||
|
keyptr = keyptr[1:]
|
||||||
|
if bytes_ == 0:
|
||||||
|
(bytes_,) = struct.unpack("!H", keyptr[0:2])
|
||||||
|
keyptr = keyptr[2:]
|
||||||
|
rsa_e = keyptr[0:bytes_]
|
||||||
|
rsa_n = keyptr[bytes_:]
|
||||||
|
return cls(
|
||||||
|
key=rsa.RSAPublicNumbers(
|
||||||
|
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
|
||||||
|
).public_key(default_backend())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateRSA(CryptographyPrivateKey):
|
||||||
|
key: rsa.RSAPrivateKey
|
||||||
|
key_cls = rsa.RSAPrivateKey
|
||||||
|
public_cls = PublicRSA
|
||||||
|
default_public_exponent = 65537
|
||||||
|
|
||||||
|
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||||
|
"""Sign using a private key per RFC 3110, section 3."""
|
||||||
|
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
|
||||||
|
if verify:
|
||||||
|
self.public_key().verify(signature, data)
|
||||||
|
return signature
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls, key_size: int) -> "PrivateRSA":
|
||||||
|
return cls(
|
||||||
|
key=rsa.generate_private_key(
|
||||||
|
public_exponent=cls.default_public_exponent,
|
||||||
|
key_size=key_size,
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PublicRSAMD5(PublicRSA):
|
||||||
|
algorithm = Algorithm.RSAMD5
|
||||||
|
chosen_hash = hashes.MD5()
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateRSAMD5(PrivateRSA):
|
||||||
|
public_cls = PublicRSAMD5
|
||||||
|
|
||||||
|
|
||||||
|
class PublicRSASHA1(PublicRSA):
|
||||||
|
algorithm = Algorithm.RSASHA1
|
||||||
|
chosen_hash = hashes.SHA1()
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateRSASHA1(PrivateRSA):
|
||||||
|
public_cls = PublicRSASHA1
|
||||||
|
|
||||||
|
|
||||||
|
class PublicRSASHA1NSEC3SHA1(PublicRSA):
|
||||||
|
algorithm = Algorithm.RSASHA1NSEC3SHA1
|
||||||
|
chosen_hash = hashes.SHA1()
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
|
||||||
|
public_cls = PublicRSASHA1NSEC3SHA1
|
||||||
|
|
||||||
|
|
||||||
|
class PublicRSASHA256(PublicRSA):
|
||||||
|
algorithm = Algorithm.RSASHA256
|
||||||
|
chosen_hash = hashes.SHA256()
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateRSASHA256(PrivateRSA):
|
||||||
|
public_cls = PublicRSASHA256
|
||||||
|
|
||||||
|
|
||||||
|
class PublicRSASHA512(PublicRSA):
|
||||||
|
algorithm = Algorithm.RSASHA512
|
||||||
|
chosen_hash = hashes.SHA512()
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateRSASHA512(PrivateRSA):
|
||||||
|
public_cls = PublicRSASHA512
|
|
@ -17,11 +17,10 @@
|
||||||
|
|
||||||
"""EDNS Options"""
|
"""EDNS Options"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import dns.enum
|
import dns.enum
|
||||||
import dns.inet
|
import dns.inet
|
||||||
|
@ -380,7 +379,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
|
||||||
def from_wire_parser(
|
def from_wire_parser(
|
||||||
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
|
||||||
) -> Option:
|
) -> Option:
|
||||||
the_code = EDECode.make(parser.get_uint16())
|
code = EDECode.make(parser.get_uint16())
|
||||||
text = parser.get_remaining()
|
text = parser.get_remaining()
|
||||||
|
|
||||||
if text:
|
if text:
|
||||||
|
@ -390,7 +389,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
|
||||||
else:
|
else:
|
||||||
btext = None
|
btext = None
|
||||||
|
|
||||||
return cls(the_code, btext)
|
return cls(code, btext)
|
||||||
|
|
||||||
|
|
||||||
_type_to_class: Dict[OptionType, Any] = {
|
_type_to_class: Dict[OptionType, Any] = {
|
||||||
|
@ -424,8 +423,8 @@ def option_from_wire_parser(
|
||||||
|
|
||||||
Returns an instance of a subclass of ``dns.edns.Option``.
|
Returns an instance of a subclass of ``dns.edns.Option``.
|
||||||
"""
|
"""
|
||||||
the_otype = OptionType.make(otype)
|
otype = OptionType.make(otype)
|
||||||
cls = get_option_class(the_otype)
|
cls = get_option_class(otype)
|
||||||
return cls.from_wire_parser(otype, parser)
|
return cls.from_wire_parser(otype, parser)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,17 +15,15 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import os
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class EntropyPool:
|
class EntropyPool:
|
||||||
|
|
||||||
# This is an entropy pool for Python implementations that do not
|
# This is an entropy pool for Python implementations that do not
|
||||||
# have a working SystemRandom. I'm not sure there are any, but
|
# have a working SystemRandom. I'm not sure there are any, but
|
||||||
# leaving this code doesn't hurt anything as the library code
|
# leaving this code doesn't hurt anything as the library code
|
||||||
|
|
|
@ -16,18 +16,31 @@
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
from typing import Type, TypeVar, Union
|
||||||
|
|
||||||
|
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
|
||||||
|
|
||||||
|
|
||||||
class IntEnum(enum.IntEnum):
|
class IntEnum(enum.IntEnum):
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_value(cls, value):
|
def _missing_(cls, value):
|
||||||
max = cls._maximum()
|
cls._check_value(value)
|
||||||
if value < 0 or value > max:
|
val = int.__new__(cls, value)
|
||||||
name = cls._short_name()
|
val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
|
||||||
raise ValueError(f"{name} must be between >= 0 and <= {max}")
|
val._value_ = value
|
||||||
|
return val
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_text(cls, text):
|
def _check_value(cls, value):
|
||||||
|
max = cls._maximum()
|
||||||
|
if not isinstance(value, int):
|
||||||
|
raise TypeError
|
||||||
|
if value < 0 or value > max:
|
||||||
|
name = cls._short_name()
|
||||||
|
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
|
||||||
text = text.upper()
|
text = text.upper()
|
||||||
try:
|
try:
|
||||||
return cls[text]
|
return cls[text]
|
||||||
|
@ -47,7 +60,7 @@ class IntEnum(enum.IntEnum):
|
||||||
raise cls._unknown_exception_class()
|
raise cls._unknown_exception_class()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_text(cls, value):
|
def to_text(cls: Type[TIntEnum], value: int) -> str:
|
||||||
cls._check_value(value)
|
cls._check_value(value)
|
||||||
try:
|
try:
|
||||||
text = cls(value).name
|
text = cls(value).name
|
||||||
|
@ -59,7 +72,7 @@ class IntEnum(enum.IntEnum):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make(cls, value):
|
def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
|
||||||
"""Convert text or a value into an enumerated type, if possible.
|
"""Convert text or a value into an enumerated type, if possible.
|
||||||
|
|
||||||
*value*, the ``int`` or ``str`` to convert.
|
*value*, the ``int`` or ``str`` to convert.
|
||||||
|
@ -76,10 +89,7 @@ class IntEnum(enum.IntEnum):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return cls.from_text(value)
|
return cls.from_text(value)
|
||||||
cls._check_value(value)
|
cls._check_value(value)
|
||||||
try:
|
return cls(value)
|
||||||
return cls(value)
|
|
||||||
except ValueError:
|
|
||||||
return value
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _maximum(cls):
|
def _maximum(cls):
|
||||||
|
|
|
@ -140,6 +140,22 @@ class Timeout(DNSException):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedAlgorithm(DNSException):
|
||||||
|
"""The DNSSEC algorithm is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
|
||||||
|
"""The DNSSEC algorithm is not supported for the given key type."""
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationFailure(DNSException):
|
||||||
|
"""The DNSSEC signature is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeniedByPolicy(DNSException):
|
||||||
|
"""Denied by DNSSEC policy."""
|
||||||
|
|
||||||
|
|
||||||
class ExceptionWrapper:
|
class ExceptionWrapper:
|
||||||
def __init__(self, exception_class):
|
def __init__(self, exception_class):
|
||||||
self.exception_class = exception_class
|
self.exception_class = exception_class
|
||||||
|
|
|
@ -17,9 +17,8 @@
|
||||||
|
|
||||||
"""DNS Message Flags."""
|
"""DNS Message Flags."""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Standard DNS flags
|
# Standard DNS flags
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from dns._immutable_ctx import immutable
|
from dns._immutable_ctx import immutable
|
||||||
|
|
||||||
|
|
|
@ -17,14 +17,12 @@
|
||||||
|
|
||||||
"""Generic Internet address helper functions."""
|
"""Generic Internet address helper functions."""
|
||||||
|
|
||||||
from typing import Any, Optional, Tuple
|
|
||||||
|
|
||||||
import socket
|
import socket
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import dns.ipv4
|
import dns.ipv4
|
||||||
import dns.ipv6
|
import dns.ipv6
|
||||||
|
|
||||||
|
|
||||||
# We assume that AF_INET and AF_INET6 are always defined. We keep
|
# We assume that AF_INET and AF_INET6 are always defined. We keep
|
||||||
# these here for the benefit of any old code (unlikely though that
|
# these here for the benefit of any old code (unlikely though that
|
||||||
# is!).
|
# is!).
|
||||||
|
@ -171,3 +169,12 @@ def low_level_address_tuple(
|
||||||
return tup
|
return tup
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"unknown address family {af}")
|
raise NotImplementedError(f"unknown address family {af}")
|
||||||
|
|
||||||
|
|
||||||
|
def any_for_af(af):
|
||||||
|
"""Return the 'any' address for the specified address family."""
|
||||||
|
if af == socket.AF_INET:
|
||||||
|
return "0.0.0.0"
|
||||||
|
elif af == socket.AF_INET6:
|
||||||
|
return "::"
|
||||||
|
raise NotImplementedError(f"unknown address family {af}")
|
||||||
|
|
|
@ -17,9 +17,8 @@
|
||||||
|
|
||||||
"""IPv4 helper functions."""
|
"""IPv4 helper functions."""
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
|
||||||
|
|
|
@ -17,10 +17,9 @@
|
||||||
|
|
||||||
"""IPv6 helper functions."""
|
"""IPv6 helper functions."""
|
||||||
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import re
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import re
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.ipv4
|
import dns.ipv4
|
||||||
|
|
|
@ -17,30 +17,29 @@
|
||||||
|
|
||||||
"""DNS Messages"""
|
"""DNS Messages"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import dns.wire
|
|
||||||
import dns.edns
|
import dns.edns
|
||||||
|
import dns.entropy
|
||||||
import dns.enum
|
import dns.enum
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.flags
|
import dns.flags
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.opcode
|
import dns.opcode
|
||||||
import dns.entropy
|
|
||||||
import dns.rcode
|
import dns.rcode
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.rrset
|
|
||||||
import dns.renderer
|
|
||||||
import dns.ttl
|
|
||||||
import dns.tsig
|
|
||||||
import dns.rdtypes.ANY.OPT
|
import dns.rdtypes.ANY.OPT
|
||||||
import dns.rdtypes.ANY.TSIG
|
import dns.rdtypes.ANY.TSIG
|
||||||
|
import dns.renderer
|
||||||
|
import dns.rrset
|
||||||
|
import dns.tsig
|
||||||
|
import dns.ttl
|
||||||
|
import dns.wire
|
||||||
|
|
||||||
|
|
||||||
class ShortHeader(dns.exception.FormError):
|
class ShortHeader(dns.exception.FormError):
|
||||||
|
@ -135,7 +134,7 @@ IndexKeyType = Tuple[
|
||||||
Optional[dns.rdataclass.RdataClass],
|
Optional[dns.rdataclass.RdataClass],
|
||||||
]
|
]
|
||||||
IndexType = Dict[IndexKeyType, dns.rrset.RRset]
|
IndexType = Dict[IndexKeyType, dns.rrset.RRset]
|
||||||
SectionType = Union[int, List[dns.rrset.RRset]]
|
SectionType = Union[int, str, List[dns.rrset.RRset]]
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
|
@ -231,7 +230,7 @@ class Message:
|
||||||
s.write("payload %d\n" % self.payload)
|
s.write("payload %d\n" % self.payload)
|
||||||
for opt in self.options:
|
for opt in self.options:
|
||||||
s.write("option %s\n" % opt.to_text())
|
s.write("option %s\n" % opt.to_text())
|
||||||
for (name, which) in self._section_enum.__members__.items():
|
for name, which in self._section_enum.__members__.items():
|
||||||
s.write(f";{name}\n")
|
s.write(f";{name}\n")
|
||||||
for rrset in self.section_from_number(which):
|
for rrset in self.section_from_number(which):
|
||||||
s.write(rrset.to_text(origin, relativize, **kw))
|
s.write(rrset.to_text(origin, relativize, **kw))
|
||||||
|
@ -348,27 +347,29 @@ class Message:
|
||||||
deleting: Optional[dns.rdataclass.RdataClass] = None,
|
deleting: Optional[dns.rdataclass.RdataClass] = None,
|
||||||
create: bool = False,
|
create: bool = False,
|
||||||
force_unique: bool = False,
|
force_unique: bool = False,
|
||||||
|
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||||
) -> dns.rrset.RRset:
|
) -> dns.rrset.RRset:
|
||||||
"""Find the RRset with the given attributes in the specified section.
|
"""Find the RRset with the given attributes in the specified section.
|
||||||
|
|
||||||
*section*, an ``int`` section number, or one of the section
|
*section*, an ``int`` section number, a ``str`` section name, or one of
|
||||||
attributes of this message. This specifies the
|
the section attributes of this message. This specifies the
|
||||||
the section of the message to search. For example::
|
the section of the message to search. For example::
|
||||||
|
|
||||||
my_message.find_rrset(my_message.answer, name, rdclass, rdtype)
|
my_message.find_rrset(my_message.answer, name, rdclass, rdtype)
|
||||||
my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype)
|
my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype)
|
||||||
|
my_message.find_rrset("ANSWER", name, rdclass, rdtype)
|
||||||
|
|
||||||
*name*, a ``dns.name.Name``, the name of the RRset.
|
*name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
|
||||||
|
|
||||||
*rdclass*, an ``int``, the class of the RRset.
|
*rdclass*, an ``int`` or ``str``, the class of the RRset.
|
||||||
|
|
||||||
*rdtype*, an ``int``, the type of the RRset.
|
*rdtype*, an ``int`` or ``str``, the type of the RRset.
|
||||||
|
|
||||||
*covers*, an ``int`` or ``None``, the covers value of the RRset.
|
*covers*, an ``int`` or ``str``, the covers value of the RRset.
|
||||||
The default is ``None``.
|
The default is ``dns.rdatatype.NONE``.
|
||||||
|
|
||||||
*deleting*, an ``int`` or ``None``, the deleting value of the RRset.
|
*deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
|
||||||
The default is ``None``.
|
RRset. The default is ``None``.
|
||||||
|
|
||||||
*create*, a ``bool``. If ``True``, create the RRset if it is not found.
|
*create*, a ``bool``. If ``True``, create the RRset if it is not found.
|
||||||
The created RRset is appended to *section*.
|
The created RRset is appended to *section*.
|
||||||
|
@ -378,6 +379,10 @@ class Message:
|
||||||
already. The default is ``False``. This is useful when creating
|
already. The default is ``False``. This is useful when creating
|
||||||
DDNS Update messages, as order matters for them.
|
DDNS Update messages, as order matters for them.
|
||||||
|
|
||||||
|
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
|
||||||
|
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
|
||||||
|
is used.
|
||||||
|
|
||||||
Raises ``KeyError`` if the RRset was not found and create was
|
Raises ``KeyError`` if the RRset was not found and create was
|
||||||
``False``.
|
``False``.
|
||||||
|
|
||||||
|
@ -386,10 +391,19 @@ class Message:
|
||||||
|
|
||||||
if isinstance(section, int):
|
if isinstance(section, int):
|
||||||
section_number = section
|
section_number = section
|
||||||
the_section = self.section_from_number(section_number)
|
section = self.section_from_number(section_number)
|
||||||
|
elif isinstance(section, str):
|
||||||
|
section_number = MessageSection.from_text(section)
|
||||||
|
section = self.section_from_number(section_number)
|
||||||
else:
|
else:
|
||||||
section_number = self.section_number(section)
|
section_number = self.section_number(section)
|
||||||
the_section = section
|
if isinstance(name, str):
|
||||||
|
name = dns.name.from_text(name, idna_codec=idna_codec)
|
||||||
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
|
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||||
|
covers = dns.rdatatype.RdataType.make(covers)
|
||||||
|
if deleting is not None:
|
||||||
|
deleting = dns.rdataclass.RdataClass.make(deleting)
|
||||||
key = (section_number, name, rdclass, rdtype, covers, deleting)
|
key = (section_number, name, rdclass, rdtype, covers, deleting)
|
||||||
if not force_unique:
|
if not force_unique:
|
||||||
if self.index is not None:
|
if self.index is not None:
|
||||||
|
@ -397,13 +411,13 @@ class Message:
|
||||||
if rrset is not None:
|
if rrset is not None:
|
||||||
return rrset
|
return rrset
|
||||||
else:
|
else:
|
||||||
for rrset in the_section:
|
for rrset in section:
|
||||||
if rrset.full_match(name, rdclass, rdtype, covers, deleting):
|
if rrset.full_match(name, rdclass, rdtype, covers, deleting):
|
||||||
return rrset
|
return rrset
|
||||||
if not create:
|
if not create:
|
||||||
raise KeyError
|
raise KeyError
|
||||||
rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
|
rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
|
||||||
the_section.append(rrset)
|
section.append(rrset)
|
||||||
if self.index is not None:
|
if self.index is not None:
|
||||||
self.index[key] = rrset
|
self.index[key] = rrset
|
||||||
return rrset
|
return rrset
|
||||||
|
@ -418,29 +432,31 @@ class Message:
|
||||||
deleting: Optional[dns.rdataclass.RdataClass] = None,
|
deleting: Optional[dns.rdataclass.RdataClass] = None,
|
||||||
create: bool = False,
|
create: bool = False,
|
||||||
force_unique: bool = False,
|
force_unique: bool = False,
|
||||||
|
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||||
) -> Optional[dns.rrset.RRset]:
|
) -> Optional[dns.rrset.RRset]:
|
||||||
"""Get the RRset with the given attributes in the specified section.
|
"""Get the RRset with the given attributes in the specified section.
|
||||||
|
|
||||||
If the RRset is not found, None is returned.
|
If the RRset is not found, None is returned.
|
||||||
|
|
||||||
*section*, an ``int`` section number, or one of the section
|
*section*, an ``int`` section number, a ``str`` section name, or one of
|
||||||
attributes of this message. This specifies the
|
the section attributes of this message. This specifies the
|
||||||
the section of the message to search. For example::
|
the section of the message to search. For example::
|
||||||
|
|
||||||
my_message.get_rrset(my_message.answer, name, rdclass, rdtype)
|
my_message.get_rrset(my_message.answer, name, rdclass, rdtype)
|
||||||
my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype)
|
my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype)
|
||||||
|
my_message.get_rrset("ANSWER", name, rdclass, rdtype)
|
||||||
|
|
||||||
*name*, a ``dns.name.Name``, the name of the RRset.
|
*name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
|
||||||
|
|
||||||
*rdclass*, an ``int``, the class of the RRset.
|
*rdclass*, an ``int`` or ``str``, the class of the RRset.
|
||||||
|
|
||||||
*rdtype*, an ``int``, the type of the RRset.
|
*rdtype*, an ``int`` or ``str``, the type of the RRset.
|
||||||
|
|
||||||
*covers*, an ``int`` or ``None``, the covers value of the RRset.
|
*covers*, an ``int`` or ``str``, the covers value of the RRset.
|
||||||
The default is ``None``.
|
The default is ``dns.rdatatype.NONE``.
|
||||||
|
|
||||||
*deleting*, an ``int`` or ``None``, the deleting value of the RRset.
|
*deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
|
||||||
The default is ``None``.
|
RRset. The default is ``None``.
|
||||||
|
|
||||||
*create*, a ``bool``. If ``True``, create the RRset if it is not found.
|
*create*, a ``bool``. If ``True``, create the RRset if it is not found.
|
||||||
The created RRset is appended to *section*.
|
The created RRset is appended to *section*.
|
||||||
|
@ -450,12 +466,24 @@ class Message:
|
||||||
already. The default is ``False``. This is useful when creating
|
already. The default is ``False``. This is useful when creating
|
||||||
DDNS Update messages, as order matters for them.
|
DDNS Update messages, as order matters for them.
|
||||||
|
|
||||||
|
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
|
||||||
|
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
|
||||||
|
is used.
|
||||||
|
|
||||||
Returns a ``dns.rrset.RRset object`` or ``None``.
|
Returns a ``dns.rrset.RRset object`` or ``None``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rrset = self.find_rrset(
|
rrset = self.find_rrset(
|
||||||
section, name, rdclass, rdtype, covers, deleting, create, force_unique
|
section,
|
||||||
|
name,
|
||||||
|
rdclass,
|
||||||
|
rdtype,
|
||||||
|
covers,
|
||||||
|
deleting,
|
||||||
|
create,
|
||||||
|
force_unique,
|
||||||
|
idna_codec,
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
rrset = None
|
rrset = None
|
||||||
|
@ -1708,13 +1736,11 @@ def make_query(
|
||||||
|
|
||||||
if isinstance(qname, str):
|
if isinstance(qname, str):
|
||||||
qname = dns.name.from_text(qname, idna_codec=idna_codec)
|
qname = dns.name.from_text(qname, idna_codec=idna_codec)
|
||||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||||
m = QueryMessage(id=id)
|
m = QueryMessage(id=id)
|
||||||
m.flags = dns.flags.Flag(flags)
|
m.flags = dns.flags.Flag(flags)
|
||||||
m.find_rrset(
|
m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
|
||||||
m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True
|
|
||||||
)
|
|
||||||
# only pass keywords on to use_edns if they have been set to a
|
# only pass keywords on to use_edns if they have been set to a
|
||||||
# non-None value. Setting a field will turn EDNS on if it hasn't
|
# non-None value. Setting a field will turn EDNS on if it hasn't
|
||||||
# been configured.
|
# been configured.
|
||||||
|
|
|
@ -18,12 +18,10 @@
|
||||||
"""DNS Names.
|
"""DNS Names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import struct
|
|
||||||
|
|
||||||
import encodings.idna # type: ignore
|
import encodings.idna # type: ignore
|
||||||
|
import struct
|
||||||
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import idna # type: ignore
|
import idna # type: ignore
|
||||||
|
@ -33,10 +31,9 @@ except ImportError: # pragma: no cover
|
||||||
have_idna_2008 = False
|
have_idna_2008 = False
|
||||||
|
|
||||||
import dns.enum
|
import dns.enum
|
||||||
import dns.wire
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.wire
|
||||||
|
|
||||||
CompressType = Dict["Name", int]
|
CompressType = Dict["Name", int]
|
||||||
|
|
||||||
|
|
329
lib/dns/nameserver.py
Normal file
329
lib/dns/nameserver.py
Normal file
|
@ -0,0 +1,329 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import dns.asyncbackend
|
||||||
|
import dns.asyncquery
|
||||||
|
import dns.inet
|
||||||
|
import dns.message
|
||||||
|
import dns.query
|
||||||
|
|
||||||
|
|
||||||
|
class Nameserver:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def kind(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_always_max_size(self) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def answer_nameserver(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def answer_port(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
backend: dns.asyncbackend.Backend,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AddressAndPortNameserver(Nameserver):
|
||||||
|
def __init__(self, address: str, port: int):
|
||||||
|
super().__init__()
|
||||||
|
self.address = address
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
def kind(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_always_max_size(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
ns_kind = self.kind()
|
||||||
|
return f"{ns_kind}:{self.address}@{self.port}"
|
||||||
|
|
||||||
|
def answer_nameserver(self) -> str:
|
||||||
|
return self.address
|
||||||
|
|
||||||
|
def answer_port(self) -> int:
|
||||||
|
return self.port
|
||||||
|
|
||||||
|
|
||||||
|
class Do53Nameserver(AddressAndPortNameserver):
|
||||||
|
def __init__(self, address: str, port: int = 53):
|
||||||
|
super().__init__(address, port)
|
||||||
|
|
||||||
|
def kind(self):
|
||||||
|
return "Do53"
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
if max_size:
|
||||||
|
response = dns.query.tcp(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
timeout=timeout,
|
||||||
|
port=self.port,
|
||||||
|
source=source,
|
||||||
|
source_port=source_port,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = dns.query.udp(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
timeout=timeout,
|
||||||
|
port=self.port,
|
||||||
|
source=source,
|
||||||
|
source_port=source_port,
|
||||||
|
raise_on_truncation=True,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def async_query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
backend: dns.asyncbackend.Backend,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
if max_size:
|
||||||
|
response = await dns.asyncquery.tcp(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
timeout=timeout,
|
||||||
|
port=self.port,
|
||||||
|
source=source,
|
||||||
|
source_port=source_port,
|
||||||
|
backend=backend,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await dns.asyncquery.udp(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
timeout=timeout,
|
||||||
|
port=self.port,
|
||||||
|
source=source,
|
||||||
|
source_port=source_port,
|
||||||
|
raise_on_truncation=True,
|
||||||
|
backend=backend,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class DoHNameserver(Nameserver):
|
||||||
|
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.url = url
|
||||||
|
self.bootstrap_address = bootstrap_address
|
||||||
|
|
||||||
|
def kind(self):
|
||||||
|
return "DoH"
|
||||||
|
|
||||||
|
def is_always_max_size(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
def answer_nameserver(self) -> str:
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
def answer_port(self) -> int:
|
||||||
|
port = urlparse(self.url).port
|
||||||
|
if port is None:
|
||||||
|
port = 443
|
||||||
|
return port
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool = False,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
return dns.query.https(
|
||||||
|
request,
|
||||||
|
self.url,
|
||||||
|
timeout=timeout,
|
||||||
|
bootstrap_address=self.bootstrap_address,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
backend: dns.asyncbackend.Backend,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
return await dns.asyncquery.https(
|
||||||
|
request,
|
||||||
|
self.url,
|
||||||
|
timeout=timeout,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoTNameserver(AddressAndPortNameserver):
|
||||||
|
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
|
||||||
|
super().__init__(address, port)
|
||||||
|
self.hostname = hostname
|
||||||
|
|
||||||
|
def kind(self):
|
||||||
|
return "DoT"
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool = False,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
return dns.query.tls(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
port=self.port,
|
||||||
|
timeout=timeout,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
server_hostname=self.hostname,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
backend: dns.asyncbackend.Backend,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
return await dns.asyncquery.tls(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
port=self.port,
|
||||||
|
timeout=timeout,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
server_hostname=self.hostname,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoQNameserver(AddressAndPortNameserver):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
address: str,
|
||||||
|
port: int = 853,
|
||||||
|
verify: Union[bool, str] = True,
|
||||||
|
server_hostname: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(address, port)
|
||||||
|
self.verify = verify
|
||||||
|
self.server_hostname = server_hostname
|
||||||
|
|
||||||
|
def kind(self):
|
||||||
|
return "DoQ"
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool = False,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
return dns.query.quic(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
port=self.port,
|
||||||
|
timeout=timeout,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
verify=self.verify,
|
||||||
|
server_hostname=self.server_hostname,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_query(
|
||||||
|
self,
|
||||||
|
request: dns.message.QueryMessage,
|
||||||
|
timeout: float,
|
||||||
|
source: Optional[str],
|
||||||
|
source_port: int,
|
||||||
|
max_size: bool,
|
||||||
|
backend: dns.asyncbackend.Backend,
|
||||||
|
one_rr_per_rrset: bool = False,
|
||||||
|
ignore_trailing: bool = False,
|
||||||
|
) -> dns.message.Message:
|
||||||
|
return await dns.asyncquery.quic(
|
||||||
|
request,
|
||||||
|
self.address,
|
||||||
|
port=self.port,
|
||||||
|
timeout=timeout,
|
||||||
|
one_rr_per_rrset=one_rr_per_rrset,
|
||||||
|
ignore_trailing=ignore_trailing,
|
||||||
|
verify=self.verify,
|
||||||
|
server_hostname=self.server_hostname,
|
||||||
|
)
|
|
@ -17,19 +17,17 @@
|
||||||
|
|
||||||
"""DNS nodes. A node is a set of rdatasets."""
|
"""DNS nodes. A node is a set of rdatasets."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import io
|
import io
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
import dns.rdataset
|
import dns.rdataset
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.rrset
|
|
||||||
import dns.renderer
|
import dns.renderer
|
||||||
|
import dns.rrset
|
||||||
|
|
||||||
_cname_types = {
|
_cname_types = {
|
||||||
dns.rdatatype.CNAME,
|
dns.rdatatype.CNAME,
|
||||||
|
|
260
lib/dns/query.py
260
lib/dns/query.py
|
@ -17,8 +17,6 @@
|
||||||
|
|
||||||
"""Talk to a DNS server."""
|
"""Talk to a DNS server."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
|
@ -28,12 +26,12 @@ import selectors
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.inet
|
import dns.inet
|
||||||
import dns.name
|
|
||||||
import dns.message
|
import dns.message
|
||||||
|
import dns.name
|
||||||
import dns.quic
|
import dns.quic
|
||||||
import dns.rcode
|
import dns.rcode
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
|
@ -43,20 +41,32 @@ import dns.transaction
|
||||||
import dns.tsig
|
import dns.tsig
|
||||||
import dns.xfr
|
import dns.xfr
|
||||||
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
from requests_toolbelt.adapters.source import SourceAddressAdapter
|
|
||||||
from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
|
|
||||||
|
|
||||||
_have_requests = True
|
def _remaining(expiration):
|
||||||
except ImportError: # pragma: no cover
|
if expiration is None:
|
||||||
_have_requests = False
|
return None
|
||||||
|
timeout = expiration - time.time()
|
||||||
|
if timeout <= 0.0:
|
||||||
|
raise dns.exception.Timeout
|
||||||
|
return timeout
|
||||||
|
|
||||||
|
|
||||||
|
def _expiration_for_this_attempt(timeout, expiration):
|
||||||
|
if expiration is None:
|
||||||
|
return None
|
||||||
|
return min(time.time() + timeout, expiration)
|
||||||
|
|
||||||
|
|
||||||
_have_httpx = False
|
_have_httpx = False
|
||||||
_have_http2 = False
|
_have_http2 = False
|
||||||
try:
|
try:
|
||||||
|
import httpcore
|
||||||
|
import httpcore._backends.sync
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
_CoreNetworkBackend = httpcore.NetworkBackend
|
||||||
|
_CoreSyncStream = httpcore._backends.sync.SyncStream
|
||||||
|
|
||||||
_have_httpx = True
|
_have_httpx = True
|
||||||
try:
|
try:
|
||||||
# See if http2 support is available.
|
# See if http2 support is available.
|
||||||
|
@ -64,10 +74,87 @@ try:
|
||||||
_have_http2 = True
|
_have_http2 = True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
pass
|
|
||||||
|
|
||||||
have_doh = _have_requests or _have_httpx
|
class _NetworkBackend(_CoreNetworkBackend):
|
||||||
|
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||||
|
super().__init__()
|
||||||
|
self._local_port = local_port
|
||||||
|
self._resolver = resolver
|
||||||
|
self._bootstrap_address = bootstrap_address
|
||||||
|
self._family = family
|
||||||
|
|
||||||
|
def connect_tcp(
|
||||||
|
self, host, port, timeout, local_address, socket_options=None
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
addresses = []
|
||||||
|
_, expiration = _compute_times(timeout)
|
||||||
|
if dns.inet.is_address(host):
|
||||||
|
addresses.append(host)
|
||||||
|
elif self._bootstrap_address is not None:
|
||||||
|
addresses.append(self._bootstrap_address)
|
||||||
|
else:
|
||||||
|
timeout = _remaining(expiration)
|
||||||
|
family = self._family
|
||||||
|
if local_address:
|
||||||
|
family = dns.inet.af_for_address(local_address)
|
||||||
|
answers = self._resolver.resolve_name(
|
||||||
|
host, family=family, lifetime=timeout
|
||||||
|
)
|
||||||
|
addresses = answers.addresses()
|
||||||
|
for address in addresses:
|
||||||
|
af = dns.inet.af_for_address(address)
|
||||||
|
if local_address is not None or self._local_port != 0:
|
||||||
|
source = dns.inet.low_level_address_tuple(
|
||||||
|
(local_address, self._local_port), af
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
source = None
|
||||||
|
sock = _make_socket(af, socket.SOCK_STREAM, source)
|
||||||
|
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
||||||
|
try:
|
||||||
|
_connect(
|
||||||
|
sock,
|
||||||
|
dns.inet.low_level_address_tuple((address, port), af),
|
||||||
|
attempt_expiration,
|
||||||
|
)
|
||||||
|
return _CoreSyncStream(sock)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise httpcore.ConnectError
|
||||||
|
|
||||||
|
def connect_unix_socket(
|
||||||
|
self, path, timeout, socket_options=None
|
||||||
|
): # pylint: disable=signature-differs
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class _HTTPTransport(httpx.HTTPTransport):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
local_port=0,
|
||||||
|
bootstrap_address=None,
|
||||||
|
resolver=None,
|
||||||
|
family=socket.AF_UNSPEC,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if resolver is None:
|
||||||
|
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||||
|
import dns.resolver
|
||||||
|
|
||||||
|
resolver = dns.resolver.Resolver()
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._pool._network_backend = _NetworkBackend(
|
||||||
|
resolver, local_port, bootstrap_address, family
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
|
||||||
|
class _HTTPTransport: # type: ignore
|
||||||
|
def connect_tcp(self, host, port, timeout, local_address):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
have_doh = _have_httpx
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
|
@ -88,7 +175,7 @@ except ImportError: # pragma: no cover
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_default_context(cls, *args, **kwargs):
|
def create_default_context(cls, *args, **kwargs):
|
||||||
raise Exception("no ssl support")
|
raise Exception("no ssl support") # pylint: disable=broad-exception-raised
|
||||||
|
|
||||||
|
|
||||||
# Function used to create a socket. Can be overridden if needed in special
|
# Function used to create a socket. Can be overridden if needed in special
|
||||||
|
@ -105,7 +192,7 @@ class BadResponse(dns.exception.FormError):
|
||||||
|
|
||||||
|
|
||||||
class NoDOH(dns.exception.DNSException):
|
class NoDOH(dns.exception.DNSException):
|
||||||
"""DNS over HTTPS (DOH) was requested but the requests module is not
|
"""DNS over HTTPS (DOH) was requested but the httpx module is not
|
||||||
available."""
|
available."""
|
||||||
|
|
||||||
|
|
||||||
|
@ -230,7 +317,7 @@ def _destination_and_source(
|
||||||
# We know the destination af, so source had better agree!
|
# We know the destination af, so source had better agree!
|
||||||
if saf != af:
|
if saf != af:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"different address families for source " + "and destination"
|
"different address families for source and destination"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# We didn't know the destination af, but we know the source,
|
# We didn't know the destination af, but we know the source,
|
||||||
|
@ -240,11 +327,10 @@ def _destination_and_source(
|
||||||
# Caller has specified a source_port but not an address, so we
|
# Caller has specified a source_port but not an address, so we
|
||||||
# need to return a source, and we need to use the appropriate
|
# need to return a source, and we need to use the appropriate
|
||||||
# wildcard address as the address.
|
# wildcard address as the address.
|
||||||
if af == socket.AF_INET:
|
try:
|
||||||
source = "0.0.0.0"
|
source = dns.inet.any_for_af(af)
|
||||||
elif af == socket.AF_INET6:
|
except Exception:
|
||||||
source = "::"
|
# we catch this and raise ValueError for backwards compatibility
|
||||||
else:
|
|
||||||
raise ValueError("source_port specified but address family is unknown")
|
raise ValueError("source_port specified but address family is unknown")
|
||||||
# Convert high-level (address, port) tuples into low-level address
|
# Convert high-level (address, port) tuples into low-level address
|
||||||
# tuples.
|
# tuples.
|
||||||
|
@ -289,6 +375,8 @@ def https(
|
||||||
post: bool = True,
|
post: bool = True,
|
||||||
bootstrap_address: Optional[str] = None,
|
bootstrap_address: Optional[str] = None,
|
||||||
verify: Union[bool, str] = True,
|
verify: Union[bool, str] = True,
|
||||||
|
resolver: Optional["dns.resolver.Resolver"] = None,
|
||||||
|
family: Optional[int] = socket.AF_UNSPEC,
|
||||||
) -> dns.message.Message:
|
) -> dns.message.Message:
|
||||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||||
|
|
||||||
|
@ -314,91 +402,78 @@ def https(
|
||||||
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
|
||||||
received message.
|
received message.
|
||||||
|
|
||||||
*session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the
|
*session*, an ``httpx.Client``. If provided, the client session to use to send the
|
||||||
client/session to use to send the queries.
|
queries.
|
||||||
|
|
||||||
*path*, a ``str``. If *where* is an IP address, then *path* will be used to
|
*path*, a ``str``. If *where* is an IP address, then *path* will be used to
|
||||||
construct the URL to send the DNS query to.
|
construct the URL to send the DNS query to.
|
||||||
|
|
||||||
*post*, a ``bool``. If ``True``, the default, POST method will be used.
|
*post*, a ``bool``. If ``True``, the default, POST method will be used.
|
||||||
|
|
||||||
*bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS
|
*bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
|
||||||
resolver.
|
|
||||||
|
|
||||||
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
|
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
|
||||||
of the server is done using the default CA bundle; if ``False``, then no
|
of the server is done using the default CA bundle; if ``False``, then no
|
||||||
verification is done; if a `str` then it specifies the path to a certificate file or
|
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||||
directory which will be used for verification.
|
directory which will be used for verification.
|
||||||
|
|
||||||
|
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
|
||||||
|
resolution of hostnames in URLs. If not specified, a new resolver with a default
|
||||||
|
configuration will be used; note this is *not* the default resolver as that resolver
|
||||||
|
might have been configured to use DoH causing a chicken-and-egg problem. This
|
||||||
|
parameter only has an effect if the HTTP library is httpx.
|
||||||
|
|
||||||
|
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
|
||||||
|
and AAAA records will be retrieved.
|
||||||
|
|
||||||
Returns a ``dns.message.Message``.
|
Returns a ``dns.message.Message``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not have_doh:
|
if not have_doh:
|
||||||
raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover
|
raise NoDOH # pragma: no cover
|
||||||
|
if session and not isinstance(session, httpx.Client):
|
||||||
_httpx_ok = _have_httpx
|
raise ValueError("session parameter must be an httpx.Client")
|
||||||
|
|
||||||
wire = q.to_wire()
|
wire = q.to_wire()
|
||||||
(af, _, source) = _destination_and_source(where, port, source, source_port, False)
|
(af, _, the_source) = _destination_and_source(
|
||||||
transport_adapter = None
|
where, port, source, source_port, False
|
||||||
|
)
|
||||||
transport = None
|
transport = None
|
||||||
headers = {"accept": "application/dns-message"}
|
headers = {"accept": "application/dns-message"}
|
||||||
if af is not None:
|
if af is not None and dns.inet.is_address(where):
|
||||||
if af == socket.AF_INET:
|
if af == socket.AF_INET:
|
||||||
url = "https://{}:{}{}".format(where, port, path)
|
url = "https://{}:{}{}".format(where, port, path)
|
||||||
elif af == socket.AF_INET6:
|
elif af == socket.AF_INET6:
|
||||||
url = "https://[{}]:{}{}".format(where, port, path)
|
url = "https://[{}]:{}{}".format(where, port, path)
|
||||||
elif bootstrap_address is not None:
|
|
||||||
_httpx_ok = False
|
|
||||||
split_url = urllib.parse.urlsplit(where)
|
|
||||||
if split_url.hostname is None:
|
|
||||||
raise ValueError("DoH URL has no hostname")
|
|
||||||
headers["Host"] = split_url.hostname
|
|
||||||
url = where.replace(split_url.hostname, bootstrap_address)
|
|
||||||
if _have_requests:
|
|
||||||
transport_adapter = HostHeaderSSLAdapter()
|
|
||||||
else:
|
else:
|
||||||
url = where
|
url = where
|
||||||
if source is not None:
|
|
||||||
# set source port and source address
|
|
||||||
if _have_httpx:
|
|
||||||
if source_port == 0:
|
|
||||||
transport = httpx.HTTPTransport(local_address=source[0], verify=verify)
|
|
||||||
else:
|
|
||||||
_httpx_ok = False
|
|
||||||
if _have_requests:
|
|
||||||
transport_adapter = SourceAddressAdapter(source)
|
|
||||||
|
|
||||||
if session:
|
# set source port and source address
|
||||||
if _have_httpx:
|
|
||||||
_is_httpx = isinstance(session, httpx.Client)
|
if the_source is None:
|
||||||
else:
|
local_address = None
|
||||||
_is_httpx = False
|
local_port = 0
|
||||||
if _is_httpx and not _httpx_ok:
|
|
||||||
raise NoDOH(
|
|
||||||
"Session is httpx, but httpx cannot be used for "
|
|
||||||
"the requested operation."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
_is_httpx = _httpx_ok
|
local_address = the_source[0]
|
||||||
|
local_port = the_source[1]
|
||||||
if not _httpx_ok and not _have_requests:
|
transport = _HTTPTransport(
|
||||||
raise NoDOH(
|
local_address=local_address,
|
||||||
"Cannot use httpx for this operation, and requests is not available."
|
http1=True,
|
||||||
)
|
http2=_have_http2,
|
||||||
|
verify=verify,
|
||||||
|
local_port=local_port,
|
||||||
|
bootstrap_address=bootstrap_address,
|
||||||
|
resolver=resolver,
|
||||||
|
family=family,
|
||||||
|
)
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
|
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
|
||||||
elif _is_httpx:
|
else:
|
||||||
cm = httpx.Client(
|
cm = httpx.Client(
|
||||||
http1=True, http2=_have_http2, verify=verify, transport=transport
|
http1=True, http2=_have_http2, verify=verify, transport=transport
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
cm = requests.sessions.Session()
|
|
||||||
with cm as session:
|
with cm as session:
|
||||||
if transport_adapter and not _is_httpx:
|
|
||||||
session.mount(url, transport_adapter)
|
|
||||||
|
|
||||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||||
# GET and POST examples
|
# GET and POST examples
|
||||||
if post:
|
if post:
|
||||||
|
@ -408,29 +483,13 @@ def https(
|
||||||
"content-length": str(len(wire)),
|
"content-length": str(len(wire)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if _is_httpx:
|
response = session.post(url, headers=headers, content=wire, timeout=timeout)
|
||||||
response = session.post(
|
|
||||||
url, headers=headers, content=wire, timeout=timeout
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = session.post(
|
|
||||||
url, headers=headers, data=wire, timeout=timeout, verify=verify
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||||
if _is_httpx:
|
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
response = session.get(
|
||||||
response = session.get(
|
url, headers=headers, timeout=timeout, params={"dns": twire}
|
||||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = session.get(
|
|
||||||
url,
|
|
||||||
headers=headers,
|
|
||||||
timeout=timeout,
|
|
||||||
verify=verify,
|
|
||||||
params={"dns": wire},
|
|
||||||
)
|
|
||||||
|
|
||||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||||
# status codes
|
# status codes
|
||||||
|
@ -1070,6 +1129,7 @@ def quic(
|
||||||
ignore_trailing: bool = False,
|
ignore_trailing: bool = False,
|
||||||
connection: Optional[dns.quic.SyncQuicConnection] = None,
|
connection: Optional[dns.quic.SyncQuicConnection] = None,
|
||||||
verify: Union[bool, str] = True,
|
verify: Union[bool, str] = True,
|
||||||
|
server_hostname: Optional[str] = None,
|
||||||
) -> dns.message.Message:
|
) -> dns.message.Message:
|
||||||
"""Return the response obtained after sending a query via DNS-over-QUIC.
|
"""Return the response obtained after sending a query via DNS-over-QUIC.
|
||||||
|
|
||||||
|
@ -1101,6 +1161,10 @@ def quic(
|
||||||
verification is done; if a `str` then it specifies the path to a certificate file or
|
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||||
directory which will be used for verification.
|
directory which will be used for verification.
|
||||||
|
|
||||||
|
*server_hostname*, a ``str`` containing the server's hostname. The
|
||||||
|
default is ``None``, which means that no hostname is known, and if an
|
||||||
|
SSL context is created, hostname checking will be disabled.
|
||||||
|
|
||||||
Returns a ``dns.message.Message``.
|
Returns a ``dns.message.Message``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1115,16 +1179,18 @@ def quic(
|
||||||
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
|
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
|
||||||
the_connection = connection
|
the_connection = connection
|
||||||
else:
|
else:
|
||||||
manager = dns.quic.SyncQuicManager(verify_mode=verify)
|
manager = dns.quic.SyncQuicManager(
|
||||||
|
verify_mode=verify, server_name=server_hostname
|
||||||
|
)
|
||||||
the_manager = manager # for type checking happiness
|
the_manager = manager # for type checking happiness
|
||||||
|
|
||||||
with manager:
|
with manager:
|
||||||
if not connection:
|
if not connection:
|
||||||
the_connection = the_manager.connect(where, port, source, source_port)
|
the_connection = the_manager.connect(where, port, source, source_port)
|
||||||
start = time.time()
|
(start, expiration) = _compute_times(timeout)
|
||||||
with the_connection.make_stream() as stream:
|
with the_connection.make_stream(timeout) as stream:
|
||||||
stream.send(wire, True)
|
stream.send(wire, True)
|
||||||
wire = stream.receive(timeout)
|
wire = stream.receive(_remaining(expiration))
|
||||||
finish = time.time()
|
finish = time.time()
|
||||||
r = dns.message.from_wire(
|
r = dns.message.from_wire(
|
||||||
wire,
|
wire,
|
||||||
|
|
|
@ -5,13 +5,13 @@ try:
|
||||||
|
|
||||||
import dns.asyncbackend
|
import dns.asyncbackend
|
||||||
from dns._asyncbackend import NullContext
|
from dns._asyncbackend import NullContext
|
||||||
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
|
|
||||||
from dns.quic._asyncio import (
|
from dns.quic._asyncio import (
|
||||||
AsyncioQuicManager,
|
|
||||||
AsyncioQuicConnection,
|
AsyncioQuicConnection,
|
||||||
|
AsyncioQuicManager,
|
||||||
AsyncioQuicStream,
|
AsyncioQuicStream,
|
||||||
)
|
)
|
||||||
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
|
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
|
||||||
|
from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream
|
||||||
|
|
||||||
have_quic = True
|
have_quic = True
|
||||||
|
|
||||||
|
@ -33,9 +33,10 @@ try:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
|
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
|
||||||
TrioQuicManager,
|
|
||||||
TrioQuicConnection,
|
TrioQuicConnection,
|
||||||
|
TrioQuicManager,
|
||||||
TrioQuicStream,
|
TrioQuicStream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -9,14 +9,16 @@ import time
|
||||||
import aioquic.quic.configuration # type: ignore
|
import aioquic.quic.configuration # type: ignore
|
||||||
import aioquic.quic.connection # type: ignore
|
import aioquic.quic.connection # type: ignore
|
||||||
import aioquic.quic.events # type: ignore
|
import aioquic.quic.events # type: ignore
|
||||||
import dns.inet
|
|
||||||
import dns.asyncbackend
|
|
||||||
|
|
||||||
|
import dns.asyncbackend
|
||||||
|
import dns.exception
|
||||||
|
import dns.inet
|
||||||
from dns.quic._common import (
|
from dns.quic._common import (
|
||||||
BaseQuicStream,
|
QUIC_MAX_DATAGRAM,
|
||||||
AsyncQuicConnection,
|
AsyncQuicConnection,
|
||||||
AsyncQuicManager,
|
AsyncQuicManager,
|
||||||
QUIC_MAX_DATAGRAM,
|
BaseQuicStream,
|
||||||
|
UnexpectedEOF,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,15 +32,15 @@ class AsyncioQuicStream(BaseQuicStream):
|
||||||
await self._wake_up.wait()
|
await self._wake_up.wait()
|
||||||
|
|
||||||
async def wait_for(self, amount, expiration):
|
async def wait_for(self, amount, expiration):
|
||||||
timeout = self._timeout_from_expiration(expiration)
|
|
||||||
while True:
|
while True:
|
||||||
|
timeout = self._timeout_from_expiration(expiration)
|
||||||
if self._buffer.have(amount):
|
if self._buffer.have(amount):
|
||||||
return
|
return
|
||||||
self._expecting = amount
|
self._expecting = amount
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
|
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
|
||||||
except Exception:
|
except TimeoutError:
|
||||||
pass
|
raise dns.exception.Timeout
|
||||||
self._expecting = 0
|
self._expecting = 0
|
||||||
|
|
||||||
async def receive(self, timeout=None):
|
async def receive(self, timeout=None):
|
||||||
|
@ -86,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||||
try:
|
try:
|
||||||
af = dns.inet.af_for_address(self._address)
|
af = dns.inet.af_for_address(self._address)
|
||||||
backend = dns.asyncbackend.get_backend("asyncio")
|
backend = dns.asyncbackend.get_backend("asyncio")
|
||||||
|
# Note that peer is a low-level address tuple, but make_socket() wants
|
||||||
|
# a high-level address tuple, so we convert.
|
||||||
self._socket = await backend.make_socket(
|
self._socket = await backend.make_socket(
|
||||||
af, socket.SOCK_DGRAM, 0, self._source, self._peer
|
af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
|
||||||
)
|
)
|
||||||
self._socket_created.set()
|
self._socket_created.set()
|
||||||
async with self._socket:
|
async with self._socket:
|
||||||
|
@ -106,6 +110,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||||
self._wake_timer.notify_all()
|
self._wake_timer.notify_all()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
self._done = True
|
||||||
|
async with self._wake_timer:
|
||||||
|
self._wake_timer.notify_all()
|
||||||
|
self._handshake_complete.set()
|
||||||
|
|
||||||
async def _wait_for_wake_timer(self):
|
async def _wait_for_wake_timer(self):
|
||||||
async with self._wake_timer:
|
async with self._wake_timer:
|
||||||
|
@ -115,7 +124,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||||
await self._socket_created.wait()
|
await self._socket_created.wait()
|
||||||
while not self._done:
|
while not self._done:
|
||||||
datagrams = self._connection.datagrams_to_send(time.time())
|
datagrams = self._connection.datagrams_to_send(time.time())
|
||||||
for (datagram, address) in datagrams:
|
for datagram, address in datagrams:
|
||||||
assert address == self._peer[0]
|
assert address == self._peer[0]
|
||||||
await self._socket.sendto(datagram, self._peer, None)
|
await self._socket.sendto(datagram, self._peer, None)
|
||||||
(expiration, interval) = self._get_timer_values()
|
(expiration, interval) = self._get_timer_values()
|
||||||
|
@ -160,8 +169,13 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||||
self._receiver_task = asyncio.Task(self._receiver())
|
self._receiver_task = asyncio.Task(self._receiver())
|
||||||
self._sender_task = asyncio.Task(self._sender())
|
self._sender_task = asyncio.Task(self._sender())
|
||||||
|
|
||||||
async def make_stream(self):
|
async def make_stream(self, timeout=None):
|
||||||
await self._handshake_complete.wait()
|
try:
|
||||||
|
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
raise dns.exception.Timeout
|
||||||
|
if self._done:
|
||||||
|
raise UnexpectedEOF
|
||||||
stream_id = self._connection.get_next_available_stream_id(False)
|
stream_id = self._connection.get_next_available_stream_id(False)
|
||||||
stream = AsyncioQuicStream(self, stream_id)
|
stream = AsyncioQuicStream(self, stream_id)
|
||||||
self._streams[stream_id] = stream
|
self._streams[stream_id] = stream
|
||||||
|
@ -172,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||||
self._manager.closed(self._peer[0], self._peer[1])
|
self._manager.closed(self._peer[0], self._peer[1])
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._connection.close()
|
self._connection.close()
|
||||||
|
# sender might be blocked on this, so set it
|
||||||
|
self._socket_created.set()
|
||||||
|
await self._socket.close()
|
||||||
async with self._wake_timer:
|
async with self._wake_timer:
|
||||||
self._wake_timer.notify_all()
|
self._wake_timer.notify_all()
|
||||||
try:
|
try:
|
||||||
|
@ -185,8 +202,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||||
|
|
||||||
|
|
||||||
class AsyncioQuicManager(AsyncQuicManager):
|
class AsyncioQuicManager(AsyncQuicManager):
|
||||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||||
super().__init__(conf, verify_mode, AsyncioQuicConnection)
|
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
|
||||||
|
|
||||||
def connect(self, address, port=853, source=None, source_port=0):
|
def connect(self, address, port=853, source=None, source_port=0):
|
||||||
(connection, start) = self._connect(address, port, source, source_port)
|
(connection, start) = self._connect(address, port, source, source_port)
|
||||||
|
@ -198,7 +215,7 @@ class AsyncioQuicManager(AsyncQuicManager):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Copy the itertor into a list as exiting things will mutate the connections
|
# Copy the iterator into a list as exiting things will mutate the connections
|
||||||
# table.
|
# table.
|
||||||
connections = list(self._connections.values())
|
connections = list(self._connections.values())
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
|
|
|
@ -3,13 +3,12 @@
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aioquic.quic.configuration # type: ignore
|
import aioquic.quic.configuration # type: ignore
|
||||||
import aioquic.quic.connection # type: ignore
|
import aioquic.quic.connection # type: ignore
|
||||||
import dns.inet
|
|
||||||
|
|
||||||
|
import dns.inet
|
||||||
|
|
||||||
QUIC_MAX_DATAGRAM = 2048
|
QUIC_MAX_DATAGRAM = 2048
|
||||||
|
|
||||||
|
@ -135,12 +134,12 @@ class BaseQuicConnection:
|
||||||
|
|
||||||
|
|
||||||
class AsyncQuicConnection(BaseQuicConnection):
|
class AsyncQuicConnection(BaseQuicConnection):
|
||||||
async def make_stream(self) -> Any:
|
async def make_stream(self, timeout: Optional[float] = None) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseQuicManager:
|
class BaseQuicManager:
|
||||||
def __init__(self, conf, verify_mode, connection_factory):
|
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
|
||||||
self._connections = {}
|
self._connections = {}
|
||||||
self._connection_factory = connection_factory
|
self._connection_factory = connection_factory
|
||||||
if conf is None:
|
if conf is None:
|
||||||
|
@ -151,6 +150,7 @@ class BaseQuicManager:
|
||||||
conf = aioquic.quic.configuration.QuicConfiguration(
|
conf = aioquic.quic.configuration.QuicConfiguration(
|
||||||
alpn_protocols=["doq", "doq-i03"],
|
alpn_protocols=["doq", "doq-i03"],
|
||||||
verify_mode=verify_mode,
|
verify_mode=verify_mode,
|
||||||
|
server_name=server_name,
|
||||||
)
|
)
|
||||||
if verify_path is not None:
|
if verify_path is not None:
|
||||||
conf.load_verify_locations(verify_path)
|
conf.load_verify_locations(verify_path)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
|
||||||
|
import selectors
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
import selectors
|
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
@ -10,13 +10,15 @@ import time
|
||||||
import aioquic.quic.configuration # type: ignore
|
import aioquic.quic.configuration # type: ignore
|
||||||
import aioquic.quic.connection # type: ignore
|
import aioquic.quic.connection # type: ignore
|
||||||
import aioquic.quic.events # type: ignore
|
import aioquic.quic.events # type: ignore
|
||||||
import dns.inet
|
|
||||||
|
|
||||||
|
import dns.exception
|
||||||
|
import dns.inet
|
||||||
from dns.quic._common import (
|
from dns.quic._common import (
|
||||||
BaseQuicStream,
|
QUIC_MAX_DATAGRAM,
|
||||||
BaseQuicConnection,
|
BaseQuicConnection,
|
||||||
BaseQuicManager,
|
BaseQuicManager,
|
||||||
QUIC_MAX_DATAGRAM,
|
BaseQuicStream,
|
||||||
|
UnexpectedEOF,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Avoid circularity with dns.query
|
# Avoid circularity with dns.query
|
||||||
|
@ -33,14 +35,15 @@ class SyncQuicStream(BaseQuicStream):
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def wait_for(self, amount, expiration):
|
def wait_for(self, amount, expiration):
|
||||||
timeout = self._timeout_from_expiration(expiration)
|
|
||||||
while True:
|
while True:
|
||||||
|
timeout = self._timeout_from_expiration(expiration)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._buffer.have(amount):
|
if self._buffer.have(amount):
|
||||||
return
|
return
|
||||||
self._expecting = amount
|
self._expecting = amount
|
||||||
with self._wake_up:
|
with self._wake_up:
|
||||||
self._wake_up.wait(timeout)
|
if not self._wake_up.wait(timeout):
|
||||||
|
raise dns.exception.Timeout
|
||||||
self._expecting = 0
|
self._expecting = 0
|
||||||
|
|
||||||
def receive(self, timeout=None):
|
def receive(self, timeout=None):
|
||||||
|
@ -114,24 +117,30 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||||
return
|
return
|
||||||
|
|
||||||
def _worker(self):
|
def _worker(self):
|
||||||
sel = _selector_class()
|
try:
|
||||||
sel.register(self._socket, selectors.EVENT_READ, self._read)
|
sel = _selector_class()
|
||||||
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
|
sel.register(self._socket, selectors.EVENT_READ, self._read)
|
||||||
while not self._done:
|
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
|
||||||
(expiration, interval) = self._get_timer_values(False)
|
while not self._done:
|
||||||
items = sel.select(interval)
|
(expiration, interval) = self._get_timer_values(False)
|
||||||
for (key, _) in items:
|
items = sel.select(interval)
|
||||||
key.data()
|
for key, _ in items:
|
||||||
|
key.data()
|
||||||
|
with self._lock:
|
||||||
|
self._handle_timer(expiration)
|
||||||
|
datagrams = self._connection.datagrams_to_send(time.time())
|
||||||
|
for datagram, _ in datagrams:
|
||||||
|
try:
|
||||||
|
self._socket.send(datagram)
|
||||||
|
except BlockingIOError:
|
||||||
|
# we let QUIC handle any lossage
|
||||||
|
pass
|
||||||
|
self._handle_events()
|
||||||
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._handle_timer(expiration)
|
self._done = True
|
||||||
datagrams = self._connection.datagrams_to_send(time.time())
|
# Ensure anyone waiting for this gets woken up.
|
||||||
for (datagram, _) in datagrams:
|
self._handshake_complete.set()
|
||||||
try:
|
|
||||||
self._socket.send(datagram)
|
|
||||||
except BlockingIOError:
|
|
||||||
# we let QUIC handle any lossage
|
|
||||||
pass
|
|
||||||
self._handle_events()
|
|
||||||
|
|
||||||
def _handle_events(self):
|
def _handle_events(self):
|
||||||
while True:
|
while True:
|
||||||
|
@ -163,9 +172,12 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||||
self._worker_thread = threading.Thread(target=self._worker)
|
self._worker_thread = threading.Thread(target=self._worker)
|
||||||
self._worker_thread.start()
|
self._worker_thread.start()
|
||||||
|
|
||||||
def make_stream(self):
|
def make_stream(self, timeout=None):
|
||||||
self._handshake_complete.wait()
|
if not self._handshake_complete.wait(timeout):
|
||||||
|
raise dns.exception.Timeout
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
if self._done:
|
||||||
|
raise UnexpectedEOF
|
||||||
stream_id = self._connection.get_next_available_stream_id(False)
|
stream_id = self._connection.get_next_available_stream_id(False)
|
||||||
stream = SyncQuicStream(self, stream_id)
|
stream = SyncQuicStream(self, stream_id)
|
||||||
self._streams[stream_id] = stream
|
self._streams[stream_id] = stream
|
||||||
|
@ -187,8 +199,8 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||||
|
|
||||||
|
|
||||||
class SyncQuicManager(BaseQuicManager):
|
class SyncQuicManager(BaseQuicManager):
|
||||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||||
super().__init__(conf, verify_mode, SyncQuicConnection)
|
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def connect(self, address, port=853, source=None, source_port=0):
|
def connect(self, address, port=853, source=None, source_port=0):
|
||||||
|
@ -206,7 +218,7 @@ class SyncQuicManager(BaseQuicManager):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Copy the itertor into a list as exiting things will mutate the connections
|
# Copy the iterator into a list as exiting things will mutate the connections
|
||||||
# table.
|
# table.
|
||||||
connections = list(self._connections.values())
|
connections = list(self._connections.values())
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
|
|
|
@ -10,13 +10,15 @@ import aioquic.quic.connection # type: ignore
|
||||||
import aioquic.quic.events # type: ignore
|
import aioquic.quic.events # type: ignore
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
import dns.exception
|
||||||
import dns.inet
|
import dns.inet
|
||||||
from dns._asyncbackend import NullContext
|
from dns._asyncbackend import NullContext
|
||||||
from dns.quic._common import (
|
from dns.quic._common import (
|
||||||
BaseQuicStream,
|
QUIC_MAX_DATAGRAM,
|
||||||
AsyncQuicConnection,
|
AsyncQuicConnection,
|
||||||
AsyncQuicManager,
|
AsyncQuicManager,
|
||||||
QUIC_MAX_DATAGRAM,
|
BaseQuicStream,
|
||||||
|
UnexpectedEOF,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,6 +46,7 @@ class TrioQuicStream(BaseQuicStream):
|
||||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||||
await self.wait_for(size)
|
await self.wait_for(size)
|
||||||
return self._buffer.get(size)
|
return self._buffer.get(size)
|
||||||
|
raise dns.exception.Timeout
|
||||||
|
|
||||||
async def send(self, datagram, is_end=False):
|
async def send(self, datagram, is_end=False):
|
||||||
data = self._encapsulate(datagram)
|
data = self._encapsulate(datagram)
|
||||||
|
@ -80,20 +83,26 @@ class TrioQuicConnection(AsyncQuicConnection):
|
||||||
self._worker_scope = None
|
self._worker_scope = None
|
||||||
|
|
||||||
async def _worker(self):
|
async def _worker(self):
|
||||||
await self._socket.connect(self._peer)
|
try:
|
||||||
while not self._done:
|
await self._socket.connect(self._peer)
|
||||||
(expiration, interval) = self._get_timer_values(False)
|
while not self._done:
|
||||||
with trio.CancelScope(
|
(expiration, interval) = self._get_timer_values(False)
|
||||||
deadline=trio.current_time() + interval
|
with trio.CancelScope(
|
||||||
) as self._worker_scope:
|
deadline=trio.current_time() + interval
|
||||||
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
|
) as self._worker_scope:
|
||||||
self._connection.receive_datagram(datagram, self._peer[0], time.time())
|
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
|
||||||
self._worker_scope = None
|
self._connection.receive_datagram(
|
||||||
self._handle_timer(expiration)
|
datagram, self._peer[0], time.time()
|
||||||
datagrams = self._connection.datagrams_to_send(time.time())
|
)
|
||||||
for (datagram, _) in datagrams:
|
self._worker_scope = None
|
||||||
await self._socket.send(datagram)
|
self._handle_timer(expiration)
|
||||||
await self._handle_events()
|
datagrams = self._connection.datagrams_to_send(time.time())
|
||||||
|
for datagram, _ in datagrams:
|
||||||
|
await self._socket.send(datagram)
|
||||||
|
await self._handle_events()
|
||||||
|
finally:
|
||||||
|
self._done = True
|
||||||
|
self._handshake_complete.set()
|
||||||
|
|
||||||
async def _handle_events(self):
|
async def _handle_events(self):
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -130,12 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection):
|
||||||
nursery.start_soon(self._worker)
|
nursery.start_soon(self._worker)
|
||||||
self._run_done.set()
|
self._run_done.set()
|
||||||
|
|
||||||
async def make_stream(self):
|
async def make_stream(self, timeout=None):
|
||||||
await self._handshake_complete.wait()
|
if timeout is None:
|
||||||
stream_id = self._connection.get_next_available_stream_id(False)
|
context = NullContext(None)
|
||||||
stream = TrioQuicStream(self, stream_id)
|
else:
|
||||||
self._streams[stream_id] = stream
|
context = trio.move_on_after(timeout)
|
||||||
return stream
|
with context:
|
||||||
|
await self._handshake_complete.wait()
|
||||||
|
if self._done:
|
||||||
|
raise UnexpectedEOF
|
||||||
|
stream_id = self._connection.get_next_available_stream_id(False)
|
||||||
|
stream = TrioQuicStream(self, stream_id)
|
||||||
|
self._streams[stream_id] = stream
|
||||||
|
return stream
|
||||||
|
raise dns.exception.Timeout
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
|
@ -148,8 +165,10 @@ class TrioQuicConnection(AsyncQuicConnection):
|
||||||
|
|
||||||
|
|
||||||
class TrioQuicManager(AsyncQuicManager):
|
class TrioQuicManager(AsyncQuicManager):
|
||||||
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
|
def __init__(
|
||||||
super().__init__(conf, verify_mode, TrioQuicConnection)
|
self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
|
||||||
|
):
|
||||||
|
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
|
||||||
self._nursery = nursery
|
self._nursery = nursery
|
||||||
|
|
||||||
def connect(self, address, port=853, source=None, source_port=0):
|
def connect(self, address, port=853, source=None, source_port=0):
|
||||||
|
@ -162,7 +181,7 @@ class TrioQuicManager(AsyncQuicManager):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Copy the itertor into a list as exiting things will mutate the connections
|
# Copy the iterator into a list as exiting things will mutate the connections
|
||||||
# table.
|
# table.
|
||||||
connections = list(self._connections.values())
|
connections = list(self._connections.values())
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
|
|
|
@ -17,17 +17,15 @@
|
||||||
|
|
||||||
"""DNS rdata."""
|
"""DNS rdata."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from importlib import import_module
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import io
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
|
from importlib import import_module
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import dns.wire
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.ipv4
|
import dns.ipv4
|
||||||
|
@ -37,6 +35,7 @@ import dns.rdataclass
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.tokenizer
|
import dns.tokenizer
|
||||||
import dns.ttl
|
import dns.ttl
|
||||||
|
import dns.wire
|
||||||
|
|
||||||
_chunksize = 32
|
_chunksize = 32
|
||||||
|
|
||||||
|
@ -358,7 +357,6 @@ class Rdata:
|
||||||
or self.rdclass != other.rdclass
|
or self.rdclass != other.rdclass
|
||||||
or self.rdtype != other.rdtype
|
or self.rdtype != other.rdtype
|
||||||
):
|
):
|
||||||
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self._cmp(other) < 0
|
return self._cmp(other) < 0
|
||||||
|
|
||||||
|
@ -881,16 +879,11 @@ def register_type(
|
||||||
it applies to all classes.
|
it applies to all classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
existing_cls = get_rdata_class(rdclass, the_rdtype)
|
existing_cls = get_rdata_class(rdclass, rdtype)
|
||||||
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
|
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
|
||||||
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
|
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
|
||||||
try:
|
_rdata_classes[(rdclass, rdtype)] = getattr(
|
||||||
if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
|
|
||||||
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
_rdata_classes[(rdclass, the_rdtype)] = getattr(
|
|
||||||
implementation, rdtype_text.replace("-", "_")
|
implementation, rdtype_text.replace("-", "_")
|
||||||
)
|
)
|
||||||
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)
|
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
|
||||||
|
|
|
@ -17,18 +17,17 @@
|
||||||
|
|
||||||
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
|
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
|
||||||
|
|
||||||
from typing import Any, cast, Collection, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import random
|
import random
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Any, Collection, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.rdatatype
|
|
||||||
import dns.rdataclass
|
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
import dns.rdataclass
|
||||||
|
import dns.rdatatype
|
||||||
import dns.set
|
import dns.set
|
||||||
import dns.ttl
|
import dns.ttl
|
||||||
|
|
||||||
|
@ -471,9 +470,9 @@ def from_text_list(
|
||||||
Returns a ``dns.rdataset.Rdataset`` object.
|
Returns a ``dns.rdataset.Rdataset`` object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
r = Rdataset(the_rdclass, the_rdtype)
|
r = Rdataset(rdclass, rdtype)
|
||||||
r.update_ttl(ttl)
|
r.update_ttl(ttl)
|
||||||
for t in text_rdatas:
|
for t in text_rdatas:
|
||||||
rd = dns.rdata.from_text(
|
rd = dns.rdata.from_text(
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.mxbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.mxbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.txtbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.txtbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,15 +15,15 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from dns.rdtypes.dnskeybase import (
|
from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
|
||||||
SEP,
|
|
||||||
REVOKE,
|
REVOKE,
|
||||||
|
SEP,
|
||||||
ZONE,
|
ZONE,
|
||||||
) # noqa: F401 lgtm[py/unused-import]
|
)
|
||||||
|
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.dsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.dsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,12 +15,12 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import base64
|
import base64
|
||||||
|
import struct
|
||||||
|
|
||||||
|
import dns.dnssectypes
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.dnssectypes
|
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.tokenizer
|
import dns.tokenizer
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.nsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.nsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -19,9 +19,9 @@ import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.name
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.name
|
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.dsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.dsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.nsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.nsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,15 +15,15 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from dns.rdtypes.dnskeybase import (
|
from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
|
||||||
SEP,
|
|
||||||
REVOKE,
|
REVOKE,
|
||||||
|
SEP,
|
||||||
ZONE,
|
ZONE,
|
||||||
) # noqa: F401 lgtm[py/unused-import]
|
)
|
||||||
|
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.dsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.dsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.euibase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.euibase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.euibase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.euibase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
|
|
@ -21,7 +21,6 @@ import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
_pows = tuple(10**i for i in range(0, 11))
|
_pows = tuple(10**i for i in range(0, 11))
|
||||||
|
|
||||||
# default values are in centimeters
|
# default values are in centimeters
|
||||||
|
@ -40,7 +39,7 @@ def _exponent_of(what, desc):
|
||||||
if what == 0:
|
if what == 0:
|
||||||
return 0
|
return 0
|
||||||
exp = None
|
exp = None
|
||||||
for (i, pow) in enumerate(_pows):
|
for i, pow in enumerate(_pows):
|
||||||
if what < pow:
|
if what < pow:
|
||||||
exp = i - 1
|
exp = i - 1
|
||||||
break
|
break
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.mxbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.mxbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.txtbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.txtbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.nsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.nsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -17,9 +17,9 @@
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.name
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.name
|
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdata
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
|
|
||||||
|
|
||||||
b32_hex_to_normal = bytes.maketrans(
|
b32_hex_to_normal = bytes.maketrans(
|
||||||
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
|
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
|
||||||
)
|
)
|
||||||
|
@ -67,6 +66,7 @@ class NSEC3(dns.rdata.Rdata):
|
||||||
|
|
||||||
def to_text(self, origin=None, relativize=True, **kw):
|
def to_text(self, origin=None, relativize=True, **kw):
|
||||||
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
|
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
|
||||||
|
next = next.rstrip("=")
|
||||||
if self.salt == b"":
|
if self.salt == b"":
|
||||||
salt = "-"
|
salt = "-"
|
||||||
else:
|
else:
|
||||||
|
@ -94,6 +94,10 @@ class NSEC3(dns.rdata.Rdata):
|
||||||
else:
|
else:
|
||||||
salt = binascii.unhexlify(salt.encode("ascii"))
|
salt = binascii.unhexlify(salt.encode("ascii"))
|
||||||
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
|
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
|
||||||
|
if next.endswith(b"="):
|
||||||
|
raise binascii.Error("Incorrect padding")
|
||||||
|
if len(next) % 8 != 0:
|
||||||
|
next += b"=" * (8 - len(next) % 8)
|
||||||
next = base64.b32decode(next)
|
next = base64.b32decode(next)
|
||||||
bitmap = Bitmap.from_text(tok)
|
bitmap = Bitmap.from_text(tok)
|
||||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
|
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
|
|
@ -18,11 +18,10 @@
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
import dns.edns
|
import dns.edns
|
||||||
import dns.immutable
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
import dns.immutable
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
# We don't implement from_text, and that's ok.
|
# We don't implement from_text, and that's ok.
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.nsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.nsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.rdata
|
|
||||||
import dns.name
|
import dns.name
|
||||||
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -21,8 +21,8 @@ import struct
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import dns.dnssectypes
|
import dns.dnssectypes
|
||||||
import dns.immutable
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
import dns.immutable
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.mxbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.mxbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -19,8 +19,8 @@ import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.rdata
|
|
||||||
import dns.name
|
import dns.name
|
||||||
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.txtbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.txtbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.rdata
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdata
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
import base64
|
import base64
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
import dns.immutable
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
|
import dns.immutable
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.txtbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.txtbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -20,9 +20,9 @@ import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.name
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
import dns.name
|
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
|
||||||
import struct
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
import dns.rdtypes.mxbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.mxbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -124,7 +124,6 @@ class APL(dns.rdata.Rdata):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
while parser.remaining() > 0:
|
while parser.remaining() > 0:
|
||||||
header = parser.get_struct("!HBB")
|
header = parser.get_struct("!HBB")
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
|
||||||
import dns.rdtypes.svcbbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.svcbbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import base64
|
import base64
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.mxbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.mxbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import dns.rdtypes.nsbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.nsbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -19,9 +19,9 @@ import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.name
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
import dns.name
|
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -19,9 +19,9 @@ import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.name
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
import dns.name
|
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
|
||||||
import dns.rdtypes.svcbbase
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdtypes.svcbbase
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
import dns.ipv4
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.ipv4
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -19,9 +19,9 @@ import base64
|
||||||
import enum
|
import enum
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
import dns.dnssectypes
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.dnssectypes
|
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
|
|
||||||
# wildcard import
|
# wildcard import
|
||||||
|
@ -43,7 +43,7 @@ class DNSKEYBase(dns.rdata.Rdata):
|
||||||
|
|
||||||
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
|
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
|
||||||
super().__init__(rdclass, rdtype)
|
super().__init__(rdclass, rdtype)
|
||||||
self.flags = self._as_uint16(flags)
|
self.flags = Flag(self._as_uint16(flags))
|
||||||
self.protocol = self._as_uint8(protocol)
|
self.protocol = self._as_uint8(protocol)
|
||||||
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
|
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
|
||||||
self.key = self._as_bytes(key)
|
self.key = self._as_bytes(key)
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.dnssectypes
|
import dns.dnssectypes
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
@ -44,7 +44,7 @@ class DSBase(dns.rdata.Rdata):
|
||||||
super().__init__(rdclass, rdtype)
|
super().__init__(rdclass, rdtype)
|
||||||
self.key_tag = self._as_uint16(key_tag)
|
self.key_tag = self._as_uint16(key_tag)
|
||||||
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
|
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
|
||||||
self.digest_type = self._as_uint8(digest_type)
|
self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type))
|
||||||
self.digest = self._as_bytes(digest)
|
self.digest = self._as_bytes(digest)
|
||||||
try:
|
try:
|
||||||
if len(self.digest) != self._digest_length_by_type[self.digest_type]:
|
if len(self.digest) != self._digest_length_by_type[self.digest_type]:
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
|
|
||||||
import dns.rdata
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -21,8 +21,8 @@ import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.rdata
|
|
||||||
import dns.name
|
import dns.name
|
||||||
|
import dns.rdata
|
||||||
import dns.rdtypes.util
|
import dns.rdtypes.util
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
import dns.rdata
|
|
||||||
import dns.name
|
import dns.name
|
||||||
|
import dns.rdata
|
||||||
|
|
||||||
|
|
||||||
@dns.immutable.immutable
|
@dns.immutable.immutable
|
||||||
|
|
|
@ -34,6 +34,7 @@ class ParamKey(dns.enum.IntEnum):
|
||||||
IPV4HINT = 4
|
IPV4HINT = 4
|
||||||
ECH = 5
|
ECH = 5
|
||||||
IPV6HINT = 6
|
IPV6HINT = 6
|
||||||
|
DOHPATH = 7
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _maximum(cls):
|
def _maximum(cls):
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
import struct
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns.rdata
|
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
import dns.rdata
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,9 +17,8 @@
|
||||||
|
|
||||||
"""TXT-like base class."""
|
"""TXT-like base class."""
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.immutable
|
import dns.immutable
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
import collections
|
import collections
|
||||||
import random
|
import random
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.ipv4
|
import dns.ipv4
|
||||||
|
@ -119,7 +120,7 @@ class Bitmap:
|
||||||
def __init__(self, windows=None):
|
def __init__(self, windows=None):
|
||||||
last_window = -1
|
last_window = -1
|
||||||
self.windows = windows
|
self.windows = windows
|
||||||
for (window, bitmap) in self.windows:
|
for window, bitmap in self.windows:
|
||||||
if not isinstance(window, int):
|
if not isinstance(window, int):
|
||||||
raise ValueError(f"bad {self.type_name} window type")
|
raise ValueError(f"bad {self.type_name} window type")
|
||||||
if window <= last_window:
|
if window <= last_window:
|
||||||
|
@ -132,11 +133,11 @@ class Bitmap:
|
||||||
if len(bitmap) == 0 or len(bitmap) > 32:
|
if len(bitmap) == 0 or len(bitmap) > 32:
|
||||||
raise ValueError(f"bad {self.type_name} octets")
|
raise ValueError(f"bad {self.type_name} octets")
|
||||||
|
|
||||||
def to_text(self):
|
def to_text(self) -> str:
|
||||||
text = ""
|
text = ""
|
||||||
for (window, bitmap) in self.windows:
|
for window, bitmap in self.windows:
|
||||||
bits = []
|
bits = []
|
||||||
for (i, byte) in enumerate(bitmap):
|
for i, byte in enumerate(bitmap):
|
||||||
for j in range(0, 8):
|
for j in range(0, 8):
|
||||||
if byte & (0x80 >> j):
|
if byte & (0x80 >> j):
|
||||||
rdtype = window * 256 + i * 8 + j
|
rdtype = window * 256 + i * 8 + j
|
||||||
|
@ -145,14 +146,18 @@ class Bitmap:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_text(cls, tok):
|
def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap":
|
||||||
rdtypes = []
|
rdtypes = []
|
||||||
for token in tok.get_remaining():
|
for token in tok.get_remaining():
|
||||||
rdtype = dns.rdatatype.from_text(token.unescape().value)
|
rdtype = dns.rdatatype.from_text(token.unescape().value)
|
||||||
if rdtype == 0:
|
if rdtype == 0:
|
||||||
raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
|
raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
|
||||||
rdtypes.append(rdtype)
|
rdtypes.append(rdtype)
|
||||||
rdtypes.sort()
|
return cls.from_rdtypes(rdtypes)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap":
|
||||||
|
rdtypes = sorted(rdtypes)
|
||||||
window = 0
|
window = 0
|
||||||
octets = 0
|
octets = 0
|
||||||
prior_rdtype = 0
|
prior_rdtype = 0
|
||||||
|
@ -177,13 +182,13 @@ class Bitmap:
|
||||||
windows.append((window, bytes(bitmap[0:octets])))
|
windows.append((window, bytes(bitmap[0:octets])))
|
||||||
return cls(windows)
|
return cls(windows)
|
||||||
|
|
||||||
def to_wire(self, file):
|
def to_wire(self, file: Any) -> None:
|
||||||
for (window, bitmap) in self.windows:
|
for window, bitmap in self.windows:
|
||||||
file.write(struct.pack("!BB", window, len(bitmap)))
|
file.write(struct.pack("!BB", window, len(bitmap)))
|
||||||
file.write(bitmap)
|
file.write(bitmap)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_wire_parser(cls, parser):
|
def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap":
|
||||||
windows = []
|
windows = []
|
||||||
while parser.remaining() > 0:
|
while parser.remaining() > 0:
|
||||||
window = parser.get_uint8()
|
window = parser.get_uint8()
|
||||||
|
@ -226,7 +231,7 @@ def weighted_processing_order(iterable):
|
||||||
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
|
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
|
||||||
while len(rdatas) > 1:
|
while len(rdatas) > 1:
|
||||||
r = random.uniform(0, total)
|
r = random.uniform(0, total)
|
||||||
for (n, rdata) in enumerate(rdatas):
|
for n, rdata in enumerate(rdatas):
|
||||||
weight = rdata._processing_weight() or _no_weight
|
weight = rdata._processing_weight() or _no_weight
|
||||||
if weight > r:
|
if weight > r:
|
||||||
break
|
break
|
||||||
|
|
|
@ -19,14 +19,13 @@
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import struct
|
|
||||||
import random
|
import random
|
||||||
|
import struct
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.tsig
|
import dns.tsig
|
||||||
|
|
||||||
|
|
||||||
QUESTION = 0
|
QUESTION = 0
|
||||||
ANSWER = 1
|
ANSWER = 1
|
||||||
AUTHORITY = 2
|
AUTHORITY = 2
|
||||||
|
|
|
@ -17,29 +17,31 @@
|
||||||
|
|
||||||
"""DNS stub resolver."""
|
"""DNS stub resolver."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import random
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import dns.exception
|
import dns._ddr
|
||||||
import dns.edns
|
import dns.edns
|
||||||
|
import dns.exception
|
||||||
import dns.flags
|
import dns.flags
|
||||||
import dns.inet
|
import dns.inet
|
||||||
import dns.ipv4
|
import dns.ipv4
|
||||||
import dns.ipv6
|
import dns.ipv6
|
||||||
import dns.message
|
import dns.message
|
||||||
import dns.name
|
import dns.name
|
||||||
|
import dns.nameserver
|
||||||
import dns.query
|
import dns.query
|
||||||
import dns.rcode
|
import dns.rcode
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
|
import dns.rdtypes.svcbbase
|
||||||
import dns.reversename
|
import dns.reversename
|
||||||
import dns.tsig
|
import dns.tsig
|
||||||
|
|
||||||
|
@ -72,7 +74,7 @@ class NXDOMAIN(dns.exception.DNSException):
|
||||||
kwargs = dict(qnames=qnames, responses=responses)
|
kwargs = dict(qnames=qnames, responses=responses)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
if "qnames" not in self.kwargs:
|
if "qnames" not in self.kwargs:
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
qnames = self.kwargs["qnames"]
|
qnames = self.kwargs["qnames"]
|
||||||
|
@ -140,7 +142,11 @@ class YXDOMAIN(dns.exception.DNSException):
|
||||||
|
|
||||||
|
|
||||||
ErrorTuple = Tuple[
|
ErrorTuple = Tuple[
|
||||||
Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]
|
Optional[str],
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
Union[Exception, str],
|
||||||
|
Optional[dns.message.Message],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,11 +154,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
|
||||||
"""Turn a resolution errors trace into a list of text."""
|
"""Turn a resolution errors trace into a list of text."""
|
||||||
texts = []
|
texts = []
|
||||||
for err in errors:
|
for err in errors:
|
||||||
texts.append(
|
texts.append("Server {} answered {}".format(err[0], err[3]))
|
||||||
"Server {} {} port {} answered {}".format(
|
|
||||||
err[0], "TCP" if err[1] else "UDP", err[2], err[3]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,7 +186,7 @@ Timeout = LifetimeTimeout
|
||||||
class NoAnswer(dns.exception.DNSException):
|
class NoAnswer(dns.exception.DNSException):
|
||||||
"""The DNS response does not contain an answer to the question."""
|
"""The DNS response does not contain an answer to the question."""
|
||||||
|
|
||||||
fmt = "The DNS response does not contain an answer " + "to the question: {query}"
|
fmt = "The DNS response does not contain an answer to the question: {query}"
|
||||||
supp_kwargs = {"response"}
|
supp_kwargs = {"response"}
|
||||||
|
|
||||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||||
|
@ -264,7 +266,7 @@ class Answer:
|
||||||
response: dns.message.QueryMessage,
|
response: dns.message.QueryMessage,
|
||||||
nameserver: Optional[str] = None,
|
nameserver: Optional[str] = None,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
):
|
) -> None:
|
||||||
self.qname = qname
|
self.qname = qname
|
||||||
self.rdtype = rdtype
|
self.rdtype = rdtype
|
||||||
self.rdclass = rdclass
|
self.rdclass = rdclass
|
||||||
|
@ -292,7 +294,7 @@ class Answer:
|
||||||
else:
|
else:
|
||||||
raise AttributeError(attr)
|
raise AttributeError(attr)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self.rrset and len(self.rrset) or 0
|
return self.rrset and len(self.rrset) or 0
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -309,14 +311,67 @@ class Answer:
|
||||||
del self.rrset[i]
|
del self.rrset[i]
|
||||||
|
|
||||||
|
|
||||||
|
class Answers(dict):
|
||||||
|
"""A dict of DNS stub resolver answers, indexed by type."""
|
||||||
|
|
||||||
|
|
||||||
|
class HostAnswers(Answers):
|
||||||
|
"""A dict of DNS stub resolver answers to a host name lookup, indexed by
|
||||||
|
type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make(
|
||||||
|
cls,
|
||||||
|
v6: Optional[Answer] = None,
|
||||||
|
v4: Optional[Answer] = None,
|
||||||
|
add_empty: bool = True,
|
||||||
|
) -> "HostAnswers":
|
||||||
|
answers = HostAnswers()
|
||||||
|
if v6 is not None and (add_empty or v6.rrset):
|
||||||
|
answers[dns.rdatatype.AAAA] = v6
|
||||||
|
if v4 is not None and (add_empty or v4.rrset):
|
||||||
|
answers[dns.rdatatype.A] = v4
|
||||||
|
return answers
|
||||||
|
|
||||||
|
# Returns pairs of (address, family) from this result, potentiallys
|
||||||
|
# filtering by address family.
|
||||||
|
def addresses_and_families(
|
||||||
|
self, family: int = socket.AF_UNSPEC
|
||||||
|
) -> Iterator[Tuple[str, int]]:
|
||||||
|
if family == socket.AF_UNSPEC:
|
||||||
|
yield from self.addresses_and_families(socket.AF_INET6)
|
||||||
|
yield from self.addresses_and_families(socket.AF_INET)
|
||||||
|
return
|
||||||
|
elif family == socket.AF_INET6:
|
||||||
|
answer = self.get(dns.rdatatype.AAAA)
|
||||||
|
elif family == socket.AF_INET:
|
||||||
|
answer = self.get(dns.rdatatype.A)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unknown address family {family}")
|
||||||
|
if answer:
|
||||||
|
for rdata in answer:
|
||||||
|
yield (rdata.address, family)
|
||||||
|
|
||||||
|
# Returns addresses from this result, potentially filtering by
|
||||||
|
# address family.
|
||||||
|
def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]:
|
||||||
|
return (pair[0] for pair in self.addresses_and_families(family))
|
||||||
|
|
||||||
|
# Returns the canonical name from this result.
|
||||||
|
def canonical_name(self) -> dns.name.Name:
|
||||||
|
answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A))
|
||||||
|
return answer.canonical_name
|
||||||
|
|
||||||
|
|
||||||
class CacheStatistics:
|
class CacheStatistics:
|
||||||
"""Cache Statistics"""
|
"""Cache Statistics"""
|
||||||
|
|
||||||
def __init__(self, hits=0, misses=0):
|
def __init__(self, hits: int = 0, misses: int = 0) -> None:
|
||||||
self.hits = hits
|
self.hits = hits
|
||||||
self.misses = misses
|
self.misses = misses
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
self.hits = 0
|
self.hits = 0
|
||||||
self.misses = 0
|
self.misses = 0
|
||||||
|
|
||||||
|
@ -325,7 +380,7 @@ class CacheStatistics:
|
||||||
|
|
||||||
|
|
||||||
class CacheBase:
|
class CacheBase:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.statistics = CacheStatistics()
|
self.statistics = CacheStatistics()
|
||||||
|
|
||||||
|
@ -361,7 +416,7 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla
|
||||||
class Cache(CacheBase):
|
class Cache(CacheBase):
|
||||||
"""Simple thread-safe DNS answer cache."""
|
"""Simple thread-safe DNS answer cache."""
|
||||||
|
|
||||||
def __init__(self, cleaning_interval: float = 300.0):
|
def __init__(self, cleaning_interval: float = 300.0) -> None:
|
||||||
"""*cleaning_interval*, a ``float`` is the number of seconds between
|
"""*cleaning_interval*, a ``float`` is the number of seconds between
|
||||||
periodic cleanings.
|
periodic cleanings.
|
||||||
"""
|
"""
|
||||||
|
@ -377,7 +432,7 @@ class Cache(CacheBase):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if self.next_cleaning <= now:
|
if self.next_cleaning <= now:
|
||||||
keys_to_delete = []
|
keys_to_delete = []
|
||||||
for (k, v) in self.data.items():
|
for k, v in self.data.items():
|
||||||
if v.expiration <= now:
|
if v.expiration <= now:
|
||||||
keys_to_delete.append(k)
|
keys_to_delete.append(k)
|
||||||
for k in keys_to_delete:
|
for k in keys_to_delete:
|
||||||
|
@ -447,13 +502,13 @@ class LRUCacheNode:
|
||||||
self.prev = self
|
self.prev = self
|
||||||
self.next = self
|
self.next = self
|
||||||
|
|
||||||
def link_after(self, node):
|
def link_after(self, node: "LRUCacheNode") -> None:
|
||||||
self.prev = node
|
self.prev = node
|
||||||
self.next = node.next
|
self.next = node.next
|
||||||
node.next.prev = self
|
node.next.prev = self
|
||||||
node.next = self
|
node.next = self
|
||||||
|
|
||||||
def unlink(self):
|
def unlink(self) -> None:
|
||||||
self.next.prev = self.prev
|
self.next.prev = self.prev
|
||||||
self.prev.next = self.next
|
self.prev.next = self.next
|
||||||
|
|
||||||
|
@ -468,7 +523,7 @@ class LRUCache(CacheBase):
|
||||||
for a new one.
|
for a new one.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_size: int = 100000):
|
def __init__(self, max_size: int = 100000) -> None:
|
||||||
"""*max_size*, an ``int``, is the maximum number of nodes to cache;
|
"""*max_size*, an ``int``, is the maximum number of nodes to cache;
|
||||||
it must be greater than 0.
|
it must be greater than 0.
|
||||||
"""
|
"""
|
||||||
|
@ -590,30 +645,29 @@ class _Resolution:
|
||||||
tcp: bool,
|
tcp: bool,
|
||||||
raise_on_no_answer: bool,
|
raise_on_no_answer: bool,
|
||||||
search: Optional[bool],
|
search: Optional[bool],
|
||||||
):
|
) -> None:
|
||||||
if isinstance(qname, str):
|
if isinstance(qname, str):
|
||||||
qname = dns.name.from_text(qname, None)
|
qname = dns.name.from_text(qname, None)
|
||||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
if dns.rdatatype.is_metatype(the_rdtype):
|
if dns.rdatatype.is_metatype(rdtype):
|
||||||
raise NoMetaqueries
|
raise NoMetaqueries
|
||||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||||
if dns.rdataclass.is_metaclass(the_rdclass):
|
if dns.rdataclass.is_metaclass(rdclass):
|
||||||
raise NoMetaqueries
|
raise NoMetaqueries
|
||||||
self.resolver = resolver
|
self.resolver = resolver
|
||||||
self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
|
self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
|
||||||
self.qnames = self.qnames_to_try[:]
|
self.qnames = self.qnames_to_try[:]
|
||||||
self.rdtype = the_rdtype
|
self.rdtype = rdtype
|
||||||
self.rdclass = the_rdclass
|
self.rdclass = rdclass
|
||||||
self.tcp = tcp
|
self.tcp = tcp
|
||||||
self.raise_on_no_answer = raise_on_no_answer
|
self.raise_on_no_answer = raise_on_no_answer
|
||||||
self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
|
self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
|
||||||
# Initialize other things to help analysis tools
|
# Initialize other things to help analysis tools
|
||||||
self.qname = dns.name.empty
|
self.qname = dns.name.empty
|
||||||
self.nameservers: List[str] = []
|
self.nameservers: List[dns.nameserver.Nameserver] = []
|
||||||
self.current_nameservers: List[str] = []
|
self.current_nameservers: List[dns.nameserver.Nameserver] = []
|
||||||
self.errors: List[ErrorTuple] = []
|
self.errors: List[ErrorTuple] = []
|
||||||
self.nameserver: Optional[str] = None
|
self.nameserver: Optional[dns.nameserver.Nameserver] = None
|
||||||
self.port = 0
|
|
||||||
self.tcp_attempt = False
|
self.tcp_attempt = False
|
||||||
self.retry_with_tcp = False
|
self.retry_with_tcp = False
|
||||||
self.request: Optional[dns.message.QueryMessage] = None
|
self.request: Optional[dns.message.QueryMessage] = None
|
||||||
|
@ -670,7 +724,11 @@ class _Resolution:
|
||||||
if self.resolver.flags is not None:
|
if self.resolver.flags is not None:
|
||||||
request.flags = self.resolver.flags
|
request.flags = self.resolver.flags
|
||||||
|
|
||||||
self.nameservers = self.resolver.nameservers[:]
|
self.nameservers = self.resolver._enrich_nameservers(
|
||||||
|
self.resolver._nameservers,
|
||||||
|
self.resolver.nameserver_ports,
|
||||||
|
self.resolver.port,
|
||||||
|
)
|
||||||
if self.resolver.rotate:
|
if self.resolver.rotate:
|
||||||
random.shuffle(self.nameservers)
|
random.shuffle(self.nameservers)
|
||||||
self.current_nameservers = self.nameservers[:]
|
self.current_nameservers = self.nameservers[:]
|
||||||
|
@ -690,12 +748,13 @@ class _Resolution:
|
||||||
#
|
#
|
||||||
raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
|
raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
|
||||||
|
|
||||||
def next_nameserver(self) -> Tuple[str, int, bool, float]:
|
def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]:
|
||||||
if self.retry_with_tcp:
|
if self.retry_with_tcp:
|
||||||
assert self.nameserver is not None
|
assert self.nameserver is not None
|
||||||
|
assert not self.nameserver.is_always_max_size()
|
||||||
self.tcp_attempt = True
|
self.tcp_attempt = True
|
||||||
self.retry_with_tcp = False
|
self.retry_with_tcp = False
|
||||||
return (self.nameserver, self.port, True, 0)
|
return (self.nameserver, True, 0)
|
||||||
|
|
||||||
backoff = 0.0
|
backoff = 0.0
|
||||||
if not self.current_nameservers:
|
if not self.current_nameservers:
|
||||||
|
@ -707,11 +766,8 @@ class _Resolution:
|
||||||
self.backoff = min(self.backoff * 2, 2)
|
self.backoff = min(self.backoff * 2, 2)
|
||||||
|
|
||||||
self.nameserver = self.current_nameservers.pop(0)
|
self.nameserver = self.current_nameservers.pop(0)
|
||||||
self.port = self.resolver.nameserver_ports.get(
|
self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size()
|
||||||
self.nameserver, self.resolver.port
|
return (self.nameserver, self.tcp_attempt, backoff)
|
||||||
)
|
|
||||||
self.tcp_attempt = self.tcp
|
|
||||||
return (self.nameserver, self.port, self.tcp_attempt, backoff)
|
|
||||||
|
|
||||||
def query_result(
|
def query_result(
|
||||||
self, response: Optional[dns.message.Message], ex: Optional[Exception]
|
self, response: Optional[dns.message.Message], ex: Optional[Exception]
|
||||||
|
@ -724,7 +780,13 @@ class _Resolution:
|
||||||
# Exception during I/O or from_wire()
|
# Exception during I/O or from_wire()
|
||||||
assert response is None
|
assert response is None
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
(self.nameserver, self.tcp_attempt, self.port, ex, response)
|
(
|
||||||
|
str(self.nameserver),
|
||||||
|
self.tcp_attempt,
|
||||||
|
self.nameserver.answer_port(),
|
||||||
|
ex,
|
||||||
|
response,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
isinstance(ex, dns.exception.FormError)
|
isinstance(ex, dns.exception.FormError)
|
||||||
|
@ -752,12 +814,18 @@ class _Resolution:
|
||||||
self.rdtype,
|
self.rdtype,
|
||||||
self.rdclass,
|
self.rdclass,
|
||||||
response,
|
response,
|
||||||
self.nameserver,
|
self.nameserver.answer_nameserver(),
|
||||||
self.port,
|
self.nameserver.answer_port(),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
(self.nameserver, self.tcp_attempt, self.port, e, response)
|
(
|
||||||
|
str(self.nameserver),
|
||||||
|
self.tcp_attempt,
|
||||||
|
self.nameserver.answer_port(),
|
||||||
|
e,
|
||||||
|
response,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# The nameserver is no good, take it out of the mix.
|
# The nameserver is no good, take it out of the mix.
|
||||||
self.nameservers.remove(self.nameserver)
|
self.nameservers.remove(self.nameserver)
|
||||||
|
@ -776,7 +844,13 @@ class _Resolution:
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
(self.nameserver, self.tcp_attempt, self.port, e, response)
|
(
|
||||||
|
str(self.nameserver),
|
||||||
|
self.tcp_attempt,
|
||||||
|
self.nameserver.answer_port(),
|
||||||
|
e,
|
||||||
|
response,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# The nameserver is no good, take it out of the mix.
|
# The nameserver is no good, take it out of the mix.
|
||||||
self.nameservers.remove(self.nameserver)
|
self.nameservers.remove(self.nameserver)
|
||||||
|
@ -792,7 +866,13 @@ class _Resolution:
|
||||||
elif rcode == dns.rcode.YXDOMAIN:
|
elif rcode == dns.rcode.YXDOMAIN:
|
||||||
yex = YXDOMAIN()
|
yex = YXDOMAIN()
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
(self.nameserver, self.tcp_attempt, self.port, yex, response)
|
(
|
||||||
|
str(self.nameserver),
|
||||||
|
self.tcp_attempt,
|
||||||
|
self.nameserver.answer_port(),
|
||||||
|
yex,
|
||||||
|
response,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
raise yex
|
raise yex
|
||||||
else:
|
else:
|
||||||
|
@ -804,9 +884,9 @@ class _Resolution:
|
||||||
self.nameservers.remove(self.nameserver)
|
self.nameservers.remove(self.nameserver)
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
(
|
(
|
||||||
self.nameserver,
|
str(self.nameserver),
|
||||||
self.tcp_attempt,
|
self.tcp_attempt,
|
||||||
self.port,
|
self.nameserver.answer_port(),
|
||||||
dns.rcode.to_text(rcode),
|
dns.rcode.to_text(rcode),
|
||||||
response,
|
response,
|
||||||
)
|
)
|
||||||
|
@ -840,8 +920,11 @@ class BaseResolver:
|
||||||
retry_servfail: bool
|
retry_servfail: bool
|
||||||
rotate: bool
|
rotate: bool
|
||||||
ndots: Optional[int]
|
ndots: Optional[int]
|
||||||
|
_nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
|
||||||
|
|
||||||
def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True):
|
def __init__(
|
||||||
|
self, filename: str = "/etc/resolv.conf", configure: bool = True
|
||||||
|
) -> None:
|
||||||
"""*filename*, a ``str`` or file object, specifying a file
|
"""*filename*, a ``str`` or file object, specifying a file
|
||||||
in standard /etc/resolv.conf format. This parameter is meaningful
|
in standard /etc/resolv.conf format. This parameter is meaningful
|
||||||
only when *configure* is true and the platform is POSIX.
|
only when *configure* is true and the platform is POSIX.
|
||||||
|
@ -860,13 +943,13 @@ class BaseResolver:
|
||||||
elif filename:
|
elif filename:
|
||||||
self.read_resolv_conf(filename)
|
self.read_resolv_conf(filename)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
"""Reset all resolver configuration to the defaults."""
|
"""Reset all resolver configuration to the defaults."""
|
||||||
|
|
||||||
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
|
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
|
||||||
if len(self.domain) == 0:
|
if len(self.domain) == 0:
|
||||||
self.domain = dns.name.root
|
self.domain = dns.name.root
|
||||||
self.nameservers = []
|
self._nameservers = []
|
||||||
self.nameserver_ports = {}
|
self.nameserver_ports = {}
|
||||||
self.port = 53
|
self.port = 53
|
||||||
self.search = []
|
self.search = []
|
||||||
|
@ -903,6 +986,7 @@ class BaseResolver:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
nameservers = []
|
||||||
if isinstance(f, str):
|
if isinstance(f, str):
|
||||||
try:
|
try:
|
||||||
cm: contextlib.AbstractContextManager = open(f)
|
cm: contextlib.AbstractContextManager = open(f)
|
||||||
|
@ -922,7 +1006,7 @@ class BaseResolver:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tokens[0] == "nameserver":
|
if tokens[0] == "nameserver":
|
||||||
self.nameservers.append(tokens[1])
|
nameservers.append(tokens[1])
|
||||||
elif tokens[0] == "domain":
|
elif tokens[0] == "domain":
|
||||||
self.domain = dns.name.from_text(tokens[1])
|
self.domain = dns.name.from_text(tokens[1])
|
||||||
# domain and search are exclusive
|
# domain and search are exclusive
|
||||||
|
@ -950,8 +1034,11 @@ class BaseResolver:
|
||||||
self.ndots = int(opt.split(":")[1])
|
self.ndots = int(opt.split(":")[1])
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
pass
|
pass
|
||||||
if len(self.nameservers) == 0:
|
if len(nameservers) == 0:
|
||||||
raise NoResolverConfiguration("no nameservers")
|
raise NoResolverConfiguration("no nameservers")
|
||||||
|
# Assigning directly instead of appending means we invoke the
|
||||||
|
# setter logic, with additonal checking and enrichment.
|
||||||
|
self.nameservers = nameservers
|
||||||
|
|
||||||
def read_registry(self) -> None:
|
def read_registry(self) -> None:
|
||||||
"""Extract resolver configuration from the Windows registry."""
|
"""Extract resolver configuration from the Windows registry."""
|
||||||
|
@ -1086,34 +1173,64 @@ class BaseResolver:
|
||||||
|
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
|
|
||||||
@property
|
@classmethod
|
||||||
def nameservers(self) -> List[str]:
|
def _enrich_nameservers(
|
||||||
return self._nameservers
|
cls,
|
||||||
|
nameservers: Sequence[Union[str, dns.nameserver.Nameserver]],
|
||||||
@nameservers.setter
|
nameserver_ports: Dict[str, int],
|
||||||
def nameservers(self, nameservers: List[str]) -> None:
|
default_port: int,
|
||||||
"""
|
) -> List[dns.nameserver.Nameserver]:
|
||||||
*nameservers*, a ``list`` of nameservers.
|
enriched_nameservers = []
|
||||||
|
|
||||||
Raises ``ValueError`` if *nameservers* is anything other than a
|
|
||||||
``list``.
|
|
||||||
"""
|
|
||||||
if isinstance(nameservers, list):
|
if isinstance(nameservers, list):
|
||||||
for nameserver in nameservers:
|
for nameserver in nameservers:
|
||||||
if not dns.inet.is_address(nameserver):
|
enriched_nameserver: dns.nameserver.Nameserver
|
||||||
|
if isinstance(nameserver, dns.nameserver.Nameserver):
|
||||||
|
enriched_nameserver = nameserver
|
||||||
|
elif dns.inet.is_address(nameserver):
|
||||||
|
port = nameserver_ports.get(nameserver, default_port)
|
||||||
|
enriched_nameserver = dns.nameserver.Do53Nameserver(
|
||||||
|
nameserver, port
|
||||||
|
)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
if urlparse(nameserver).scheme != "https":
|
if urlparse(nameserver).scheme != "https":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"nameserver {nameserver} is not an "
|
f"nameserver {nameserver} is not a "
|
||||||
"IP address or valid https URL"
|
"dns.nameserver.Nameserver instance or text form, "
|
||||||
|
"IP address, nor a valid https URL"
|
||||||
)
|
)
|
||||||
self._nameservers = nameservers
|
enriched_nameserver = dns.nameserver.DoHNameserver(nameserver)
|
||||||
|
enriched_nameservers.append(enriched_nameserver)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"nameservers must be a list (not a {})".format(type(nameservers))
|
"nameservers must be a list or tuple (not a {})".format(
|
||||||
|
type(nameservers)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
return enriched_nameservers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nameservers(
|
||||||
|
self,
|
||||||
|
) -> Sequence[Union[str, dns.nameserver.Nameserver]]:
|
||||||
|
return self._nameservers
|
||||||
|
|
||||||
|
@nameservers.setter
|
||||||
|
def nameservers(
|
||||||
|
self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
*nameservers*, a ``list`` of nameservers, where a nameserver is either
|
||||||
|
a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
|
||||||
|
instance.
|
||||||
|
|
||||||
|
Raises ``ValueError`` if *nameservers* is not a list of nameservers.
|
||||||
|
"""
|
||||||
|
# We just call _enrich_nameservers() for checking
|
||||||
|
self._enrich_nameservers(nameservers, self.nameserver_ports, self.port)
|
||||||
|
self._nameservers = nameservers
|
||||||
|
|
||||||
|
|
||||||
class Resolver(BaseResolver):
|
class Resolver(BaseResolver):
|
||||||
|
@ -1198,33 +1315,18 @@ class Resolver(BaseResolver):
|
||||||
assert request is not None # needed for type checking
|
assert request is not None # needed for type checking
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
|
(nameserver, tcp, backoff) = resolution.next_nameserver()
|
||||||
if backoff:
|
if backoff:
|
||||||
time.sleep(backoff)
|
time.sleep(backoff)
|
||||||
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
timeout = self._compute_timeout(start, lifetime, resolution.errors)
|
||||||
try:
|
try:
|
||||||
if dns.inet.is_address(nameserver):
|
response = nameserver.query(
|
||||||
if tcp:
|
request,
|
||||||
response = dns.query.tcp(
|
timeout=timeout,
|
||||||
request,
|
source=source,
|
||||||
nameserver,
|
source_port=source_port,
|
||||||
timeout=timeout,
|
max_size=tcp,
|
||||||
port=port,
|
)
|
||||||
source=source,
|
|
||||||
source_port=source_port,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = dns.query.udp(
|
|
||||||
request,
|
|
||||||
nameserver,
|
|
||||||
timeout=timeout,
|
|
||||||
port=port,
|
|
||||||
source=source,
|
|
||||||
source_port=source_port,
|
|
||||||
raise_on_truncation=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = dns.query.https(request, nameserver, timeout=timeout)
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
(_, done) = resolution.query_result(None, ex)
|
(_, done) = resolution.query_result(None, ex)
|
||||||
continue
|
continue
|
||||||
|
@ -1293,7 +1395,72 @@ class Resolver(BaseResolver):
|
||||||
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||||
return self.resolve(
|
return self.resolve(
|
||||||
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
|
||||||
) # type: ignore[arg-type]
|
)
|
||||||
|
|
||||||
|
def resolve_name(
|
||||||
|
self,
|
||||||
|
name: Union[dns.name.Name, str],
|
||||||
|
family: int = socket.AF_UNSPEC,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> HostAnswers:
|
||||||
|
"""Use a resolver to query for address records.
|
||||||
|
|
||||||
|
This utilizes the resolve() method to perform A and/or AAAA lookups on
|
||||||
|
the specified name.
|
||||||
|
|
||||||
|
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
|
||||||
|
|
||||||
|
*family*, an ``int``, the address family. If socket.AF_UNSPEC
|
||||||
|
(the default), both A and AAAA records will be retrieved.
|
||||||
|
|
||||||
|
All other arguments that can be passed to the resolve() function
|
||||||
|
except for rdtype and rdclass are also supported by this
|
||||||
|
function.
|
||||||
|
"""
|
||||||
|
# We make a modified kwargs for type checking happiness, as otherwise
|
||||||
|
# we get a legit warning about possibly having rdtype and rdclass
|
||||||
|
# in the kwargs more than once.
|
||||||
|
modified_kwargs: Dict[str, Any] = {}
|
||||||
|
modified_kwargs.update(kwargs)
|
||||||
|
modified_kwargs.pop("rdtype", None)
|
||||||
|
modified_kwargs["rdclass"] = dns.rdataclass.IN
|
||||||
|
|
||||||
|
if family == socket.AF_INET:
|
||||||
|
v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs)
|
||||||
|
return HostAnswers.make(v4=v4)
|
||||||
|
elif family == socket.AF_INET6:
|
||||||
|
v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
|
||||||
|
return HostAnswers.make(v6=v6)
|
||||||
|
elif family != socket.AF_UNSPEC:
|
||||||
|
raise NotImplementedError(f"unknown address family {family}")
|
||||||
|
|
||||||
|
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
|
||||||
|
lifetime = modified_kwargs.pop("lifetime", None)
|
||||||
|
start = time.time()
|
||||||
|
v6 = self.resolve(
|
||||||
|
name,
|
||||||
|
dns.rdatatype.AAAA,
|
||||||
|
raise_on_no_answer=False,
|
||||||
|
lifetime=self._compute_timeout(start, lifetime),
|
||||||
|
**modified_kwargs,
|
||||||
|
)
|
||||||
|
# Note that setting name ensures we query the same name
|
||||||
|
# for A as we did for AAAA. (This is just in case search lists
|
||||||
|
# are active by default in the resolver configuration and
|
||||||
|
# we might be talking to a server that says NXDOMAIN when it
|
||||||
|
# wants to say NOERROR no data.
|
||||||
|
name = v6.qname
|
||||||
|
v4 = self.resolve(
|
||||||
|
name,
|
||||||
|
dns.rdatatype.A,
|
||||||
|
raise_on_no_answer=False,
|
||||||
|
lifetime=self._compute_timeout(start, lifetime),
|
||||||
|
**modified_kwargs,
|
||||||
|
)
|
||||||
|
answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer)
|
||||||
|
if not answers:
|
||||||
|
raise NoAnswer(response=v6.response)
|
||||||
|
return answers
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
@ -1320,6 +1487,37 @@ class Resolver(BaseResolver):
|
||||||
|
|
||||||
# pylint: enable=redefined-outer-name
|
# pylint: enable=redefined-outer-name
|
||||||
|
|
||||||
|
def try_ddr(self, lifetime: float = 5.0) -> None:
|
||||||
|
"""Try to update the resolver's nameservers using Discovery of Designated
|
||||||
|
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||||
|
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||||
|
|
||||||
|
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
|
||||||
|
is 5 seconds.
|
||||||
|
|
||||||
|
If the SVCB query is successful and results in a non-empty list of nameservers,
|
||||||
|
then the resolver's nameservers are set to the returned servers in priority
|
||||||
|
order.
|
||||||
|
|
||||||
|
The current implementation does not use any address hints from the SVCB record,
|
||||||
|
nor does it resolve addresses for the SCVB target name, rather it assumes that
|
||||||
|
the bootstrap nameserver will always be one of the addresses and uses it.
|
||||||
|
A future revision to the code may offer fuller support. The code verifies that
|
||||||
|
the bootstrap nameserver is in the Subject Alternative Name field of the
|
||||||
|
TLS certficate.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
expiration = time.time() + lifetime
|
||||||
|
answer = self.resolve(
|
||||||
|
dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime
|
||||||
|
)
|
||||||
|
timeout = dns.query._remaining(expiration)
|
||||||
|
nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
|
||||||
|
if len(nameservers) > 0:
|
||||||
|
self.nameservers = nameservers
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
#: The default resolver.
|
#: The default resolver.
|
||||||
default_resolver: Optional[Resolver] = None
|
default_resolver: Optional[Resolver] = None
|
||||||
|
@ -1333,7 +1531,7 @@ def get_default_resolver() -> Resolver:
|
||||||
return default_resolver
|
return default_resolver
|
||||||
|
|
||||||
|
|
||||||
def reset_default_resolver():
|
def reset_default_resolver() -> None:
|
||||||
"""Re-initialize default resolver.
|
"""Re-initialize default resolver.
|
||||||
|
|
||||||
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
|
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
|
||||||
|
@ -1355,7 +1553,6 @@ def resolve(
|
||||||
lifetime: Optional[float] = None,
|
lifetime: Optional[float] = None,
|
||||||
search: Optional[bool] = None,
|
search: Optional[bool] = None,
|
||||||
) -> Answer: # pragma: no cover
|
) -> Answer: # pragma: no cover
|
||||||
|
|
||||||
"""Query nameservers to find the answer to the question.
|
"""Query nameservers to find the answer to the question.
|
||||||
|
|
||||||
This is a convenience function that uses the default resolver
|
This is a convenience function that uses the default resolver
|
||||||
|
@ -1421,6 +1618,18 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer:
|
||||||
return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_name(
|
||||||
|
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
|
||||||
|
) -> HostAnswers:
|
||||||
|
"""Use a resolver to query for address records.
|
||||||
|
|
||||||
|
See ``dns.resolver.Resolver.resolve_name`` for more information on the
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return get_default_resolver().resolve_name(name, family, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||||
"""Determine the canonical name of *name*.
|
"""Determine the canonical name of *name*.
|
||||||
|
|
||||||
|
@ -1431,6 +1640,16 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||||
return get_default_resolver().canonical_name(name)
|
return get_default_resolver().canonical_name(name)
|
||||||
|
|
||||||
|
|
||||||
|
def try_ddr(lifetime: float = 5.0) -> None:
|
||||||
|
"""Try to update the default resolver's nameservers using Discovery of Designated
|
||||||
|
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||||
|
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||||
|
|
||||||
|
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
|
||||||
|
"""
|
||||||
|
return get_default_resolver().try_ddr(lifetime)
|
||||||
|
|
||||||
|
|
||||||
def zone_for_name(
|
def zone_for_name(
|
||||||
name: Union[dns.name.Name, str],
|
name: Union[dns.name.Name, str],
|
||||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||||
|
@ -1478,7 +1697,7 @@ def zone_for_name(
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
try:
|
||||||
rlifetime: Optional[float]
|
rlifetime: Optional[float]
|
||||||
if expiration:
|
if expiration is not None:
|
||||||
rlifetime = expiration - time.time()
|
rlifetime = expiration - time.time()
|
||||||
if rlifetime <= 0:
|
if rlifetime <= 0:
|
||||||
rlifetime = 0
|
rlifetime = 0
|
||||||
|
@ -1516,6 +1735,83 @@ def zone_for_name(
|
||||||
raise NoRootSOA
|
raise NoRootSOA
|
||||||
|
|
||||||
|
|
||||||
|
def make_resolver_at(
|
||||||
|
where: Union[dns.name.Name, str],
|
||||||
|
port: int = 53,
|
||||||
|
family: int = socket.AF_UNSPEC,
|
||||||
|
resolver: Optional[Resolver] = None,
|
||||||
|
) -> Resolver:
|
||||||
|
"""Make a stub resolver using the specified destination as the full resolver.
|
||||||
|
|
||||||
|
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
|
||||||
|
full resolver.
|
||||||
|
|
||||||
|
*port*, an ``int``, the port to use. If not specified, the default is 53.
|
||||||
|
|
||||||
|
*family*, an ``int``, the address family to use. This parameter is used if
|
||||||
|
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
|
||||||
|
the first address returned by ``resolve_name()`` will be used, otherwise the
|
||||||
|
first address of the specified family will be used.
|
||||||
|
|
||||||
|
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
|
||||||
|
resolution of hostnames. If not specified, the default resolver will be used.
|
||||||
|
|
||||||
|
Returns a ``dns.resolver.Resolver`` or raises an exception.
|
||||||
|
"""
|
||||||
|
if resolver is None:
|
||||||
|
resolver = get_default_resolver()
|
||||||
|
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
|
||||||
|
if isinstance(where, str) and dns.inet.is_address(where):
|
||||||
|
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
|
||||||
|
else:
|
||||||
|
for address in resolver.resolve_name(where, family).addresses():
|
||||||
|
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
|
||||||
|
res = dns.resolver.Resolver(configure=False)
|
||||||
|
res.nameservers = nameservers
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_at(
|
||||||
|
where: Union[dns.name.Name, str],
|
||||||
|
qname: Union[dns.name.Name, str],
|
||||||
|
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
|
||||||
|
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
|
||||||
|
tcp: bool = False,
|
||||||
|
source: Optional[str] = None,
|
||||||
|
raise_on_no_answer: bool = True,
|
||||||
|
source_port: int = 0,
|
||||||
|
lifetime: Optional[float] = None,
|
||||||
|
search: Optional[bool] = None,
|
||||||
|
port: int = 53,
|
||||||
|
family: int = socket.AF_UNSPEC,
|
||||||
|
resolver: Optional[Resolver] = None,
|
||||||
|
) -> Answer:
|
||||||
|
"""Query nameservers to find the answer to the question.
|
||||||
|
|
||||||
|
This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to
|
||||||
|
make a resolver, and then uses it to resolve the query.
|
||||||
|
|
||||||
|
See ``dns.resolver.Resolver.resolve`` for more information on the resolution
|
||||||
|
parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver
|
||||||
|
parameters *where*, *port*, *family*, and *resolver*.
|
||||||
|
|
||||||
|
If making more than one query, it is more efficient to call
|
||||||
|
``dns.resolver.make_resolver_at()`` and then use that resolver for the queries
|
||||||
|
instead of calling ``resolve_at()`` multiple times.
|
||||||
|
"""
|
||||||
|
return make_resolver_at(where, port, family, resolver).resolve(
|
||||||
|
qname,
|
||||||
|
rdtype,
|
||||||
|
rdclass,
|
||||||
|
tcp,
|
||||||
|
source,
|
||||||
|
raise_on_no_answer,
|
||||||
|
source_port,
|
||||||
|
lifetime,
|
||||||
|
search,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Support for overriding the system resolver for all python code in the
|
# Support for overriding the system resolver for all python code in the
|
||||||
# running process.
|
# running process.
|
||||||
|
@ -1559,8 +1855,7 @@ def _getaddrinfo(
|
||||||
)
|
)
|
||||||
if host is None and service is None:
|
if host is None and service is None:
|
||||||
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
||||||
v6addrs = []
|
addrs = []
|
||||||
v4addrs = []
|
|
||||||
canonical_name = None # pylint: disable=redefined-outer-name
|
canonical_name = None # pylint: disable=redefined-outer-name
|
||||||
# Is host None or an address literal? If so, use the system's
|
# Is host None or an address literal? If so, use the system's
|
||||||
# getaddrinfo().
|
# getaddrinfo().
|
||||||
|
@ -1576,24 +1871,9 @@ def _getaddrinfo(
|
||||||
pass
|
pass
|
||||||
# Something needs resolution!
|
# Something needs resolution!
|
||||||
try:
|
try:
|
||||||
if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
|
answers = _resolver.resolve_name(host, family)
|
||||||
v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False)
|
addrs = answers.addresses_and_families()
|
||||||
# Note that setting host ensures we query the same name
|
canonical_name = answers.canonical_name().to_text(True)
|
||||||
# for A as we did for AAAA. (This is just in case search lists
|
|
||||||
# are active by default in the resolver configuration and
|
|
||||||
# we might be talking to a server that says NXDOMAIN when it
|
|
||||||
# wants to say NOERROR no data.
|
|
||||||
host = v6.qname
|
|
||||||
canonical_name = v6.canonical_name.to_text(True)
|
|
||||||
if v6.rrset is not None:
|
|
||||||
for rdata in v6.rrset:
|
|
||||||
v6addrs.append(rdata.address)
|
|
||||||
if family == socket.AF_INET or family == socket.AF_UNSPEC:
|
|
||||||
v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False)
|
|
||||||
canonical_name = v4.canonical_name.to_text(True)
|
|
||||||
if v4.rrset is not None:
|
|
||||||
for rdata in v4.rrset:
|
|
||||||
v4addrs.append(rdata.address)
|
|
||||||
except dns.resolver.NXDOMAIN:
|
except dns.resolver.NXDOMAIN:
|
||||||
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -1625,20 +1905,11 @@ def _getaddrinfo(
|
||||||
cname = canonical_name
|
cname = canonical_name
|
||||||
else:
|
else:
|
||||||
cname = ""
|
cname = ""
|
||||||
if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
|
for addr, af in addrs:
|
||||||
for addr in v6addrs:
|
for socktype in socktypes:
|
||||||
for socktype in socktypes:
|
for proto in _protocols_for_socktype[socktype]:
|
||||||
for proto in _protocols_for_socktype[socktype]:
|
addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
|
||||||
tuples.append(
|
tuples.append((af, socktype, proto, cname, addr_tuple))
|
||||||
(socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0))
|
|
||||||
)
|
|
||||||
if family == socket.AF_INET or family == socket.AF_UNSPEC:
|
|
||||||
for addr in v4addrs:
|
|
||||||
for socktype in socktypes:
|
|
||||||
for proto in _protocols_for_socktype[socktype]:
|
|
||||||
tuples.append(
|
|
||||||
(socket.AF_INET, socktype, proto, cname, (addr, port))
|
|
||||||
)
|
|
||||||
if len(tuples) == 0:
|
if len(tuples) == 0:
|
||||||
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
|
||||||
return tuples
|
return tuples
|
||||||
|
|
|
@ -19,9 +19,9 @@
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
|
|
||||||
import dns.name
|
|
||||||
import dns.ipv6
|
|
||||||
import dns.ipv4
|
import dns.ipv4
|
||||||
|
import dns.ipv6
|
||||||
|
import dns.name
|
||||||
|
|
||||||
ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.")
|
ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.")
|
||||||
ipv6_reverse_domain = dns.name.from_text("ip6.arpa.")
|
ipv6_reverse_domain = dns.name.from_text("ip6.arpa.")
|
||||||
|
|
|
@ -17,11 +17,11 @@
|
||||||
|
|
||||||
"""DNS RRsets (an RRset is a named rdataset)"""
|
"""DNS RRsets (an RRset is a named rdataset)"""
|
||||||
|
|
||||||
from typing import Any, cast, Collection, Dict, Optional, Union
|
from typing import Any, Collection, Dict, Optional, Union, cast
|
||||||
|
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.rdataset
|
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
|
import dns.rdataset
|
||||||
import dns.renderer
|
import dns.renderer
|
||||||
|
|
||||||
|
|
||||||
|
@ -214,9 +214,9 @@ def from_text_list(
|
||||||
|
|
||||||
if isinstance(name, str):
|
if isinstance(name, str):
|
||||||
name = dns.name.from_text(name, None, idna_codec=idna_codec)
|
name = dns.name.from_text(name, None, idna_codec=idna_codec)
|
||||||
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
rdclass = dns.rdataclass.RdataClass.make(rdclass)
|
||||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
r = RRset(name, the_rdclass, the_rdtype)
|
r = RRset(name, rdclass, rdtype)
|
||||||
r.update_ttl(ttl)
|
r.update_ttl(ttl)
|
||||||
for t in text_rdatas:
|
for t in text_rdatas:
|
||||||
rd = dns.rdata.from_text(
|
rd = dns.rdata.from_text(
|
||||||
|
|
|
@ -17,10 +17,9 @@
|
||||||
|
|
||||||
"""Tokenize DNS zone file format"""
|
"""Tokenize DNS zone file format"""
|
||||||
|
|
||||||
from typing import Any, Optional, List, Tuple
|
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.name
|
import dns.name
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||||
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.name
|
import dns.name
|
||||||
|
@ -357,6 +356,27 @@ class Transaction:
|
||||||
"""
|
"""
|
||||||
self._check_delete_name.append(check)
|
self._check_delete_name.append(check)
|
||||||
|
|
||||||
|
def iterate_rdatasets(
|
||||||
|
self,
|
||||||
|
) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
|
||||||
|
"""Iterate all the rdatasets in the transaction, returning
|
||||||
|
(`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
|
||||||
|
|
||||||
|
Note that as is usual with python iterators, adding or removing items
|
||||||
|
while iterating will invalidate the iterator and may raise `RuntimeError`
|
||||||
|
or fail to iterate over all entries."""
|
||||||
|
self._check_ended()
|
||||||
|
return self._iterate_rdatasets()
|
||||||
|
|
||||||
|
def iterate_names(self) -> Iterator[dns.name.Name]:
|
||||||
|
"""Iterate all the names in the transaction.
|
||||||
|
|
||||||
|
Note that as is usual with python iterators, adding or removing names
|
||||||
|
while iterating will invalidate the iterator and may raise `RuntimeError`
|
||||||
|
or fail to iterate over all entries."""
|
||||||
|
self._check_ended()
|
||||||
|
return self._iterate_names()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Helper methods
|
# Helper methods
|
||||||
#
|
#
|
||||||
|
@ -416,7 +436,7 @@ class Transaction:
|
||||||
rdataset = rrset.to_rdataset()
|
rdataset = rrset.to_rdataset()
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{method} requires a name or RRset " + "as the first argument"
|
f"{method} requires a name or RRset as the first argument"
|
||||||
)
|
)
|
||||||
if rdataset.rdclass != self.manager.get_class():
|
if rdataset.rdclass != self.manager.get_class():
|
||||||
raise ValueError(f"{method} has objects of wrong RdataClass")
|
raise ValueError(f"{method} has objects of wrong RdataClass")
|
||||||
|
@ -475,7 +495,7 @@ class Transaction:
|
||||||
name = rdataset.name
|
name = rdataset.name
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{method} requires a name or RRset " + "as the first argument"
|
f"{method} requires a name or RRset as the first argument"
|
||||||
)
|
)
|
||||||
self._raise_if_not_empty(method, args)
|
self._raise_if_not_empty(method, args)
|
||||||
if rdataset:
|
if rdataset:
|
||||||
|
@ -610,6 +630,10 @@ class Transaction:
|
||||||
"""Return an iterator that yields (name, rdataset) tuples."""
|
"""Return an iterator that yields (name, rdataset) tuples."""
|
||||||
raise NotImplementedError # pragma: no cover
|
raise NotImplementedError # pragma: no cover
|
||||||
|
|
||||||
|
def _iterate_names(self):
|
||||||
|
"""Return an iterator that yields a name."""
|
||||||
|
raise NotImplementedError # pragma: no cover
|
||||||
|
|
||||||
def _get_node(self, name):
|
def _get_node(self, name):
|
||||||
"""Return the node at *name*, if any.
|
"""Return the node at *name*, if any.
|
||||||
|
|
||||||
|
|
|
@ -23,9 +23,9 @@ import hmac
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
import dns.exception
|
import dns.exception
|
||||||
import dns.rdataclass
|
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.rcode
|
import dns.rcode
|
||||||
|
import dns.rdataclass
|
||||||
|
|
||||||
|
|
||||||
class BadTime(dns.exception.DNSException):
|
class BadTime(dns.exception.DNSException):
|
||||||
|
@ -187,9 +187,7 @@ class HMACTSig:
|
||||||
try:
|
try:
|
||||||
hashinfo = self._hashes[algorithm]
|
hashinfo = self._hashes[algorithm]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported")
|
||||||
f"TSIG algorithm {algorithm} " + "is not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
# create the HMAC context
|
# create the HMAC context
|
||||||
if isinstance(hashinfo, tuple):
|
if isinstance(hashinfo, tuple):
|
||||||
|
|
|
@ -17,9 +17,8 @@
|
||||||
|
|
||||||
"""A place to store TSIG keys."""
|
"""A place to store TSIG keys."""
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import dns.name
|
import dns.name
|
||||||
import dns.tsig
|
import dns.tsig
|
||||||
|
@ -33,7 +32,7 @@ def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
|
||||||
@rtype: dict"""
|
@rtype: dict"""
|
||||||
|
|
||||||
keyring = {}
|
keyring = {}
|
||||||
for (name, value) in textring.items():
|
for name, value in textring.items():
|
||||||
kname = dns.name.from_text(name)
|
kname = dns.name.from_text(name)
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
keyring[kname] = dns.tsig.Key(kname, value).secret
|
keyring[kname] = dns.tsig.Key(kname, value).secret
|
||||||
|
@ -55,7 +54,7 @@ def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]:
|
||||||
def b64encode(secret):
|
def b64encode(secret):
|
||||||
return base64.encodebytes(secret).decode().rstrip()
|
return base64.encodebytes(secret).decode().rstrip()
|
||||||
|
|
||||||
for (name, key) in keyring.items():
|
for name, key in keyring.items():
|
||||||
tname = name.to_text()
|
tname = name.to_text()
|
||||||
if isinstance(key, bytes):
|
if isinstance(key, bytes):
|
||||||
textring[tname] = b64encode(key)
|
textring[tname] = b64encode(key)
|
||||||
|
|
|
@ -24,8 +24,8 @@ import dns.name
|
||||||
import dns.opcode
|
import dns.opcode
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
import dns.rdatatype
|
|
||||||
import dns.rdataset
|
import dns.rdataset
|
||||||
|
import dns.rdatatype
|
||||||
import dns.tsig
|
import dns.tsig
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,7 +43,6 @@ class UpdateSection(dns.enum.IntEnum):
|
||||||
|
|
||||||
|
|
||||||
class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
|
class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
|
||||||
|
|
||||||
# ignore the mypy error here as we mean to use a different enum
|
# ignore the mypy error here as we mean to use a different enum
|
||||||
_section_enum = UpdateSection # type: ignore
|
_section_enum = UpdateSection # type: ignore
|
||||||
|
|
||||||
|
@ -336,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
|
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||||
self.find_rrset(
|
self.find_rrset(
|
||||||
self.prerequisite,
|
self.prerequisite,
|
||||||
name,
|
name,
|
||||||
dns.rdataclass.NONE,
|
dns.rdataclass.NONE,
|
||||||
the_rdtype,
|
rdtype,
|
||||||
dns.rdatatype.NONE,
|
dns.rdatatype.NONE,
|
||||||
None,
|
None,
|
||||||
True,
|
True,
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue