From c0aa4e49968b283d54a44669314a4ecb17fa697a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:05:11 -0700 Subject: [PATCH] Bump dnspython from 2.3.0 to 2.4.2 (#2123) * Bump dnspython from 2.3.0 to 2.4.2 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.3.0 to 2.4.2. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.3.0...v2.4.2) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update dnspython==2.4.2 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/dns/__init__.py | 1 + lib/dns/_asyncbackend.py | 14 + lib/dns/_asyncio_backend.py | 118 ++++- lib/dns/_curio_backend.py | 122 ------ lib/dns/_ddr.py | 154 +++++++ lib/dns/_immutable_ctx.py | 1 - lib/dns/_trio_backend.py | 116 ++++- lib/dns/asyncbackend.py | 23 +- lib/dns/asyncquery.py | 77 ++-- lib/dns/asyncresolver.py | 245 +++++++++-- lib/dns/dnssec.py | 671 +++++++++++++---------------- lib/dns/dnssecalgs/__init__.py | 121 ++++++ lib/dns/dnssecalgs/base.py | 84 ++++ lib/dns/dnssecalgs/cryptography.py | 68 +++ lib/dns/dnssecalgs/dsa.py | 101 +++++ lib/dns/dnssecalgs/ecdsa.py | 89 ++++ lib/dns/dnssecalgs/eddsa.py | 65 +++ lib/dns/dnssecalgs/rsa.py | 119 +++++ lib/dns/edns.py | 11 +- lib/dns/entropy.py | 6 +- lib/dns/enum.py | 34 +- lib/dns/exception.py | 16 + lib/dns/flags.py | 3 +- lib/dns/immutable.py | 3 +- lib/dns/inet.py | 13 +- lib/dns/ipv4.py | 3 +- lib/dns/ipv6.py | 5 +- lib/dns/message.py | 102 +++-- lib/dns/name.py | 9 +- lib/dns/nameserver.py | 329 ++++++++++++++ lib/dns/node.py | 6 +- lib/dns/query.py | 260 ++++++----- lib/dns/quic/__init__.py | 7 +- lib/dns/quic/_asyncio.py | 45 +- lib/dns/quic/_common.py | 10 +- lib/dns/quic/_sync.py | 68 +-- lib/dns/quic/_trio.py | 69 +-- lib/dns/rdata.py | 27 +- lib/dns/rdataset.py | 13 +- lib/dns/rdtypes/ANY/AFSDB.py | 2 +- lib/dns/rdtypes/ANY/AVC.py | 2 +- lib/dns/rdtypes/ANY/CDNSKEY.py | 8 +- lib/dns/rdtypes/ANY/CDS.py | 2 +- lib/dns/rdtypes/ANY/CERT.py | 4 +- lib/dns/rdtypes/ANY/CNAME.py | 2 +- lib/dns/rdtypes/ANY/CSYNC.py | 2 +- lib/dns/rdtypes/ANY/DLV.py | 2 +- lib/dns/rdtypes/ANY/DNAME.py | 2 +- lib/dns/rdtypes/ANY/DNSKEY.py | 8 +- lib/dns/rdtypes/ANY/DS.py | 2 +- lib/dns/rdtypes/ANY/EUI48.py | 2 +- lib/dns/rdtypes/ANY/EUI64.py | 2 +- lib/dns/rdtypes/ANY/HIP.py | 2 +- lib/dns/rdtypes/ANY/LOC.py | 3 +- lib/dns/rdtypes/ANY/MX.py | 2 +- lib/dns/rdtypes/ANY/NINFO.py | 2 +- lib/dns/rdtypes/ANY/NS.py | 2 +- lib/dns/rdtypes/ANY/NSEC.py | 2 +- lib/dns/rdtypes/ANY/NSEC3.py | 6 +- lib/dns/rdtypes/ANY/NSEC3PARAM.py | 2 +- lib/dns/rdtypes/ANY/OPT.py | 3 +- lib/dns/rdtypes/ANY/PTR.py | 2 +- lib/dns/rdtypes/ANY/RP.py | 2 +- lib/dns/rdtypes/ANY/RRSIG.py | 2 +- lib/dns/rdtypes/ANY/RT.py | 2 +- lib/dns/rdtypes/ANY/SOA.py | 2 +- lib/dns/rdtypes/ANY/SPF.py | 2 +- lib/dns/rdtypes/ANY/SSHFP.py | 4 +- lib/dns/rdtypes/ANY/TKEY.py | 2 +- lib/dns/rdtypes/ANY/TXT.py | 2 +- lib/dns/rdtypes/ANY/URI.py | 2 +- lib/dns/rdtypes/ANY/ZONEMD.py | 2 +- lib/dns/rdtypes/CH/A.py | 2 +- lib/dns/rdtypes/IN/APL.py | 1 - lib/dns/rdtypes/IN/HTTPS.py | 2 +- lib/dns/rdtypes/IN/IPSECKEY.py | 2 +- lib/dns/rdtypes/IN/KX.py | 2 +- lib/dns/rdtypes/IN/NSAP_PTR.py | 2 +- lib/dns/rdtypes/IN/PX.py | 2 +- lib/dns/rdtypes/IN/SRV.py | 2 +- lib/dns/rdtypes/IN/SVCB.py | 2 +- lib/dns/rdtypes/IN/WKS.py | 2 +- lib/dns/rdtypes/dnskeybase.py | 4 +- lib/dns/rdtypes/dsbase.py | 4 +- lib/dns/rdtypes/euibase.py | 2 +- lib/dns/rdtypes/mxbase.py | 2 +- lib/dns/rdtypes/nsbase.py | 2 +- lib/dns/rdtypes/svcbbase.py | 1 + lib/dns/rdtypes/tlsabase.py | 4 +- lib/dns/rdtypes/txtbase.py | 3 +- lib/dns/rdtypes/util.py | 25 +- lib/dns/renderer.py | 3 +- lib/dns/resolver.py | 537 +++++++++++++++++------ lib/dns/reversename.py | 4 +- lib/dns/rrset.py | 10 +- lib/dns/tokenizer.py | 3 +- lib/dns/transaction.py | 32 +- lib/dns/tsig.py | 6 +- lib/dns/tsigkeyring.py | 7 +- lib/dns/update.py | 7 +- lib/dns/version.py | 6 +- lib/dns/versioned.py | 6 +- lib/dns/win32util.py | 8 +- lib/dns/wire.py | 3 +- lib/dns/xfr.py | 2 +- lib/dns/zone.py | 55 +-- lib/dns/zonefile.py | 48 ++- requirements.txt | 2 +- 108 files changed, 2985 insertions(+), 1136 deletions(-) delete mode 100644 lib/dns/_curio_backend.py create mode 100644 lib/dns/_ddr.py create mode 100644 lib/dns/dnssecalgs/__init__.py create mode 100644 lib/dns/dnssecalgs/base.py create mode 100644 lib/dns/dnssecalgs/cryptography.py create mode 100644 lib/dns/dnssecalgs/dsa.py create mode 100644 lib/dns/dnssecalgs/ecdsa.py create mode 100644 lib/dns/dnssecalgs/eddsa.py create mode 100644 lib/dns/dnssecalgs/rsa.py create mode 100644 lib/dns/nameserver.py diff --git a/lib/dns/__init__.py b/lib/dns/__init__.py index 9abdf018..a4249b9e 100644 --- a/lib/dns/__init__.py +++ b/lib/dns/__init__.py @@ -22,6 +22,7 @@ __all__ = [ "asyncquery", "asyncresolver", "dnssec", + "dnssecalgs", "dnssectypes", "e164", "edns", diff --git a/lib/dns/_asyncbackend.py b/lib/dns/_asyncbackend.py index ff24604f..49f14fed 100644 --- a/lib/dns/_asyncbackend.py +++ b/lib/dns/_asyncbackend.py @@ -35,6 +35,9 @@ class Socket: # pragma: no cover async def getsockname(self): raise NotImplementedError + async def getpeercert(self, timeout): + raise NotImplementedError + async def __aenter__(self): return self @@ -61,6 +64,11 @@ class StreamSocket(Socket): # pragma: no cover raise NotImplementedError +class NullTransport: + async def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + class Backend: # pragma: no cover def name(self): return "unknown" @@ -83,3 +91,9 @@ class Backend: # pragma: no cover async def sleep(self, interval): raise NotImplementedError + + def get_transport_class(self): + raise NotImplementedError + + async def wait_for(self, awaitable, timeout): + raise NotImplementedError diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py index 82a06249..2631228e 100644 --- a/lib/dns/_asyncio_backend.py +++ b/lib/dns/_asyncio_backend.py @@ -2,14 +2,13 @@ """asyncio library query support""" -import socket import asyncio +import socket import sys import dns._asyncbackend import dns.exception - _is_win32 = sys.platform == "win32" @@ -38,14 +37,21 @@ class _DatagramProtocol: def connection_lost(self, exc): if self.recvfrom and not self.recvfrom.done(): - self.recvfrom.set_exception(exc) + if exc is None: + # EOF we triggered. Is there a better way to do this? + try: + raise EOFError + except EOFError as e: + self.recvfrom.set_exception(e) + else: + self.recvfrom.set_exception(exc) def close(self): self.transport.close() async def _maybe_wait_for(awaitable, timeout): - if timeout: + if timeout is not None: try: return await asyncio.wait_for(awaitable, timeout) except asyncio.TimeoutError: @@ -85,6 +91,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getsockname(self): return self.transport.get_extra_info("sockname") + async def getpeercert(self, timeout): + raise NotImplementedError + class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, af, reader, writer): @@ -101,10 +110,6 @@ class StreamSocket(dns._asyncbackend.StreamSocket): async def close(self): self.writer.close() - try: - await self.writer.wait_closed() - except AttributeError: # pragma: no cover - pass async def getpeername(self): return self.writer.get_extra_info("peername") @@ -112,6 +117,97 @@ class StreamSocket(dns._asyncbackend.StreamSocket): async def getsockname(self): return self.writer.get_extra_info("sockname") + async def getpeercert(self, timeout): + return self.writer.get_extra_info("peercert") + + +try: + import anyio + import httpcore + import httpcore._backends.anyio + import httpx + + _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend + _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream + + from dns.query import _compute_times, _expiration_for_this_attempt, _remaining + + class _NetworkBackend(_CoreAsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + if local_port != 0: + raise NotImplementedError( + "the asyncio transport for HTTPX cannot set the local port" + ) + + async def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + with anyio.fail_after(timeout): + stream = await anyio.connect_tcp( + remote_host=address, + remote_port=port, + local_host=local_address, + ) + return _CoreAnyIOStream(stream) + except Exception: + pass + raise httpcore.ConnectError + + async def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + async def sleep(self, seconds): # pylint: disable=signature-differs + await anyio.sleep(seconds) + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + class Backend(dns._asyncbackend.Backend): def name(self): @@ -171,3 +267,9 @@ class Backend(dns._asyncbackend.Backend): def datagram_connection_required(self): return _is_win32 + + def get_transport_class(self): + return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + return await _maybe_wait_for(awaitable, timeout) diff --git a/lib/dns/_curio_backend.py b/lib/dns/_curio_backend.py deleted file mode 100644 index 765d6471..00000000 --- a/lib/dns/_curio_backend.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -"""curio async I/O library query support""" - -import socket -import curio -import curio.socket # type: ignore - -import dns._asyncbackend -import dns.exception -import dns.inet - - -def _maybe_timeout(timeout): - if timeout: - return curio.ignore_after(timeout) - else: - return dns._asyncbackend.NullContext() - - -# for brevity -_lltuple = dns.inet.low_level_address_tuple - -# pylint: disable=redefined-outer-name - - -class DatagramSocket(dns._asyncbackend.DatagramSocket): - def __init__(self, socket): - super().__init__(socket.family) - self.socket = socket - - async def sendto(self, what, destination, timeout): - async with _maybe_timeout(timeout): - return await self.socket.sendto(what, destination) - raise dns.exception.Timeout( - timeout=timeout - ) # pragma: no cover lgtm[py/unreachable-statement] - - async def recvfrom(self, size, timeout): - async with _maybe_timeout(timeout): - return await self.socket.recvfrom(size) - raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] - - async def close(self): - await self.socket.close() - - async def getpeername(self): - return self.socket.getpeername() - - async def getsockname(self): - return self.socket.getsockname() - - -class StreamSocket(dns._asyncbackend.StreamSocket): - def __init__(self, socket): - self.socket = socket - self.family = socket.family - - async def sendall(self, what, timeout): - async with _maybe_timeout(timeout): - return await self.socket.sendall(what) - raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] - - async def recv(self, size, timeout): - async with _maybe_timeout(timeout): - return await self.socket.recv(size) - raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] - - async def close(self): - await self.socket.close() - - async def getpeername(self): - return self.socket.getpeername() - - async def getsockname(self): - return self.socket.getsockname() - - -class Backend(dns._asyncbackend.Backend): - def name(self): - return "curio" - - async def make_socket( - self, - af, - socktype, - proto=0, - source=None, - destination=None, - timeout=None, - ssl_context=None, - server_hostname=None, - ): - if socktype == socket.SOCK_DGRAM: - s = curio.socket.socket(af, socktype, proto) - try: - if source: - s.bind(_lltuple(source, af)) - except Exception: # pragma: no cover - await s.close() - raise - return DatagramSocket(s) - elif socktype == socket.SOCK_STREAM: - if source: - source_addr = _lltuple(source, af) - else: - source_addr = None - async with _maybe_timeout(timeout): - s = await curio.open_connection( - destination[0], - destination[1], - ssl=ssl_context, - source_addr=source_addr, - server_hostname=server_hostname, - ) - return StreamSocket(s) - raise NotImplementedError( - "unsupported socket " + f"type {socktype}" - ) # pragma: no cover - - async def sleep(self, interval): - await curio.sleep(interval) diff --git a/lib/dns/_ddr.py b/lib/dns/_ddr.py new file mode 100644 index 00000000..bf5c11eb --- /dev/null +++ b/lib/dns/_ddr.py @@ -0,0 +1,154 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +# +# Support for Discovery of Designated Resolvers + +import socket +import time +from urllib.parse import urlparse + +import dns.asyncbackend +import dns.inet +import dns.name +import dns.nameserver +import dns.query +import dns.rdtypes.svcbbase + +# The special name of the local resolver when using DDR +_local_resolver_name = dns.name.from_text("_dns.resolver.arpa") + + +# +# Processing is split up into I/O independent and I/O dependent parts to +# make supporting sync and async versions easy. +# + + +class _SVCBInfo: + def __init__(self, bootstrap_address, port, hostname, nameservers): + self.bootstrap_address = bootstrap_address + self.port = port + self.hostname = hostname + self.nameservers = nameservers + + def ddr_check_certificate(self, cert): + """Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)""" + for name, value in cert["subjectAltName"]: + if name == "IP Address" and value == self.bootstrap_address: + return True + return False + + def make_tls_context(self): + ssl = dns.query.ssl + ctx = ssl.create_default_context() + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + return ctx + + def ddr_tls_check_sync(self, lifetime): + ctx = self.make_tls_context() + expiration = time.time() + lifetime + with socket.create_connection( + (self.bootstrap_address, self.port), lifetime + ) as s: + with ctx.wrap_socket(s, server_hostname=self.hostname) as ts: + ts.settimeout(dns.query._remaining(expiration)) + ts.do_handshake() + cert = ts.getpeercert() + return self.ddr_check_certificate(cert) + + async def ddr_tls_check_async(self, lifetime, backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + ctx = self.make_tls_context() + expiration = time.time() + lifetime + async with await backend.make_socket( + dns.inet.af_for_address(self.bootstrap_address), + socket.SOCK_STREAM, + 0, + None, + (self.bootstrap_address, self.port), + lifetime, + ctx, + self.hostname, + ) as ts: + cert = await ts.getpeercert(dns.query._remaining(expiration)) + return self.ddr_check_certificate(cert) + + +def _extract_nameservers_from_svcb(answer): + bootstrap_address = answer.nameserver + if not dns.inet.is_address(bootstrap_address): + return [] + infos = [] + for rr in answer.rrset.processing_order(): + nameservers = [] + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN) + if param is None: + continue + alpns = set(param.ids) + host = rr.target.to_text(omit_final_dot=True) + port = None + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT) + if param is not None: + port = param.port + # For now we ignore address hints and address resolution and always use the + # bootstrap address + if b"h2" in alpns: + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH) + if param is None or not param.value.endswith(b"{?dns}"): + continue + path = param.value[:-6].decode() + if not path.startswith("/"): + path = "/" + path + if port is None: + port = 443 + url = f"https://{host}:{port}{path}" + # check the URL + try: + urlparse(url) + nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address)) + except Exception: + # continue processing other ALPN types + pass + if b"dot" in alpns: + if port is None: + port = 853 + nameservers.append( + dns.nameserver.DoTNameserver(bootstrap_address, port, host) + ) + if b"doq" in alpns: + if port is None: + port = 853 + nameservers.append( + dns.nameserver.DoQNameserver(bootstrap_address, port, True, host) + ) + if len(nameservers) > 0: + infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers)) + return infos + + +def _get_nameservers_sync(answer, lifetime): + """Return a list of TLS-validated resolver nameservers extracted from an SVCB + answer.""" + nameservers = [] + infos = _extract_nameservers_from_svcb(answer) + for info in infos: + try: + if info.ddr_tls_check_sync(lifetime): + nameservers.extend(info.nameservers) + except Exception: + pass + return nameservers + + +async def _get_nameservers_async(answer, lifetime): + """Return a list of TLS-validated resolver nameservers extracted from an SVCB + answer.""" + nameservers = [] + infos = _extract_nameservers_from_svcb(answer) + for info in infos: + try: + if await info.ddr_tls_check_async(lifetime): + nameservers.extend(info.nameservers) + except Exception: + pass + return nameservers diff --git a/lib/dns/_immutable_ctx.py b/lib/dns/_immutable_ctx.py index 63c0a2d3..ae7a33bf 100644 --- a/lib/dns/_immutable_ctx.py +++ b/lib/dns/_immutable_ctx.py @@ -7,7 +7,6 @@ import contextvars import inspect - _in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py index b0c02103..4d9fb820 100644 --- a/lib/dns/_trio_backend.py +++ b/lib/dns/_trio_backend.py @@ -3,6 +3,7 @@ """trio async I/O library query support""" import socket + import trio import trio.socket # type: ignore @@ -12,7 +13,7 @@ import dns.inet def _maybe_timeout(timeout): - if timeout: + if timeout is not None: return trio.move_on_after(timeout) else: return dns._asyncbackend.NullContext() @@ -50,6 +51,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getsockname(self): return self.socket.getsockname() + async def getpeercert(self, timeout): + raise NotImplementedError + class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, family, stream, tls=False): @@ -82,6 +86,100 @@ class StreamSocket(dns._asyncbackend.StreamSocket): else: return self.stream.socket.getsockname() + async def getpeercert(self, timeout): + if self.tls: + with _maybe_timeout(timeout): + await self.stream.do_handshake() + return self.stream.getpeercert() + else: + raise NotImplementedError + + +try: + import httpcore + import httpcore._backends.trio + import httpx + + _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend + _CoreTrioStream = httpcore._backends.trio.TrioStream + + from dns.query import _compute_times, _expiration_for_this_attempt, _remaining + + class _NetworkBackend(_CoreAsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + async def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = (local_address, self._local_port) + else: + source = None + destination = (address, port) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + sock = await Backend().make_socket( + af, socket.SOCK_STREAM, 0, source, destination, timeout + ) + return _CoreTrioStream(sock.stream) + except Exception: + continue + raise httpcore.ConnectError + + async def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + async def sleep(self, seconds): # pylint: disable=signature-differs + await trio.sleep(seconds) + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + class Backend(dns._asyncbackend.Backend): def name(self): @@ -104,8 +202,14 @@ class Backend(dns._asyncbackend.Backend): if source: await s.bind(_lltuple(source, af)) if socktype == socket.SOCK_STREAM: + connected = False with _maybe_timeout(timeout): await s.connect(_lltuple(destination, af)) + connected = True + if not connected: + raise dns.exception.Timeout( + timeout=timeout + ) # lgtm[py/unreachable-statement] except Exception: # pragma: no cover s.close() raise @@ -130,3 +234,13 @@ class Backend(dns._asyncbackend.Backend): async def sleep(self, interval): await trio.sleep(interval) + + def get_transport_class(self): + return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + with _maybe_timeout(timeout): + return await awaitable + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] diff --git a/lib/dns/asyncbackend.py b/lib/dns/asyncbackend.py index c7565a99..07d50e1e 100644 --- a/lib/dns/asyncbackend.py +++ b/lib/dns/asyncbackend.py @@ -5,13 +5,12 @@ from typing import Dict import dns.exception # pylint: disable=unused-import - -from dns._asyncbackend import ( - Socket, - DatagramSocket, - StreamSocket, +from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import] Backend, -) # noqa: F401 lgtm[py/unused-import] + DatagramSocket, + Socket, + StreamSocket, +) # pylint: enable=unused-import @@ -30,8 +29,8 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException): def get_backend(name: str) -> Backend: """Get the specified asynchronous backend. - *name*, a ``str``, the name of the backend. Currently the "trio", - "curio", and "asyncio" backends are available. + *name*, a ``str``, the name of the backend. Currently the "trio" + and "asyncio" backends are available. Raises NotImplementError if an unknown backend name is specified. """ @@ -43,10 +42,6 @@ def get_backend(name: str) -> Backend: import dns._trio_backend backend = dns._trio_backend.Backend() - elif name == "curio": - import dns._curio_backend - - backend = dns._curio_backend.Backend() elif name == "asyncio": import dns._asyncio_backend @@ -73,9 +68,7 @@ def sniff() -> str: try: return sniffio.current_async_library() except sniffio.AsyncLibraryNotFoundError: - raise AsyncLibraryNotFoundError( - "sniffio cannot determine " + "async library" - ) + raise AsyncLibraryNotFoundError("sniffio cannot determine async library") except ImportError: import asyncio diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py index 459c611d..ecf9c1a5 100644 --- a/lib/dns/asyncquery.py +++ b/lib/dns/asyncquery.py @@ -17,39 +17,38 @@ """Talk to a DNS server.""" -from typing import Any, Dict, Optional, Tuple, Union - import base64 import contextlib import socket import struct import time +from typing import Any, Dict, Optional, Tuple, Union import dns.asyncbackend import dns.exception import dns.inet -import dns.name import dns.message +import dns.name import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype import dns.transaction - from dns._asyncbackend import NullContext from dns.query import ( - _compute_times, - _matches_destination, BadResponse, - ssl, - UDPMode, - _have_httpx, - _have_http2, NoDOH, NoDOQ, + UDPMode, + _compute_times, + _have_http2, + _matches_destination, + _remaining, + have_doh, + ssl, ) -if _have_httpx: +if have_doh: import httpx # for brevity @@ -73,7 +72,7 @@ def _source_tuple(af, address, port): def _timeout(expiration, now=None): - if expiration: + if expiration is not None: if not now: now = time.time() return max(expiration - now, 0) @@ -445,9 +444,6 @@ async def tls( ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 if server_hostname is None: ssl_context.check_hostname = False - else: - ssl_context = None - server_hostname = None af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) @@ -495,6 +491,9 @@ async def https( path: str = "/dns-query", post: bool = True, verify: Union[bool, str] = True, + bootstrap_address: Optional[str] = None, + resolver: Optional["dns.asyncresolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -508,8 +507,10 @@ async def https( parameters, exceptions, and return type of this method. """ - if not _have_httpx: - raise NoDOH("httpx is not available.") # pragma: no cover + if not have_doh: + raise NoDOH # pragma: no cover + if client and not isinstance(client, httpx.AsyncClient): + raise ValueError("session parameter must be an httpx.AsyncClient") wire = q.to_wire() try: @@ -518,15 +519,32 @@ async def https( af = None transport = None headers = {"accept": "application/dns-message"} - if af is not None: + if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: url = "https://[{}]:{}{}".format(where, port, path) else: url = where - if source is not None: - transport = httpx.AsyncHTTPTransport(local_address=source[0]) + + backend = dns.asyncbackend.get_default_backend() + + if source is None: + local_address = None + local_port = 0 + else: + local_address = source + local_port = source_port + transport = backend.get_transport_class()( + local_address=local_address, + http1=True, + http2=_have_http2, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) if client: cm: contextlib.AbstractAsyncContextManager = NullContext(client) @@ -545,14 +563,14 @@ async def https( "content-length": str(len(wire)), } ) - response = await the_client.post( - url, headers=headers, content=wire, timeout=timeout + response = await backend.wait_for( + the_client.post(url, headers=headers, content=wire), timeout ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes - response = await the_client.get( - url, headers=headers, timeout=timeout, params={"dns": twire} + response = await backend.wait_for( + the_client.get(url, headers=headers, params={"dns": twire}), timeout ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH @@ -690,6 +708,7 @@ async def quic( connection: Optional[dns.quic.AsyncQuicConnection] = None, verify: Union[bool, str] = True, backend: Optional[dns.asyncbackend.Backend] = None, + server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending an asynchronous query via DNS-over-QUIC. @@ -715,14 +734,16 @@ async def quic( (cfactory, mfactory) = dns.quic.factories_for_backend(backend) async with cfactory() as context: - async with mfactory(context, verify_mode=verify) as the_manager: + async with mfactory( + context, verify_mode=verify, server_name=server_hostname + ) as the_manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) - start = time.time() - stream = await the_connection.make_stream() + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) async with stream: await stream.send(wire, True) - wire = await stream.receive(timeout) + wire = await stream.receive(_remaining(expiration)) finish = time.time() r = dns.message.from_wire( wire, diff --git a/lib/dns/asyncresolver.py b/lib/dns/asyncresolver.py index 506530e2..8f5e062a 100644 --- a/lib/dns/asyncresolver.py +++ b/lib/dns/asyncresolver.py @@ -17,10 +17,11 @@ """Asynchronous DNS stub resolver.""" -from typing import Any, Dict, Optional, Union - +import socket import time +from typing import Any, Dict, List, Optional, Union +import dns._ddr import dns.asyncbackend import dns.asyncquery import dns.exception @@ -31,8 +32,7 @@ import dns.rdatatype import dns.resolver # lgtm[py/import-and-import-from] # import some resolver symbols for brevity -from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA - +from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute # for indentation purposes below _udp = dns.asyncquery.udp @@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver): assert request is not None # needed for type checking done = False while not done: - (nameserver, port, tcp, backoff) = resolution.next_nameserver() + (nameserver, tcp, backoff) = resolution.next_nameserver() if backoff: await backend.sleep(backoff) timeout = self._compute_timeout(start, lifetime, resolution.errors) try: - if dns.inet.is_address(nameserver): - if tcp: - response = await _tcp( - request, - nameserver, - timeout, - port, - source, - source_port, - backend=backend, - ) - else: - response = await _udp( - request, - nameserver, - timeout, - port, - source, - source_port, - raise_on_truncation=True, - backend=backend, - ) - else: - response = await dns.asyncquery.https( - request, nameserver, timeout=timeout - ) + response = await nameserver.async_query( + request, + timeout=timeout, + source=source, + source_port=source_port, + max_size=tcp, + backend=backend, + ) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -153,6 +135,73 @@ class Resolver(dns.resolver.BaseResolver): dns.reversename.from_address(ipaddr), *args, **modified_kwargs ) + async def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any, + ) -> dns.resolver.HostAnswers: + """Use an asynchronous resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return dns.resolver.HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return dns.resolver.HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) + lifetime = modified_kwargs.pop("lifetime", None) + start = time.time() + v6 = await self.resolve( + name, + dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = await self.resolve( + name, + dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + answers = dns.resolver.HostAnswers.make( + v6=v6, v4=v4, add_empty=not raise_on_no_answer + ) + if not answers: + raise NoAnswer(response=v6.response) + return answers + # pylint: disable=redefined-outer-name async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: @@ -176,6 +225,37 @@ class Resolver(dns.resolver.BaseResolver): canonical_name = e.canonical_name return canonical_name + async def try_ddr(self, lifetime: float = 5.0) -> None: + """Try to update the resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + *lifetime*, a float, is the maximum time to spend attempting DDR. The default + is 5 seconds. + + If the SVCB query is successful and results in a non-empty list of nameservers, + then the resolver's nameservers are set to the returned servers in priority + order. + + The current implementation does not use any address hints from the SVCB record, + nor does it resolve addresses for the SCVB target name, rather it assumes that + the bootstrap nameserver will always be one of the addresses and uses it. + A future revision to the code may offer fuller support. The code verifies that + the bootstrap nameserver is in the Subject Alternative Name field of the + TLS certficate. + """ + try: + expiration = time.time() + lifetime + answer = await self.resolve( + dns._ddr._local_resolver_name, "svcb", lifetime=lifetime + ) + timeout = dns.query._remaining(expiration) + nameservers = await dns._ddr._get_nameservers_async(answer, timeout) + if len(nameservers) > 0: + self.nameservers = nameservers + except Exception: + pass + default_resolver = None @@ -246,6 +326,18 @@ async def resolve_address( return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +async def resolve_name( + name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any +) -> dns.resolver.HostAnswers: + """Use a resolver to asynchronously query for address records. + + See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more + information on the parameters. + """ + + return await get_default_resolver().resolve_name(name, family, **kwargs) + + async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. @@ -256,6 +348,16 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return await get_default_resolver().canonical_name(name) +async def try_ddr(timeout: float = 5.0) -> None: + """Try to update the default resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + See :py:func:`dns.resolver.Resolver.try_ddr` for more information. + """ + return await get_default_resolver().try_ddr(timeout) + + async def zone_for_name( name: Union[dns.name.Name, str], rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, @@ -290,3 +392,84 @@ async def zone_for_name( name = name.parent() except dns.name.NoParent: # pragma: no cover raise NoRootSOA + + +async def make_resolver_at( + where: Union[dns.name.Name, str], + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Resolver: + """Make a stub resolver using the specified destination as the full resolver. + + *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the + full resolver. + + *port*, an ``int``, the port to use. If not specified, the default is 53. + + *family*, an ``int``, the address family to use. This parameter is used if + *where* is not an address. The default is ``socket.AF_UNSPEC`` in which case + the first address returned by ``resolve_name()`` will be used, otherwise the + first address of the specified family will be used. + + *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames. If not specified, the default resolver will be used. + + Returns a ``dns.resolver.Resolver`` or raises an exception. + """ + if resolver is None: + resolver = get_default_resolver() + nameservers: List[Union[str, dns.nameserver.Nameserver]] = [] + if isinstance(where, str) and dns.inet.is_address(where): + nameservers.append(dns.nameserver.Do53Nameserver(where, port)) + else: + answers = await resolver.resolve_name(where, family) + for address in answers.addresses(): + nameservers.append(dns.nameserver.Do53Nameserver(address, port)) + res = dns.asyncresolver.Resolver(configure=False) + res.nameservers = nameservers + return res + + +async def resolve_at( + where: Union[dns.name.Name, str], + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> dns.resolver.Answer: + """Query nameservers to find the answer to the question. + + This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()`` + to make a resolver, and then uses it to resolve the query. + + See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution + parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the + resolver parameters *where*, *port*, *family*, and *resolver*. + + If making more than one query, it is more efficient to call + ``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries + instead of calling ``resolve_at()`` multiple times. + """ + res = await make_resolver_at(where, port, family, resolver) + return await res.resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + backend, + ) diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index 5dc26223..2949f619 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -17,50 +17,44 @@ """Common DNSSEC-related functions and constants.""" -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union +import base64 +import contextlib +import functools import hashlib -import math import struct import time -import base64 from datetime import datetime - -from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash +from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast import dns.exception import dns.name import dns.node -import dns.rdataset import dns.rdata -import dns.rdatatype import dns.rdataclass +import dns.rdataset +import dns.rdatatype import dns.rrset +import dns.transaction +import dns.zone +from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash +from dns.exception import ( # pylint: disable=W0611 + AlgorithmKeyMismatch, + DeniedByPolicy, + UnsupportedAlgorithm, + ValidationFailure, +) from dns.rdtypes.ANY.CDNSKEY import CDNSKEY from dns.rdtypes.ANY.CDS import CDS from dns.rdtypes.ANY.DNSKEY import DNSKEY from dns.rdtypes.ANY.DS import DS +from dns.rdtypes.ANY.NSEC import NSEC, Bitmap +from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime from dns.rdtypes.dnskeybase import Flag - -class UnsupportedAlgorithm(dns.exception.DNSException): - """The DNSSEC algorithm is not supported.""" - - -class AlgorithmKeyMismatch(UnsupportedAlgorithm): - """The DNSSEC algorithm is not supported for the given key type.""" - - -class ValidationFailure(dns.exception.DNSException): - """The DNSSEC signature is invalid.""" - - -class DeniedByPolicy(dns.exception.DNSException): - """Denied by DNSSEC policy.""" - - PublicKey = Union[ + "GenericPublicKey", "rsa.RSAPublicKey", "ec.EllipticCurvePublicKey", "ed25519.Ed25519PublicKey", @@ -68,12 +62,15 @@ PublicKey = Union[ ] PrivateKey = Union[ + "GenericPrivateKey", "rsa.RSAPrivateKey", "ec.EllipticCurvePrivateKey", "ed25519.Ed25519PrivateKey", "ed448.Ed448PrivateKey", ] +RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None] + def algorithm_from_text(text: str) -> Algorithm: """Convert text into a DNSSEC algorithm value. @@ -308,113 +305,13 @@ def _find_candidate_keys( return [ cast(DNSKEY, rd) for rd in rdataset - if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag + if rd.algorithm == rrsig.algorithm + and key_id(rd) == rrsig.key_tag + and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1 + and rd.protocol == 3 # RFC 4034 2.1.2 ] -def _is_rsa(algorithm: int) -> bool: - return algorithm in ( - Algorithm.RSAMD5, - Algorithm.RSASHA1, - Algorithm.RSASHA1NSEC3SHA1, - Algorithm.RSASHA256, - Algorithm.RSASHA512, - ) - - -def _is_dsa(algorithm: int) -> bool: - return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1) - - -def _is_ecdsa(algorithm: int) -> bool: - return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384) - - -def _is_eddsa(algorithm: int) -> bool: - return algorithm in (Algorithm.ED25519, Algorithm.ED448) - - -def _is_gost(algorithm: int) -> bool: - return algorithm == Algorithm.ECCGOST - - -def _is_md5(algorithm: int) -> bool: - return algorithm == Algorithm.RSAMD5 - - -def _is_sha1(algorithm: int) -> bool: - return algorithm in ( - Algorithm.DSA, - Algorithm.RSASHA1, - Algorithm.DSANSEC3SHA1, - Algorithm.RSASHA1NSEC3SHA1, - ) - - -def _is_sha256(algorithm: int) -> bool: - return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256) - - -def _is_sha384(algorithm: int) -> bool: - return algorithm == Algorithm.ECDSAP384SHA384 - - -def _is_sha512(algorithm: int) -> bool: - return algorithm == Algorithm.RSASHA512 - - -def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None: - """Ensure algorithm is valid for key type, throwing an exception on - mismatch.""" - if isinstance(key, rsa.RSAPublicKey): - if _is_rsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm) - if isinstance(key, dsa.DSAPublicKey): - if _is_dsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm) - if isinstance(key, ec.EllipticCurvePublicKey): - if _is_ecdsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm) - if isinstance(key, ed25519.Ed25519PublicKey): - if algorithm == Algorithm.ED25519: - return - raise AlgorithmKeyMismatch( - 'algorithm "%s" not valid for ED25519 key' % algorithm - ) - if isinstance(key, ed448.Ed448PublicKey): - if algorithm == Algorithm.ED448: - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm) - - raise TypeError("unsupported key type") - - -def _make_hash(algorithm: int) -> Any: - if _is_md5(algorithm): - return hashes.MD5() - if _is_sha1(algorithm): - return hashes.SHA1() - if _is_sha256(algorithm): - return hashes.SHA256() - if _is_sha384(algorithm): - return hashes.SHA384() - if _is_sha512(algorithm): - return hashes.SHA512() - if algorithm == Algorithm.ED25519: - return hashes.SHA512() - if algorithm == Algorithm.ED448: - return hashes.SHAKE256(114) - - raise ValidationFailure("unknown hash for algorithm %u" % algorithm) - - -def _bytes_to_long(b: bytes) -> int: - return int.from_bytes(b, "big") - - def _get_rrname_rdataset( rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], ) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: @@ -424,85 +321,13 @@ def _get_rrname_rdataset( return rrset.name, rrset -def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: - keyptr: bytes - if _is_rsa(key.algorithm): - # we ignore because mypy is confused and thinks key.key is a str for unknown - # reasons. - keyptr = key.key - (bytes_,) = struct.unpack("!B", keyptr[0:1]) - keyptr = keyptr[1:] - if bytes_ == 0: - (bytes_,) = struct.unpack("!H", keyptr[0:2]) - keyptr = keyptr[2:] - rsa_e = keyptr[0:bytes_] - rsa_n = keyptr[bytes_:] - try: - rsa_public_key = rsa.RSAPublicNumbers( - _bytes_to_long(rsa_e), _bytes_to_long(rsa_n) - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) - elif _is_dsa(key.algorithm): - keyptr = key.key - (t,) = struct.unpack("!B", keyptr[0:1]) - keyptr = keyptr[1:] - octets = 64 + t * 8 - dsa_q = keyptr[0:20] - keyptr = keyptr[20:] - dsa_p = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_g = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_y = keyptr[0:octets] - try: - dsa_public_key = dsa.DSAPublicNumbers( # type: ignore - _bytes_to_long(dsa_y), - dsa.DSAParameterNumbers( - _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g) - ), - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - dsa_public_key.verify(sig, data, chosen_hash) - elif _is_ecdsa(key.algorithm): - keyptr = key.key - curve: Any - if key.algorithm == Algorithm.ECDSAP256SHA256: - curve = ec.SECP256R1() - octets = 32 - else: - curve = ec.SECP384R1() - octets = 48 - ecdsa_x = keyptr[0:octets] - ecdsa_y = keyptr[octets : octets * 2] - try: - ecdsa_public_key = ec.EllipticCurvePublicNumbers( - curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y) - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash)) - elif _is_eddsa(key.algorithm): - keyptr = key.key - loader: Any - if key.algorithm == Algorithm.ED25519: - loader = ed25519.Ed25519PublicKey - else: - loader = ed448.Ed448PublicKey - try: - eddsa_public_key = loader.from_public_bytes(keyptr) - except ValueError: - raise ValidationFailure("invalid public key") - eddsa_public_key.verify(sig, data) - elif _is_gost(key.algorithm): - raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' - % algorithm_to_text(key.algorithm) - ) - else: - raise ValidationFailure("unknown algorithm %u" % key.algorithm) +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None: + public_cls = get_algorithm_cls_from_dnskey(key).public_cls + try: + public_key = public_cls.from_dnskey(key) + except ValueError: + raise ValidationFailure("invalid public key") + public_key.verify(sig, data) def _validate_rrsig( @@ -559,29 +384,13 @@ def _validate_rrsig( if rrsig.inception > now: raise ValidationFailure("not yet valid") - if _is_dsa(rrsig.algorithm): - sig_r = rrsig.signature[1:21] - sig_s = rrsig.signature[21:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) - elif _is_ecdsa(rrsig.algorithm): - if rrsig.algorithm == Algorithm.ECDSAP256SHA256: - octets = 32 - else: - octets = 48 - sig_r = rrsig.signature[0:octets] - sig_s = rrsig.signature[octets:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) - else: - sig = rrsig.signature - data = _make_rrsig_signature_data(rrset, rrsig, origin) - chosen_hash = _make_hash(rrsig.algorithm) for candidate_key in candidate_keys: if not policy.ok_to_validate(candidate_key): continue try: - _validate_signature(sig, data, candidate_key, chosen_hash) + _validate_signature(rrsig.signature, data, candidate_key) return except (InvalidSignature, ValidationFailure): # this happens on an individual validation failure @@ -673,6 +482,7 @@ def _sign( lifetime: Optional[int] = None, verify: bool = False, policy: Optional[Policy] = None, + origin: Optional[dns.name.Name] = None, ) -> RRSIG: """Sign RRset using private key. @@ -708,6 +518,10 @@ def _sign( *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + *origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all + names in the rrset (including its owner name) must be absolute; otherwise the + specified origin will be used to make names absolute when signing. + Raises ``DeniedByPolicy`` if the signature is denied by policy. """ @@ -735,16 +549,26 @@ def _sign( if expiration is not None: rrsig_expiration = to_timestamp(expiration) elif lifetime is not None: - rrsig_expiration = int(time.time()) + lifetime + rrsig_expiration = rrsig_inception + lifetime else: raise ValueError("expiration or lifetime must be specified") + # Derelativize now because we need a correct labels length for the + # rrsig_template. + if origin is not None: + rrname = rrname.derelativize(origin) + labels = len(rrname) - 1 + + # Adjust labels appropriately for wildcards. + if rrname.is_wild(): + labels -= 1 + rrsig_template = RRSIG( rdclass=rdclass, rdtype=dns.rdatatype.RRSIG, type_covered=rdtype, algorithm=dnskey.algorithm, - labels=len(rrname) - 1, + labels=labels, original_ttl=original_ttl, expiration=rrsig_expiration, inception=rrsig_inception, @@ -753,63 +577,18 @@ def _sign( signature=b"", ) - data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template) - chosen_hash = _make_hash(rrsig_template.algorithm) - signature = None + data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) - if isinstance(private_key, rsa.RSAPrivateKey): - if not _is_rsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for RSA key") - signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash) - if verify: - private_key.public_key().verify( - signature, data, padding.PKCS1v15(), chosen_hash - ) - elif isinstance(private_key, dsa.DSAPrivateKey): - if not _is_dsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for DSA key") - public_dsa_key = private_key.public_key() - if public_dsa_key.key_size > 1024: - raise ValueError("DSA key size overflow") - der_signature = private_key.sign(data, chosen_hash) - if verify: - public_dsa_key.verify(der_signature, data, chosen_hash) - dsa_r, dsa_s = utils.decode_dss_signature(der_signature) - dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 - octets = 20 - signature = ( - struct.pack("!B", dsa_t) - + int.to_bytes(dsa_r, length=octets, byteorder="big") - + int.to_bytes(dsa_s, length=octets, byteorder="big") - ) - elif isinstance(private_key, ec.EllipticCurvePrivateKey): - if not _is_ecdsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for EC key") - der_signature = private_key.sign(data, ec.ECDSA(chosen_hash)) - if verify: - private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash)) - if dnskey.algorithm == Algorithm.ECDSAP256SHA256: - octets = 32 - else: - octets = 48 - dsa_r, dsa_s = utils.decode_dss_signature(der_signature) - signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes( - dsa_s, length=octets, byteorder="big" - ) - elif isinstance(private_key, ed25519.Ed25519PrivateKey): - if dnskey.algorithm != Algorithm.ED25519: - raise ValueError("Invalid DNSKEY algorithm for ED25519 key") - signature = private_key.sign(data) - if verify: - private_key.public_key().verify(signature, data) - elif isinstance(private_key, ed448.Ed448PrivateKey): - if dnskey.algorithm != Algorithm.ED448: - raise ValueError("Invalid DNSKEY algorithm for ED448 key") - signature = private_key.sign(data) - if verify: - private_key.public_key().verify(signature, data) + if isinstance(private_key, GenericPrivateKey): + signing_key = private_key else: - raise TypeError("Unsupported key algorithm") + try: + private_cls = get_algorithm_cls_from_dnskey(dnskey) + signing_key = private_cls(key=private_key) + except UnsupportedAlgorithm: + raise TypeError("Unsupported key algorithm") + + signature = signing_key.sign(data, verify) return cast(RRSIG, rrsig_template.replace(signature=signature)) @@ -858,9 +637,12 @@ def _make_rrsig_signature_data( raise ValidationFailure("relative RR name without an origin specified") rrname = rrname.derelativize(origin) - if len(rrname) - 1 < rrsig.labels: + name_len = len(rrname) + if rrname.is_wild() and rrsig.labels != name_len - 2: + raise ValidationFailure("wild owner name has wrong label length") + if name_len - 1 < rrsig.labels: raise ValidationFailure("owner name longer than RRSIG labels") - elif rrsig.labels < len(rrname) - 1: + elif rrsig.labels < name_len - 1: suffix = rrname.split(rrsig.labels + 1)[1] rrname = dns.name.from_text("*", suffix) rrnamebuf = rrname.to_digestable() @@ -884,9 +666,8 @@ def _make_dnskey( ) -> DNSKEY: """Convert a public key to DNSKEY Rdata - *public_key*, the public key to convert, a - ``cryptography.hazmat.primitives.asymmetric`` public key class applicable - for DNSSEC. + *public_key*, a ``PublicKey`` (``GenericPublicKey`` or + ``cryptography.hazmat.primitives.asymmetric``) to convert. *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. @@ -902,72 +683,13 @@ def _make_dnskey( Return DNSKEY ``Rdata``. """ - def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes: - """Encode a public key per RFC 3110, section 2.""" - pn = public_key.public_numbers() - _exp_len = math.ceil(int.bit_length(pn.e) / 8) - exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") - if _exp_len > 255: - exp_header = b"\0" + struct.pack("!H", _exp_len) - else: - exp_header = struct.pack("!B", _exp_len) - if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: - raise ValueError("unsupported RSA key length") - return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + algorithm = Algorithm.make(algorithm) - def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes: - """Encode a public key per RFC 2536, section 2.""" - pn = public_key.public_numbers() - dsa_t = (public_key.key_size // 8 - 64) // 8 - if dsa_t > 8: - raise ValueError("unsupported DSA key size") - octets = 64 + dsa_t * 8 - res = struct.pack("!B", dsa_t) - res += pn.parameter_numbers.q.to_bytes(20, "big") - res += pn.parameter_numbers.p.to_bytes(octets, "big") - res += pn.parameter_numbers.g.to_bytes(octets, "big") - res += pn.y.to_bytes(octets, "big") - return res - - def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes: - """Encode a public key per RFC 6605, section 4.""" - pn = public_key.public_numbers() - if isinstance(public_key.curve, ec.SECP256R1): - return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big") - elif isinstance(public_key.curve, ec.SECP384R1): - return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big") - else: - raise ValueError("unsupported ECDSA curve") - - the_algorithm = Algorithm.make(algorithm) - - _ensure_algorithm_key_combination(the_algorithm, public_key) - - if isinstance(public_key, rsa.RSAPublicKey): - key_bytes = encode_rsa_public_key(public_key) - elif isinstance(public_key, dsa.DSAPublicKey): - key_bytes = encode_dsa_public_key(public_key) - elif isinstance(public_key, ec.EllipticCurvePublicKey): - key_bytes = encode_ecdsa_public_key(public_key) - elif isinstance(public_key, ed25519.Ed25519PublicKey): - key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) - elif isinstance(public_key, ed448.Ed448PublicKey): - key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) + if isinstance(public_key, GenericPublicKey): + return public_key.to_dnskey(flags=flags, protocol=protocol) else: - raise TypeError("unsupported key algorithm") - - return DNSKEY( - rdclass=dns.rdataclass.IN, - rdtype=dns.rdatatype.DNSKEY, - flags=flags, - protocol=protocol, - algorithm=the_algorithm, - key=key_bytes, - ) + public_cls = get_algorithm_cls(algorithm).public_cls + return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol) def _make_cdnskey( @@ -1216,23 +938,252 @@ def dnskey_rdataset_to_cdnskey_rdataset( return dns.rdataset.from_rdata_list(rdataset.ttl, res) +def default_rrset_signer( + txn: dns.transaction.Transaction, + rrset: dns.rrset.RRset, + signer: dns.name.Name, + ksks: List[Tuple[PrivateKey, DNSKEY]], + zsks: List[Tuple[PrivateKey, DNSKEY]], + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + policy: Optional[Policy] = None, + origin: Optional[dns.name.Name] = None, +) -> None: + """Default RRset signer""" + + if rrset.rdtype in set( + [ + dns.rdatatype.RdataType.DNSKEY, + dns.rdatatype.RdataType.CDS, + dns.rdatatype.RdataType.CDNSKEY, + ] + ): + keys = ksks + else: + keys = zsks + + for private_key, dnskey in keys: + rrsig = dns.dnssec.sign( + rrset=rrset, + private_key=private_key, + dnskey=dnskey, + inception=inception, + expiration=expiration, + lifetime=lifetime, + signer=signer, + policy=policy, + origin=origin, + ) + txn.add(rrset.name, rrset.ttl, rrsig) + + +def sign_zone( + zone: dns.zone.Zone, + txn: Optional[dns.transaction.Transaction] = None, + keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None, + add_dnskey: bool = True, + dnskey_ttl: Optional[int] = None, + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + nsec3: Optional[NSEC3PARAM] = None, + rrset_signer: Optional[RRsetSigner] = None, + policy: Optional[Policy] = None, +) -> None: + """Sign zone. + + *zone*, a ``dns.zone.Zone``, the zone to sign. + + *txn*, a ``dns.transaction.Transaction``, an optional transaction to use for + signing. + + *keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK + roles are assigned automatically if the SEP flag is used, otherwise all RRsets are + signed by all keys. + + *add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are + automatically added to the zone on signing. + + *dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL + of the existing DNSKEY RRset used or the TTL of the SOA RRset. + + *inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + inception time. If ``None``, the current time is used. If a ``str``, the format is + "YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text + form; this is the same the RRSIG rdata's text form. Values of type `int` or `float` + are interpreted as seconds since the UNIX epoch. + + *expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + expiration time. If ``None``, the expiration time will be the inception time plus + the value of the *lifetime* parameter. See the description of *inception* above for + how the various parameter types are interpreted. + + *lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This + parameter is only meaningful if *expiration* is ``None``. + + *nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet + implemented. + + *rrset_signer*, a ``Callable``, an optional function for signing RRsets. The + function requires two arguments: transaction and RRset. If the not specified, + ``dns.dnssec.default_rrset_signer`` will be used. + + Returns ``None``. + """ + + ksks = [] + zsks = [] + + # if we have both KSKs and ZSKs, split by SEP flag. if not, sign all + # records with all keys + if keys: + for key in keys: + if key[1].flags & Flag.SEP: + ksks.append(key) + else: + zsks.append(key) + if not ksks: + ksks = keys + if not zsks: + zsks = keys + else: + keys = [] + + if txn: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn) + else: + cm = zone.writer() + + with cm as _txn: + if add_dnskey: + if dnskey_ttl is None: + dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY) + if dnskey: + dnskey_ttl = dnskey.ttl + else: + soa = _txn.get(zone.origin, dns.rdatatype.SOA) + dnskey_ttl = soa.ttl + for _, dnskey in keys: + _txn.add(zone.origin, dnskey_ttl, dnskey) + + if nsec3: + raise NotImplementedError("Signing with NSEC3 not yet implemented") + else: + _rrset_signer = rrset_signer or functools.partial( + default_rrset_signer, + signer=zone.origin, + ksks=ksks, + zsks=zsks, + inception=inception, + expiration=expiration, + lifetime=lifetime, + policy=policy, + origin=zone.origin, + ) + return _sign_zone_nsec(zone, _txn, _rrset_signer) + + +def _sign_zone_nsec( + zone: dns.zone.Zone, + txn: dns.transaction.Transaction, + rrset_signer: Optional[RRsetSigner] = None, +) -> None: + """NSEC zone signer""" + + def _txn_add_nsec( + txn: dns.transaction.Transaction, + name: dns.name.Name, + next_secure: Optional[dns.name.Name], + rdclass: dns.rdataclass.RdataClass, + ttl: int, + rrset_signer: Optional[RRsetSigner] = None, + ) -> None: + """NSEC zone signer helper""" + mandatory_types = set( + [dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC] + ) + node = txn.get_node(name) + if node and next_secure: + types = ( + set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types + ) + windows = Bitmap.from_rdtypes(list(types)) + rrset = dns.rrset.from_rdata( + name, + ttl, + NSEC( + rdclass=rdclass, + rdtype=dns.rdatatype.RdataType.NSEC, + next=next_secure, + windows=windows, + ), + ) + txn.add(rrset) + if rrset_signer: + rrset_signer(txn, rrset) + + rrsig_ttl = zone.get_soa().minimum + delegation = None + last_secure = None + + for name in sorted(txn.iterate_names()): + if delegation and name.is_subdomain(delegation): + # names below delegations are not secure + continue + elif txn.get(name, dns.rdatatype.NS) and name != zone.origin: + # inside delegation + delegation = name + else: + # outside delegation + delegation = None + + if rrset_signer: + node = txn.get_node(name) + if node: + for rdataset in node.rdatasets: + if rdataset.rdtype == dns.rdatatype.RRSIG: + # do not sign RRSIGs + continue + elif delegation and rdataset.rdtype != dns.rdatatype.DS: + # do not sign delegations except DS records + continue + else: + rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset) + rrset_signer(txn, rrset) + + # We need "is not None" as the empty name is False because its length is 0. + if last_secure is not None: + _txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer) + last_secure = name + + if last_secure: + _txn_add_nsec( + txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer + ) + + def _need_pyca(*args, **kwargs): raise ImportError( - "DNSSEC validation requires " + "python cryptography" + "DNSSEC validation requires python cryptography" ) # pragma: no cover try: from cryptography.exceptions import InvalidSignature - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import padding - from cryptography.hazmat.primitives.asymmetric import utils - from cryptography.hazmat.primitives.asymmetric import dsa - from cryptography.hazmat.primitives.asymmetric import ec - from cryptography.hazmat.primitives.asymmetric import ed25519 - from cryptography.hazmat.primitives.asymmetric import ed448 - from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611 + ed25519, + ) + + from dns.dnssecalgs import ( # pylint: disable=C0412 + get_algorithm_cls, + get_algorithm_cls_from_dnskey, + ) + from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey except ImportError: # pragma: no cover validate = _need_pyca validate_rrsig = _need_pyca diff --git a/lib/dns/dnssecalgs/__init__.py b/lib/dns/dnssecalgs/__init__.py new file mode 100644 index 00000000..d1ffd519 --- /dev/null +++ b/lib/dns/dnssecalgs/__init__.py @@ -0,0 +1,121 @@ +from typing import Dict, Optional, Tuple, Type, Union + +import dns.name + +try: + from dns.dnssecalgs.base import GenericPrivateKey + from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 + from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 + from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519 + from dns.dnssecalgs.rsa import ( + PrivateRSAMD5, + PrivateRSASHA1, + PrivateRSASHA1NSEC3SHA1, + PrivateRSASHA256, + PrivateRSASHA512, + ) + + _have_cryptography = True +except ImportError: + _have_cryptography = False + +from dns.dnssectypes import Algorithm +from dns.exception import UnsupportedAlgorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + +AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] + +algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} +if _have_cryptography: + algorithms.update( + { + (Algorithm.RSAMD5, None): PrivateRSAMD5, + (Algorithm.DSA, None): PrivateDSA, + (Algorithm.RSASHA1, None): PrivateRSASHA1, + (Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1, + (Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1, + (Algorithm.RSASHA256, None): PrivateRSASHA256, + (Algorithm.RSASHA512, None): PrivateRSASHA512, + (Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256, + (Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384, + (Algorithm.ED25519, None): PrivateED25519, + (Algorithm.ED448, None): PrivateED448, + } + ) + + +def get_algorithm_cls( + algorithm: Union[int, str], prefix: AlgorithmPrefix = None +) -> Type[GenericPrivateKey]: + """Get Private Key class from Algorithm. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + algorithm = Algorithm.make(algorithm) + cls = algorithms.get((algorithm, prefix)) + if cls: + return cls + raise UnsupportedAlgorithm( + 'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm) + ) + + +def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]: + """Get Private Key class from DNSKEY. + + *dnskey*, a ``DNSKEY`` to get Algorithm class for. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + prefix: AlgorithmPrefix = None + if dnskey.algorithm == Algorithm.PRIVATEDNS: + prefix, _ = dns.name.from_wire(dnskey.key, 0) + elif dnskey.algorithm == Algorithm.PRIVATEOID: + length = int(dnskey.key[0]) + prefix = dnskey.key[0 : length + 1] + return get_algorithm_cls(dnskey.algorithm, prefix) + + +def register_algorithm_cls( + algorithm: Union[int, str], + algorithm_cls: Type[GenericPrivateKey], + name: Optional[Union[dns.name.Name, str]] = None, + oid: Optional[bytes] = None, +) -> None: + """Register Algorithm Private Key class. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *algorithm_cls*: A `GenericPrivateKey` class. + + *name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms. + + *oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms. + + Raises ``ValueError`` if a name or oid is specified incorrectly. + """ + if not issubclass(algorithm_cls, GenericPrivateKey): + raise TypeError("Invalid algorithm class") + algorithm = Algorithm.make(algorithm) + prefix: AlgorithmPrefix = None + if algorithm == Algorithm.PRIVATEDNS: + if name is None: + raise ValueError("Name required for PRIVATEDNS algorithms") + if isinstance(name, str): + name = dns.name.from_text(name) + prefix = name + elif algorithm == Algorithm.PRIVATEOID: + if oid is None: + raise ValueError("OID required for PRIVATEOID algorithms") + prefix = bytes([len(oid)]) + oid + elif name: + raise ValueError("Name only supported for PRIVATEDNS algorithm") + elif oid: + raise ValueError("OID only supported for PRIVATEOID algorithm") + algorithms[(algorithm, prefix)] = algorithm_cls diff --git a/lib/dns/dnssecalgs/base.py b/lib/dns/dnssecalgs/base.py new file mode 100644 index 00000000..e990575a --- /dev/null +++ b/lib/dns/dnssecalgs/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod # pylint: disable=no-name-in-module +from typing import Any, Optional, Type + +import dns.rdataclass +import dns.rdatatype +from dns.dnssectypes import Algorithm +from dns.exception import AlgorithmKeyMismatch +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.dnskeybase import Flag + + +class GenericPublicKey(ABC): + algorithm: Algorithm + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def verify(self, signature: bytes, data: bytes) -> None: + """Verify signed DNSSEC data""" + + @abstractmethod + def encode_key_bytes(self) -> bytes: + """Encode key as bytes for DNSKEY""" + + @classmethod + def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None: + if key.algorithm != cls.algorithm: + raise AlgorithmKeyMismatch + + def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY: + """Return public key as DNSKEY""" + return DNSKEY( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.DNSKEY, + flags=flags, + protocol=protocol, + algorithm=self.algorithm, + key=self.encode_key_bytes(), + ) + + @classmethod + @abstractmethod + def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey": + """Create public key from DNSKEY""" + + @classmethod + @abstractmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + """Create public key from PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + @abstractmethod + def to_pem(self) -> bytes: + """Return public-key as PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + +class GenericPrivateKey(ABC): + public_cls: Type[GenericPublicKey] + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign DNSSEC data""" + + @abstractmethod + def public_key(self) -> "GenericPublicKey": + """Return public key instance""" + + @classmethod + @abstractmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + """Create private key from PEM-encoded PKCS#8""" + + @abstractmethod + def to_pem(self, password: Optional[bytes] = None) -> bytes: + """Return private key as PEM-encoded PKCS#8""" diff --git a/lib/dns/dnssecalgs/cryptography.py b/lib/dns/dnssecalgs/cryptography.py new file mode 100644 index 00000000..5a31a812 --- /dev/null +++ b/lib/dns/dnssecalgs/cryptography.py @@ -0,0 +1,68 @@ +from typing import Any, Optional, Type + +from cryptography.hazmat.primitives import serialization + +from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey +from dns.exception import AlgorithmKeyMismatch + + +class CryptographyPublicKey(GenericPublicKey): + key: Any = None + key_cls: Any = None + + def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + key, self.key_cls + ): + raise AlgorithmKeyMismatch + self.key = key + + @classmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + key = serialization.load_pem_public_key(public_pem) + return cls(key=key) + + def to_pem(self) -> bytes: + return self.key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +class CryptographyPrivateKey(GenericPrivateKey): + key: Any = None + key_cls: Any = None + public_cls: Type[CryptographyPublicKey] + + def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + key, self.key_cls + ): + raise AlgorithmKeyMismatch + self.key = key + + def public_key(self) -> "CryptographyPublicKey": + return self.public_cls(key=self.key.public_key()) + + @classmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + key = serialization.load_pem_private_key(private_pem, password=password) + return cls(key=key) + + def to_pem(self, password: Optional[bytes] = None) -> bytes: + encryption_algorithm: serialization.KeySerializationEncryption + if password: + encryption_algorithm = serialization.BestAvailableEncryption(password) + else: + encryption_algorithm = serialization.NoEncryption() + return self.key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) diff --git a/lib/dns/dnssecalgs/dsa.py b/lib/dns/dnssecalgs/dsa.py new file mode 100644 index 00000000..0fe4690d --- /dev/null +++ b/lib/dns/dnssecalgs/dsa.py @@ -0,0 +1,101 @@ +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import dsa, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicDSA(CryptographyPublicKey): + key: dsa.DSAPublicKey + key_cls = dsa.DSAPublicKey + algorithm = Algorithm.DSA + chosen_hash = hashes.SHA1() + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[1:21] + sig_s = signature[21:] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 2536, section 2.""" + pn = self.key.public_numbers() + dsa_t = (self.key.key_size // 8 - 64) // 8 + if dsa_t > 8: + raise ValueError("unsupported DSA key size") + octets = 64 + dsa_t * 8 + res = struct.pack("!B", dsa_t) + res += pn.parameter_numbers.q.to_bytes(20, "big") + res += pn.parameter_numbers.p.to_bytes(octets, "big") + res += pn.parameter_numbers.g.to_bytes(octets, "big") + res += pn.y.to_bytes(octets, "big") + return res + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicDSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (t,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + octets = 64 + t * 8 + dsa_q = keyptr[0:20] + keyptr = keyptr[20:] + dsa_p = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_g = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_y = keyptr[0:octets] + return cls( + key=dsa.DSAPublicNumbers( # type: ignore + int.from_bytes(dsa_y, "big"), + dsa.DSAParameterNumbers( + int.from_bytes(dsa_p, "big"), + int.from_bytes(dsa_q, "big"), + int.from_bytes(dsa_g, "big"), + ), + ).public_key(default_backend()), + ) + + +class PrivateDSA(CryptographyPrivateKey): + key: dsa.DSAPrivateKey + key_cls = dsa.DSAPrivateKey + public_cls = PublicDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 2536, section 3.""" + public_dsa_key = self.key.public_key() + if public_dsa_key.key_size > 1024: + raise ValueError("DSA key size overflow") + der_signature = self.key.sign(data, self.public_cls.chosen_hash) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 + octets = 20 + signature = ( + struct.pack("!B", dsa_t) + + int.to_bytes(dsa_r, length=octets, byteorder="big") + + int.to_bytes(dsa_s, length=octets, byteorder="big") + ) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateDSA": + return cls( + key=dsa.generate_private_key(key_size=key_size), + ) + + +class PublicDSANSEC3SHA1(PublicDSA): + algorithm = Algorithm.DSANSEC3SHA1 + + +class PrivateDSANSEC3SHA1(PrivateDSA): + public_cls = PublicDSANSEC3SHA1 diff --git a/lib/dns/dnssecalgs/ecdsa.py b/lib/dns/dnssecalgs/ecdsa.py new file mode 100644 index 00000000..a31d79f2 --- /dev/null +++ b/lib/dns/dnssecalgs/ecdsa.py @@ -0,0 +1,89 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicECDSA(CryptographyPublicKey): + key: ec.EllipticCurvePublicKey + key_cls = ec.EllipticCurvePublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + curve: ec.EllipticCurve + octets: int + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[0 : self.octets] + sig_s = signature[self.octets :] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, ec.ECDSA(self.chosen_hash)) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 6605, section 4.""" + pn = self.key.public_numbers() + return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA": + cls._ensure_algorithm_key_combination(key) + ecdsa_x = key.key[0 : cls.octets] + ecdsa_y = key.key[cls.octets : cls.octets * 2] + return cls( + key=ec.EllipticCurvePublicNumbers( + curve=cls.curve, + x=int.from_bytes(ecdsa_x, "big"), + y=int.from_bytes(ecdsa_y, "big"), + ).public_key(default_backend()), + ) + + +class PrivateECDSA(CryptographyPrivateKey): + key: ec.EllipticCurvePrivateKey + key_cls = ec.EllipticCurvePrivateKey + public_cls = PublicECDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 6605, section 4.""" + der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash)) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + signature = int.to_bytes( + dsa_r, length=self.public_cls.octets, byteorder="big" + ) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big") + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateECDSA": + return cls( + key=ec.generate_private_key( + curve=cls.public_cls.curve, backend=default_backend() + ), + ) + + +class PublicECDSAP256SHA256(PublicECDSA): + algorithm = Algorithm.ECDSAP256SHA256 + chosen_hash = hashes.SHA256() + curve = ec.SECP256R1() + octets = 32 + + +class PrivateECDSAP256SHA256(PrivateECDSA): + public_cls = PublicECDSAP256SHA256 + + +class PublicECDSAP384SHA384(PublicECDSA): + algorithm = Algorithm.ECDSAP384SHA384 + chosen_hash = hashes.SHA384() + curve = ec.SECP384R1() + octets = 48 + + +class PrivateECDSAP384SHA384(PrivateECDSA): + public_cls = PublicECDSAP384SHA384 diff --git a/lib/dns/dnssecalgs/eddsa.py b/lib/dns/dnssecalgs/eddsa.py new file mode 100644 index 00000000..70505342 --- /dev/null +++ b/lib/dns/dnssecalgs/eddsa.py @@ -0,0 +1,65 @@ +from typing import Type + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed448, ed25519 + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicEDDSA(CryptographyPublicKey): + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 8080, section 3.""" + return self.key.public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA": + cls._ensure_algorithm_key_combination(key) + return cls( + key=cls.key_cls.from_public_bytes(key.key), + ) + + +class PrivateEDDSA(CryptographyPrivateKey): + public_cls: Type[PublicEDDSA] + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 8080, section 4.""" + signature = self.key.sign(data) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateEDDSA": + return cls(key=cls.key_cls.generate()) + + +class PublicED25519(PublicEDDSA): + key: ed25519.Ed25519PublicKey + key_cls = ed25519.Ed25519PublicKey + algorithm = Algorithm.ED25519 + + +class PrivateED25519(PrivateEDDSA): + key: ed25519.Ed25519PrivateKey + key_cls = ed25519.Ed25519PrivateKey + public_cls = PublicED25519 + + +class PublicED448(PublicEDDSA): + key: ed448.Ed448PublicKey + key_cls = ed448.Ed448PublicKey + algorithm = Algorithm.ED448 + + +class PrivateED448(PrivateEDDSA): + key: ed448.Ed448PrivateKey + key_cls = ed448.Ed448PrivateKey + public_cls = PublicED448 diff --git a/lib/dns/dnssecalgs/rsa.py b/lib/dns/dnssecalgs/rsa.py new file mode 100644 index 00000000..e95dcf1d --- /dev/null +++ b/lib/dns/dnssecalgs/rsa.py @@ -0,0 +1,119 @@ +import math +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicRSA(CryptographyPublicKey): + key: rsa.RSAPublicKey + key_cls = rsa.RSAPublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 3110, section 2.""" + pn = self.key.public_numbers() + _exp_len = math.ceil(int.bit_length(pn.e) / 8) + exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") + if _exp_len > 255: + exp_header = b"\0" + struct.pack("!H", _exp_len) + else: + exp_header = struct.pack("!B", _exp_len) + if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: + raise ValueError("unsupported RSA key length") + return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicRSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (bytes_,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + if bytes_ == 0: + (bytes_,) = struct.unpack("!H", keyptr[0:2]) + keyptr = keyptr[2:] + rsa_e = keyptr[0:bytes_] + rsa_n = keyptr[bytes_:] + return cls( + key=rsa.RSAPublicNumbers( + int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big") + ).public_key(default_backend()) + ) + + +class PrivateRSA(CryptographyPrivateKey): + key: rsa.RSAPrivateKey + key_cls = rsa.RSAPrivateKey + public_cls = PublicRSA + default_public_exponent = 65537 + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 3110, section 3.""" + signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateRSA": + return cls( + key=rsa.generate_private_key( + public_exponent=cls.default_public_exponent, + key_size=key_size, + backend=default_backend(), + ) + ) + + +class PublicRSAMD5(PublicRSA): + algorithm = Algorithm.RSAMD5 + chosen_hash = hashes.MD5() + + +class PrivateRSAMD5(PrivateRSA): + public_cls = PublicRSAMD5 + + +class PublicRSASHA1(PublicRSA): + algorithm = Algorithm.RSASHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1(PrivateRSA): + public_cls = PublicRSASHA1 + + +class PublicRSASHA1NSEC3SHA1(PublicRSA): + algorithm = Algorithm.RSASHA1NSEC3SHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1NSEC3SHA1(PrivateRSA): + public_cls = PublicRSASHA1NSEC3SHA1 + + +class PublicRSASHA256(PublicRSA): + algorithm = Algorithm.RSASHA256 + chosen_hash = hashes.SHA256() + + +class PrivateRSASHA256(PrivateRSA): + public_cls = PublicRSASHA256 + + +class PublicRSASHA512(PublicRSA): + algorithm = Algorithm.RSASHA512 + chosen_hash = hashes.SHA512() + + +class PrivateRSASHA512(PrivateRSA): + public_cls = PublicRSASHA512 diff --git a/lib/dns/edns.py b/lib/dns/edns.py index 64436cde..f05baac4 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -17,11 +17,10 @@ """EDNS Options""" -from typing import Any, Dict, Optional, Union - import math import socket import struct +from typing import Any, Dict, Optional, Union import dns.enum import dns.inet @@ -380,7 +379,7 @@ class EDEOption(Option): # lgtm[py/missing-equals] def from_wire_parser( cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" ) -> Option: - the_code = EDECode.make(parser.get_uint16()) + code = EDECode.make(parser.get_uint16()) text = parser.get_remaining() if text: @@ -390,7 +389,7 @@ class EDEOption(Option): # lgtm[py/missing-equals] else: btext = None - return cls(the_code, btext) + return cls(code, btext) _type_to_class: Dict[OptionType, Any] = { @@ -424,8 +423,8 @@ def option_from_wire_parser( Returns an instance of a subclass of ``dns.edns.Option``. """ - the_otype = OptionType.make(otype) - cls = get_option_class(the_otype) + otype = OptionType.make(otype) + cls = get_option_class(otype) return cls.from_wire_parser(otype, parser) diff --git a/lib/dns/entropy.py b/lib/dns/entropy.py index 5e1f5e23..4dcdc627 100644 --- a/lib/dns/entropy.py +++ b/lib/dns/entropy.py @@ -15,17 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -from typing import Any, Optional - -import os import hashlib +import os import random import threading import time +from typing import Any, Optional class EntropyPool: - # This is an entropy pool for Python implementations that do not # have a working SystemRandom. I'm not sure there are any, but # leaving this code doesn't hurt anything as the library code diff --git a/lib/dns/enum.py b/lib/dns/enum.py index b5a4aed8..71461f17 100644 --- a/lib/dns/enum.py +++ b/lib/dns/enum.py @@ -16,18 +16,31 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import enum +from typing import Type, TypeVar, Union + +TIntEnum = TypeVar("TIntEnum", bound="IntEnum") class IntEnum(enum.IntEnum): @classmethod - def _check_value(cls, value): - max = cls._maximum() - if value < 0 or value > max: - name = cls._short_name() - raise ValueError(f"{name} must be between >= 0 and <= {max}") + def _missing_(cls, value): + cls._check_value(value) + val = int.__new__(cls, value) + val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}" + val._value_ = value + return val @classmethod - def from_text(cls, text): + def _check_value(cls, value): + max = cls._maximum() + if not isinstance(value, int): + raise TypeError + if value < 0 or value > max: + name = cls._short_name() + raise ValueError(f"{name} must be an int between >= 0 and <= {max}") + + @classmethod + def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum: text = text.upper() try: return cls[text] @@ -47,7 +60,7 @@ class IntEnum(enum.IntEnum): raise cls._unknown_exception_class() @classmethod - def to_text(cls, value): + def to_text(cls: Type[TIntEnum], value: int) -> str: cls._check_value(value) try: text = cls(value).name @@ -59,7 +72,7 @@ class IntEnum(enum.IntEnum): return text @classmethod - def make(cls, value): + def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum: """Convert text or a value into an enumerated type, if possible. *value*, the ``int`` or ``str`` to convert. @@ -76,10 +89,7 @@ class IntEnum(enum.IntEnum): if isinstance(value, str): return cls.from_text(value) cls._check_value(value) - try: - return cls(value) - except ValueError: - return value + return cls(value) @classmethod def _maximum(cls): diff --git a/lib/dns/exception.py b/lib/dns/exception.py index 4b1481d1..6982373d 100644 --- a/lib/dns/exception.py +++ b/lib/dns/exception.py @@ -140,6 +140,22 @@ class Timeout(DNSException): super().__init__(*args, **kwargs) +class UnsupportedAlgorithm(DNSException): + """The DNSSEC algorithm is not supported.""" + + +class AlgorithmKeyMismatch(UnsupportedAlgorithm): + """The DNSSEC algorithm is not supported for the given key type.""" + + +class ValidationFailure(DNSException): + """The DNSSEC signature is invalid.""" + + +class DeniedByPolicy(DNSException): + """Denied by DNSSEC policy.""" + + class ExceptionWrapper: def __init__(self, exception_class): self.exception_class = exception_class diff --git a/lib/dns/flags.py b/lib/dns/flags.py index b21b8e3b..4c60be13 100644 --- a/lib/dns/flags.py +++ b/lib/dns/flags.py @@ -17,9 +17,8 @@ """DNS Message Flags.""" -from typing import Any - import enum +from typing import Any # Standard DNS flags diff --git a/lib/dns/immutable.py b/lib/dns/immutable.py index 38fbe597..cab8d6fb 100644 --- a/lib/dns/immutable.py +++ b/lib/dns/immutable.py @@ -1,8 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Any - import collections.abc +from typing import Any from dns._immutable_ctx import immutable diff --git a/lib/dns/inet.py b/lib/dns/inet.py index 11180c96..02e925c6 100644 --- a/lib/dns/inet.py +++ b/lib/dns/inet.py @@ -17,14 +17,12 @@ """Generic Internet address helper functions.""" -from typing import Any, Optional, Tuple - import socket +from typing import Any, Optional, Tuple import dns.ipv4 import dns.ipv6 - # We assume that AF_INET and AF_INET6 are always defined. We keep # these here for the benefit of any old code (unlikely though that # is!). @@ -171,3 +169,12 @@ def low_level_address_tuple( return tup else: raise NotImplementedError(f"unknown address family {af}") + + +def any_for_af(af): + """Return the 'any' address for the specified address family.""" + if af == socket.AF_INET: + return "0.0.0.0" + elif af == socket.AF_INET6: + return "::" + raise NotImplementedError(f"unknown address family {af}") diff --git a/lib/dns/ipv4.py b/lib/dns/ipv4.py index b8e148f3..f549150a 100644 --- a/lib/dns/ipv4.py +++ b/lib/dns/ipv4.py @@ -17,9 +17,8 @@ """IPv4 helper functions.""" -from typing import Union - import struct +from typing import Union import dns.exception diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index fbd49623..0cc3d868 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -17,10 +17,9 @@ """IPv6 helper functions.""" -from typing import List, Union - -import re import binascii +import re +from typing import List, Union import dns.exception import dns.ipv4 diff --git a/lib/dns/message.py b/lib/dns/message.py index 8250db3b..daae6363 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -17,30 +17,29 @@ """DNS Messages""" -from typing import Any, Dict, List, Optional, Tuple, Union - import contextlib import io import time +from typing import Any, Dict, List, Optional, Tuple, Union -import dns.wire import dns.edns +import dns.entropy import dns.enum import dns.exception import dns.flags import dns.name import dns.opcode -import dns.entropy import dns.rcode import dns.rdata import dns.rdataclass import dns.rdatatype -import dns.rrset -import dns.renderer -import dns.ttl -import dns.tsig import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.TSIG +import dns.renderer +import dns.rrset +import dns.tsig +import dns.ttl +import dns.wire class ShortHeader(dns.exception.FormError): @@ -135,7 +134,7 @@ IndexKeyType = Tuple[ Optional[dns.rdataclass.RdataClass], ] IndexType = Dict[IndexKeyType, dns.rrset.RRset] -SectionType = Union[int, List[dns.rrset.RRset]] +SectionType = Union[int, str, List[dns.rrset.RRset]] class Message: @@ -231,7 +230,7 @@ class Message: s.write("payload %d\n" % self.payload) for opt in self.options: s.write("option %s\n" % opt.to_text()) - for (name, which) in self._section_enum.__members__.items(): + for name, which in self._section_enum.__members__.items(): s.write(f";{name}\n") for rrset in self.section_from_number(which): s.write(rrset.to_text(origin, relativize, **kw)) @@ -348,27 +347,29 @@ class Message: deleting: Optional[dns.rdataclass.RdataClass] = None, create: bool = False, force_unique: bool = False, + idna_codec: Optional[dns.name.IDNACodec] = None, ) -> dns.rrset.RRset: """Find the RRset with the given attributes in the specified section. - *section*, an ``int`` section number, or one of the section - attributes of this message. This specifies the + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the the section of the message to search. For example:: my_message.find_rrset(my_message.answer, name, rdclass, rdtype) my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) + my_message.find_rrset("ANSWER", name, rdclass, rdtype) - *name*, a ``dns.name.Name``, the name of the RRset. + *name*, a ``dns.name.Name`` or ``str``, the name of the RRset. - *rdclass*, an ``int``, the class of the RRset. + *rdclass*, an ``int`` or ``str``, the class of the RRset. - *rdtype*, an ``int``, the type of the RRset. + *rdtype*, an ``int`` or ``str``, the type of the RRset. - *covers*, an ``int`` or ``None``, the covers value of the RRset. - The default is ``None``. + *covers*, an ``int`` or ``str``, the covers value of the RRset. + The default is ``dns.rdatatype.NONE``. - *deleting*, an ``int`` or ``None``, the deleting value of the RRset. - The default is ``None``. + *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the + RRset. The default is ``None``. *create*, a ``bool``. If ``True``, create the RRset if it is not found. The created RRset is appended to *section*. @@ -378,6 +379,10 @@ class Message: already. The default is ``False``. This is useful when creating DDNS Update messages, as order matters for them. + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + Raises ``KeyError`` if the RRset was not found and create was ``False``. @@ -386,10 +391,19 @@ class Message: if isinstance(section, int): section_number = section - the_section = self.section_from_number(section_number) + section = self.section_from_number(section_number) + elif isinstance(section, str): + section_number = MessageSection.from_text(section) + section = self.section_from_number(section_number) else: section_number = self.section_number(section) - the_section = section + if isinstance(name, str): + name = dns.name.from_text(name, idna_codec=idna_codec) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + covers = dns.rdatatype.RdataType.make(covers) + if deleting is not None: + deleting = dns.rdataclass.RdataClass.make(deleting) key = (section_number, name, rdclass, rdtype, covers, deleting) if not force_unique: if self.index is not None: @@ -397,13 +411,13 @@ class Message: if rrset is not None: return rrset else: - for rrset in the_section: + for rrset in section: if rrset.full_match(name, rdclass, rdtype, covers, deleting): return rrset if not create: raise KeyError rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) - the_section.append(rrset) + section.append(rrset) if self.index is not None: self.index[key] = rrset return rrset @@ -418,29 +432,31 @@ class Message: deleting: Optional[dns.rdataclass.RdataClass] = None, create: bool = False, force_unique: bool = False, + idna_codec: Optional[dns.name.IDNACodec] = None, ) -> Optional[dns.rrset.RRset]: """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. - *section*, an ``int`` section number, or one of the section - attributes of this message. This specifies the + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the the section of the message to search. For example:: my_message.get_rrset(my_message.answer, name, rdclass, rdtype) my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) + my_message.get_rrset("ANSWER", name, rdclass, rdtype) - *name*, a ``dns.name.Name``, the name of the RRset. + *name*, a ``dns.name.Name`` or ``str``, the name of the RRset. - *rdclass*, an ``int``, the class of the RRset. + *rdclass*, an ``int`` or ``str``, the class of the RRset. - *rdtype*, an ``int``, the type of the RRset. + *rdtype*, an ``int`` or ``str``, the type of the RRset. - *covers*, an ``int`` or ``None``, the covers value of the RRset. - The default is ``None``. + *covers*, an ``int`` or ``str``, the covers value of the RRset. + The default is ``dns.rdatatype.NONE``. - *deleting*, an ``int`` or ``None``, the deleting value of the RRset. - The default is ``None``. + *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the + RRset. The default is ``None``. *create*, a ``bool``. If ``True``, create the RRset if it is not found. The created RRset is appended to *section*. @@ -450,12 +466,24 @@ class Message: already. The default is ``False``. This is useful when creating DDNS Update messages, as order matters for them. + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + Returns a ``dns.rrset.RRset object`` or ``None``. """ try: rrset = self.find_rrset( - section, name, rdclass, rdtype, covers, deleting, create, force_unique + section, + name, + rdclass, + rdtype, + covers, + deleting, + create, + force_unique, + idna_codec, ) except KeyError: rrset = None @@ -1708,13 +1736,11 @@ def make_query( if isinstance(qname, str): qname = dns.name.from_text(qname, idna_codec=idna_codec) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) m.flags = dns.flags.Flag(flags) - m.find_rrset( - m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True - ) + m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't # been configured. diff --git a/lib/dns/name.py b/lib/dns/name.py index 612af021..f452bfed 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -18,12 +18,10 @@ """DNS Names. """ -from typing import Any, Dict, Iterable, Optional, Tuple, Union - import copy -import struct - import encodings.idna # type: ignore +import struct +from typing import Any, Dict, Iterable, Optional, Tuple, Union try: import idna # type: ignore @@ -33,10 +31,9 @@ except ImportError: # pragma: no cover have_idna_2008 = False import dns.enum -import dns.wire import dns.exception import dns.immutable - +import dns.wire CompressType = Dict["Name", int] diff --git a/lib/dns/nameserver.py b/lib/dns/nameserver.py new file mode 100644 index 00000000..5910139e --- /dev/null +++ b/lib/dns/nameserver.py @@ -0,0 +1,329 @@ +from typing import Optional, Union +from urllib.parse import urlparse + +import dns.asyncbackend +import dns.asyncquery +import dns.inet +import dns.message +import dns.query + + +class Nameserver: + def __init__(self): + pass + + def __str__(self): + raise NotImplementedError + + def kind(self) -> str: + raise NotImplementedError + + def is_always_max_size(self) -> bool: + raise NotImplementedError + + def answer_nameserver(self) -> str: + raise NotImplementedError + + def answer_port(self) -> int: + raise NotImplementedError + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + raise NotImplementedError + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + raise NotImplementedError + + +class AddressAndPortNameserver(Nameserver): + def __init__(self, address: str, port: int): + super().__init__() + self.address = address + self.port = port + + def kind(self) -> str: + raise NotImplementedError + + def is_always_max_size(self) -> bool: + return False + + def __str__(self): + ns_kind = self.kind() + return f"{ns_kind}:{self.address}@{self.port}" + + def answer_nameserver(self) -> str: + return self.address + + def answer_port(self) -> int: + return self.port + + +class Do53Nameserver(AddressAndPortNameserver): + def __init__(self, address: str, port: int = 53): + super().__init__(address, port) + + def kind(self): + return "Do53" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + if max_size: + response = dns.query.tcp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + else: + response = dns.query.udp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + raise_on_truncation=True, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + return response + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + if max_size: + response = await dns.asyncquery.tcp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + backend=backend, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + else: + response = await dns.asyncquery.udp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + raise_on_truncation=True, + backend=backend, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + return response + + +class DoHNameserver(Nameserver): + def __init__(self, url: str, bootstrap_address: Optional[str] = None): + super().__init__() + self.url = url + self.bootstrap_address = bootstrap_address + + def kind(self): + return "DoH" + + def is_always_max_size(self) -> bool: + return True + + def __str__(self): + return self.url + + def answer_nameserver(self) -> str: + return self.url + + def answer_port(self) -> int: + port = urlparse(self.url).port + if port is None: + port = 443 + return port + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.https( + request, + self.url, + timeout=timeout, + bootstrap_address=self.bootstrap_address, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.https( + request, + self.url, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + + +class DoTNameserver(AddressAndPortNameserver): + def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None): + super().__init__(address, port) + self.hostname = hostname + + def kind(self): + return "DoT" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.tls( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + server_hostname=self.hostname, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.tls( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + server_hostname=self.hostname, + ) + + +class DoQNameserver(AddressAndPortNameserver): + def __init__( + self, + address: str, + port: int = 853, + verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, + ): + super().__init__(address, port) + self.verify = verify + self.server_hostname = server_hostname + + def kind(self): + return "DoQ" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.quic( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + server_hostname=self.server_hostname, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.quic( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + server_hostname=self.server_hostname, + ) diff --git a/lib/dns/node.py b/lib/dns/node.py index 22bbe7cb..c670243c 100644 --- a/lib/dns/node.py +++ b/lib/dns/node.py @@ -17,19 +17,17 @@ """DNS nodes. A node is a set of rdatasets.""" -from typing import Any, Dict, Optional - import enum import io +from typing import Any, Dict, Optional import dns.immutable import dns.name import dns.rdataclass import dns.rdataset import dns.rdatatype -import dns.rrset import dns.renderer - +import dns.rrset _cname_types = { dns.rdatatype.CNAME, diff --git a/lib/dns/query.py b/lib/dns/query.py index b4cd69f7..0d711251 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -17,8 +17,6 @@ """Talk to a DNS server.""" -from typing import Any, Dict, Optional, Tuple, Union - import base64 import contextlib import enum @@ -28,12 +26,12 @@ import selectors import socket import struct import time -import urllib.parse +from typing import Any, Dict, Optional, Tuple, Union import dns.exception import dns.inet -import dns.name import dns.message +import dns.name import dns.quic import dns.rcode import dns.rdataclass @@ -43,20 +41,32 @@ import dns.transaction import dns.tsig import dns.xfr -try: - import requests - from requests_toolbelt.adapters.source import SourceAddressAdapter - from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter - _have_requests = True -except ImportError: # pragma: no cover - _have_requests = False +def _remaining(expiration): + if expiration is None: + return None + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + return timeout + + +def _expiration_for_this_attempt(timeout, expiration): + if expiration is None: + return None + return min(time.time() + timeout, expiration) + _have_httpx = False _have_http2 = False try: + import httpcore + import httpcore._backends.sync import httpx + _CoreNetworkBackend = httpcore.NetworkBackend + _CoreSyncStream = httpcore._backends.sync.SyncStream + _have_httpx = True try: # See if http2 support is available. @@ -64,10 +74,87 @@ try: _have_http2 = True except Exception: pass -except ImportError: # pragma: no cover - pass -have_doh = _have_requests or _have_httpx + class _NetworkBackend(_CoreNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = dns.inet.low_level_address_tuple( + (local_address, self._local_port), af + ) + else: + source = None + sock = _make_socket(af, socket.SOCK_STREAM, source) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + try: + _connect( + sock, + dns.inet.low_level_address_tuple((address, port), af), + attempt_expiration, + ) + return _CoreSyncStream(sock) + except Exception: + pass + raise httpcore.ConnectError + + def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + class _HTTPTransport(httpx.HTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.resolver + + resolver = dns.resolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: # pragma: no cover + + class _HTTPTransport: # type: ignore + def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + +have_doh = _have_httpx try: import ssl @@ -88,7 +175,7 @@ except ImportError: # pragma: no cover @classmethod def create_default_context(cls, *args, **kwargs): - raise Exception("no ssl support") + raise Exception("no ssl support") # pylint: disable=broad-exception-raised # Function used to create a socket. Can be overridden if needed in special @@ -105,7 +192,7 @@ class BadResponse(dns.exception.FormError): class NoDOH(dns.exception.DNSException): - """DNS over HTTPS (DOH) was requested but the requests module is not + """DNS over HTTPS (DOH) was requested but the httpx module is not available.""" @@ -230,7 +317,7 @@ def _destination_and_source( # We know the destination af, so source had better agree! if saf != af: raise ValueError( - "different address families for source " + "and destination" + "different address families for source and destination" ) else: # We didn't know the destination af, but we know the source, @@ -240,11 +327,10 @@ def _destination_and_source( # Caller has specified a source_port but not an address, so we # need to return a source, and we need to use the appropriate # wildcard address as the address. - if af == socket.AF_INET: - source = "0.0.0.0" - elif af == socket.AF_INET6: - source = "::" - else: + try: + source = dns.inet.any_for_af(af) + except Exception: + # we catch this and raise ValueError for backwards compatibility raise ValueError("source_port specified but address family is unknown") # Convert high-level (address, port) tuples into low-level address # tuples. @@ -289,6 +375,8 @@ def https( post: bool = True, bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, + resolver: Optional["dns.resolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -314,91 +402,78 @@ def https( *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the - client/session to use to send the queries. + *session*, an ``httpx.Client``. If provided, the client session to use to send the + queries. *path*, a ``str``. If *where* is an IP address, then *path* will be used to construct the URL to send the DNS query to. *post*, a ``bool``. If ``True``, the default, POST method will be used. - *bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS - resolver. + *bootstrap_address*, a ``str``, the IP address to use to bypass resolution. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification of the server is done using the default CA bundle; if ``False``, then no verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames in URLs. If not specified, a new resolver with a default + configuration will be used; note this is *not* the default resolver as that resolver + might have been configured to use DoH causing a chicken-and-egg problem. This + parameter only has an effect if the HTTP library is httpx. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A + and AAAA records will be retrieved. + Returns a ``dns.message.Message``. """ if not have_doh: - raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover - - _httpx_ok = _have_httpx + raise NoDOH # pragma: no cover + if session and not isinstance(session, httpx.Client): + raise ValueError("session parameter must be an httpx.Client") wire = q.to_wire() - (af, _, source) = _destination_and_source(where, port, source, source_port, False) - transport_adapter = None + (af, _, the_source) = _destination_and_source( + where, port, source, source_port, False + ) transport = None headers = {"accept": "application/dns-message"} - if af is not None: + if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: url = "https://[{}]:{}{}".format(where, port, path) - elif bootstrap_address is not None: - _httpx_ok = False - split_url = urllib.parse.urlsplit(where) - if split_url.hostname is None: - raise ValueError("DoH URL has no hostname") - headers["Host"] = split_url.hostname - url = where.replace(split_url.hostname, bootstrap_address) - if _have_requests: - transport_adapter = HostHeaderSSLAdapter() else: url = where - if source is not None: - # set source port and source address - if _have_httpx: - if source_port == 0: - transport = httpx.HTTPTransport(local_address=source[0], verify=verify) - else: - _httpx_ok = False - if _have_requests: - transport_adapter = SourceAddressAdapter(source) - if session: - if _have_httpx: - _is_httpx = isinstance(session, httpx.Client) - else: - _is_httpx = False - if _is_httpx and not _httpx_ok: - raise NoDOH( - "Session is httpx, but httpx cannot be used for " - "the requested operation." - ) + # set source port and source address + + if the_source is None: + local_address = None + local_port = 0 else: - _is_httpx = _httpx_ok - - if not _httpx_ok and not _have_requests: - raise NoDOH( - "Cannot use httpx for this operation, and requests is not available." - ) + local_address = the_source[0] + local_port = the_source[1] + transport = _HTTPTransport( + local_address=local_address, + http1=True, + http2=_have_http2, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) - elif _is_httpx: + else: cm = httpx.Client( http1=True, http2=_have_http2, verify=verify, transport=transport ) - else: - cm = requests.sessions.Session() with cm as session: - if transport_adapter and not _is_httpx: - session.mount(url, transport_adapter) - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: @@ -408,29 +483,13 @@ def https( "content-length": str(len(wire)), } ) - if _is_httpx: - response = session.post( - url, headers=headers, content=wire, timeout=timeout - ) - else: - response = session.post( - url, headers=headers, data=wire, timeout=timeout, verify=verify - ) + response = session.post(url, headers=headers, content=wire, timeout=timeout) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") - if _is_httpx: - twire = wire.decode() # httpx does a repr() if we give it bytes - response = session.get( - url, headers=headers, timeout=timeout, params={"dns": twire} - ) - else: - response = session.get( - url, - headers=headers, - timeout=timeout, - verify=verify, - params={"dns": wire}, - ) + twire = wire.decode() # httpx does a repr() if we give it bytes + response = session.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes @@ -1070,6 +1129,7 @@ def quic( ignore_trailing: bool = False, connection: Optional[dns.quic.SyncQuicConnection] = None, verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-QUIC. @@ -1101,6 +1161,10 @@ def quic( verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + Returns a ``dns.message.Message``. """ @@ -1115,16 +1179,18 @@ def quic( manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) the_connection = connection else: - manager = dns.quic.SyncQuicManager(verify_mode=verify) + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=server_hostname + ) the_manager = manager # for type checking happiness with manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) - start = time.time() - with the_connection.make_stream() as stream: + (start, expiration) = _compute_times(timeout) + with the_connection.make_stream(timeout) as stream: stream.send(wire, True) - wire = stream.receive(timeout) + wire = stream.receive(_remaining(expiration)) finish = time.time() r = dns.message.from_wire( wire, diff --git a/lib/dns/quic/__init__.py b/lib/dns/quic/__init__.py index f48ecf57..69813f9f 100644 --- a/lib/dns/quic/__init__.py +++ b/lib/dns/quic/__init__.py @@ -5,13 +5,13 @@ try: import dns.asyncbackend from dns._asyncbackend import NullContext - from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream from dns.quic._asyncio import ( - AsyncioQuicManager, AsyncioQuicConnection, + AsyncioQuicManager, AsyncioQuicStream, ) from dns.quic._common import AsyncQuicConnection, AsyncQuicManager + from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream have_quic = True @@ -33,9 +33,10 @@ try: try: import trio + from dns.quic._trio import ( # pylint: disable=ungrouped-imports - TrioQuicManager, TrioQuicConnection, + TrioQuicManager, TrioQuicStream, ) diff --git a/lib/dns/quic/_asyncio.py b/lib/dns/quic/_asyncio.py index 0a2e220d..e1c52339 100644 --- a/lib/dns/quic/_asyncio.py +++ b/lib/dns/quic/_asyncio.py @@ -9,14 +9,16 @@ import time import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore -import dns.inet -import dns.asyncbackend +import dns.asyncbackend +import dns.exception +import dns.inet from dns.quic._common import ( - BaseQuicStream, + QUIC_MAX_DATAGRAM, AsyncQuicConnection, AsyncQuicManager, - QUIC_MAX_DATAGRAM, + BaseQuicStream, + UnexpectedEOF, ) @@ -30,15 +32,15 @@ class AsyncioQuicStream(BaseQuicStream): await self._wake_up.wait() async def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) if self._buffer.have(amount): return self._expecting = amount try: await asyncio.wait_for(self._wait_for_wake_up(), timeout) - except Exception: - pass + except TimeoutError: + raise dns.exception.Timeout self._expecting = 0 async def receive(self, timeout=None): @@ -86,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection): try: af = dns.inet.af_for_address(self._address) backend = dns.asyncbackend.get_backend("asyncio") + # Note that peer is a low-level address tuple, but make_socket() wants + # a high-level address tuple, so we convert. self._socket = await backend.make_socket( - af, socket.SOCK_DGRAM, 0, self._source, self._peer + af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1]) ) self._socket_created.set() async with self._socket: @@ -106,6 +110,11 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._wake_timer.notify_all() except Exception: pass + finally: + self._done = True + async with self._wake_timer: + self._wake_timer.notify_all() + self._handshake_complete.set() async def _wait_for_wake_timer(self): async with self._wake_timer: @@ -115,7 +124,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): await self._socket_created.wait() while not self._done: datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, address) in datagrams: + for datagram, address in datagrams: assert address == self._peer[0] await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() @@ -160,8 +169,13 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._receiver_task = asyncio.Task(self._receiver()) self._sender_task = asyncio.Task(self._sender()) - async def make_stream(self): - await self._handshake_complete.wait() + async def make_stream(self, timeout=None): + try: + await asyncio.wait_for(self._handshake_complete.wait(), timeout) + except TimeoutError: + raise dns.exception.Timeout + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = AsyncioQuicStream(self, stream_id) self._streams[stream_id] = stream @@ -172,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._manager.closed(self._peer[0], self._peer[1]) self._closed = True self._connection.close() + # sender might be blocked on this, so set it + self._socket_created.set() + await self._socket.close() async with self._wake_timer: self._wake_timer.notify_all() try: @@ -185,8 +202,8 @@ class AsyncioQuicConnection(AsyncQuicConnection): class AsyncioQuicManager(AsyncQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, AsyncioQuicConnection) + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) def connect(self, address, port=853, source=None, source_port=0): (connection, start) = self._connect(address, port, source, source_port) @@ -198,7 +215,7 @@ class AsyncioQuicManager(AsyncQuicManager): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - # Copy the itertor into a list as exiting things will mutate the connections + # Copy the iterator into a list as exiting things will mutate the connections # table. connections = list(self._connections.values()) for connection in connections: diff --git a/lib/dns/quic/_common.py b/lib/dns/quic/_common.py index d8f6f7fd..38ec103f 100644 --- a/lib/dns/quic/_common.py +++ b/lib/dns/quic/_common.py @@ -3,13 +3,12 @@ import socket import struct import time - -from typing import Any +from typing import Any, Optional import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore -import dns.inet +import dns.inet QUIC_MAX_DATAGRAM = 2048 @@ -135,12 +134,12 @@ class BaseQuicConnection: class AsyncQuicConnection(BaseQuicConnection): - async def make_stream(self) -> Any: + async def make_stream(self, timeout: Optional[float] = None) -> Any: pass class BaseQuicManager: - def __init__(self, conf, verify_mode, connection_factory): + def __init__(self, conf, verify_mode, connection_factory, server_name=None): self._connections = {} self._connection_factory = connection_factory if conf is None: @@ -151,6 +150,7 @@ class BaseQuicManager: conf = aioquic.quic.configuration.QuicConfiguration( alpn_protocols=["doq", "doq-i03"], verify_mode=verify_mode, + server_name=server_name, ) if verify_path is not None: conf.load_verify_locations(verify_path) diff --git a/lib/dns/quic/_sync.py b/lib/dns/quic/_sync.py index be005ba9..e944784d 100644 --- a/lib/dns/quic/_sync.py +++ b/lib/dns/quic/_sync.py @@ -1,8 +1,8 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import selectors import socket import ssl -import selectors import struct import threading import time @@ -10,13 +10,15 @@ import time import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore -import dns.inet +import dns.exception +import dns.inet from dns.quic._common import ( - BaseQuicStream, + QUIC_MAX_DATAGRAM, BaseQuicConnection, BaseQuicManager, - QUIC_MAX_DATAGRAM, + BaseQuicStream, + UnexpectedEOF, ) # Avoid circularity with dns.query @@ -33,14 +35,15 @@ class SyncQuicStream(BaseQuicStream): self._lock = threading.Lock() def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) with self._lock: if self._buffer.have(amount): return self._expecting = amount with self._wake_up: - self._wake_up.wait(timeout) + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout self._expecting = 0 def receive(self, timeout=None): @@ -114,24 +117,30 @@ class SyncQuicConnection(BaseQuicConnection): return def _worker(self): - sel = _selector_class() - sel.register(self._socket, selectors.EVENT_READ, self._read) - sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) - while not self._done: - (expiration, interval) = self._get_timer_values(False) - items = sel.select(interval) - for (key, _) in items: - key.data() + try: + sel = _selector_class() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for key, _ in items: + key.data() + with self._lock: + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + self._handle_events() + finally: with self._lock: - self._handle_timer(expiration) - datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, _) in datagrams: - try: - self._socket.send(datagram) - except BlockingIOError: - # we let QUIC handle any lossage - pass - self._handle_events() + self._done = True + # Ensure anyone waiting for this gets woken up. + self._handshake_complete.set() def _handle_events(self): while True: @@ -163,9 +172,12 @@ class SyncQuicConnection(BaseQuicConnection): self._worker_thread = threading.Thread(target=self._worker) self._worker_thread.start() - def make_stream(self): - self._handshake_complete.wait() + def make_stream(self, timeout=None): + if not self._handshake_complete.wait(timeout): + raise dns.exception.Timeout with self._lock: + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = SyncQuicStream(self, stream_id) self._streams[stream_id] = stream @@ -187,8 +199,8 @@ class SyncQuicConnection(BaseQuicConnection): class SyncQuicManager(BaseQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, SyncQuicConnection) + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name) self._lock = threading.Lock() def connect(self, address, port=853, source=None, source_port=0): @@ -206,7 +218,7 @@ class SyncQuicManager(BaseQuicManager): return self def __exit__(self, exc_type, exc_val, exc_tb): - # Copy the itertor into a list as exiting things will mutate the connections + # Copy the iterator into a list as exiting things will mutate the connections # table. connections = list(self._connections.values()) for connection in connections: diff --git a/lib/dns/quic/_trio.py b/lib/dns/quic/_trio.py index 1e47a5a6..ee07e4f6 100644 --- a/lib/dns/quic/_trio.py +++ b/lib/dns/quic/_trio.py @@ -10,13 +10,15 @@ import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore import trio +import dns.exception import dns.inet from dns._asyncbackend import NullContext from dns.quic._common import ( - BaseQuicStream, + QUIC_MAX_DATAGRAM, AsyncQuicConnection, AsyncQuicManager, - QUIC_MAX_DATAGRAM, + BaseQuicStream, + UnexpectedEOF, ) @@ -44,6 +46,7 @@ class TrioQuicStream(BaseQuicStream): (size,) = struct.unpack("!H", self._buffer.get(2)) await self.wait_for(size) return self._buffer.get(size) + raise dns.exception.Timeout async def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -80,20 +83,26 @@ class TrioQuicConnection(AsyncQuicConnection): self._worker_scope = None async def _worker(self): - await self._socket.connect(self._peer) - while not self._done: - (expiration, interval) = self._get_timer_values(False) - with trio.CancelScope( - deadline=trio.current_time() + interval - ) as self._worker_scope: - datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) - self._connection.receive_datagram(datagram, self._peer[0], time.time()) - self._worker_scope = None - self._handle_timer(expiration) - datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, _) in datagrams: - await self._socket.send(datagram) - await self._handle_events() + try: + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram( + datagram, self._peer[0], time.time() + ) + self._worker_scope = None + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + await self._socket.send(datagram) + await self._handle_events() + finally: + self._done = True + self._handshake_complete.set() async def _handle_events(self): count = 0 @@ -130,12 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection): nursery.start_soon(self._worker) self._run_done.set() - async def make_stream(self): - await self._handshake_complete.wait() - stream_id = self._connection.get_next_available_stream_id(False) - stream = TrioQuicStream(self, stream_id) - self._streams[stream_id] = stream - return stream + async def make_stream(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + raise dns.exception.Timeout async def close(self): if not self._closed: @@ -148,8 +165,10 @@ class TrioQuicConnection(AsyncQuicConnection): class TrioQuicManager(AsyncQuicManager): - def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, TrioQuicConnection) + def __init__( + self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None + ): + super().__init__(conf, verify_mode, TrioQuicConnection, server_name) self._nursery = nursery def connect(self, address, port=853, source=None, source_port=0): @@ -162,7 +181,7 @@ class TrioQuicManager(AsyncQuicManager): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - # Copy the itertor into a list as exiting things will mutate the connections + # Copy the iterator into a list as exiting things will mutate the connections # table. connections = list(self._connections.values()) for connection in connections: diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index 1dd6ed90..0d262e8d 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -17,17 +17,15 @@ """DNS rdata.""" -from typing import Any, Dict, Optional, Tuple, Union - -from importlib import import_module import base64 import binascii -import io import inspect +import io import itertools import random +from importlib import import_module +from typing import Any, Dict, Optional, Tuple, Union -import dns.wire import dns.exception import dns.immutable import dns.ipv4 @@ -37,6 +35,7 @@ import dns.rdataclass import dns.rdatatype import dns.tokenizer import dns.ttl +import dns.wire _chunksize = 32 @@ -358,7 +357,6 @@ class Rdata: or self.rdclass != other.rdclass or self.rdtype != other.rdtype ): - return NotImplemented return self._cmp(other) < 0 @@ -881,16 +879,11 @@ def register_type( it applies to all classes. """ - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - existing_cls = get_rdata_class(rdclass, the_rdtype) - if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): - raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) - try: - if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text: - raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) - except ValueError: - pass - _rdata_classes[(rdclass, the_rdtype)] = getattr( + rdtype = dns.rdatatype.RdataType.make(rdtype) + existing_cls = get_rdata_class(rdclass, rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + _rdata_classes[(rdclass, rdtype)] = getattr( implementation, rdtype_text.replace("-", "_") ) - dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) + dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index c0ede425..31124afc 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -17,18 +17,17 @@ """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" -from typing import Any, cast, Collection, Dict, List, Optional, Union - import io import random import struct +from typing import Any, Collection, Dict, List, Optional, Union, cast import dns.exception import dns.immutable import dns.name -import dns.rdatatype -import dns.rdataclass import dns.rdata +import dns.rdataclass +import dns.rdatatype import dns.set import dns.ttl @@ -471,9 +470,9 @@ def from_text_list( Returns a ``dns.rdataset.Rdataset`` object. """ - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - r = Rdataset(the_rdclass, the_rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = Rdataset(rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text( diff --git a/lib/dns/rdtypes/ANY/AFSDB.py b/lib/dns/rdtypes/ANY/AFSDB.py index d7838e7e..3d287f6e 100644 --- a/lib/dns/rdtypes/ANY/AFSDB.py +++ b/lib/dns/rdtypes/ANY/AFSDB.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/AVC.py b/lib/dns/rdtypes/ANY/AVC.py index 11e026d0..766d5e2d 100644 --- a/lib/dns/rdtypes/ANY/AVC.py +++ b/lib/dns/rdtypes/ANY/AVC.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/CDNSKEY.py b/lib/dns/rdtypes/ANY/CDNSKEY.py index 869523fb..38b8a8da 100644 --- a/lib/dns/rdtypes/ANY/CDNSKEY.py +++ b/lib/dns/rdtypes/ANY/CDNSKEY.py @@ -15,15 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import ( - SEP, +from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] REVOKE, + SEP, ZONE, -) # noqa: F401 lgtm[py/unused-import] +) # pylint: enable=unused-import diff --git a/lib/dns/rdtypes/ANY/CDS.py b/lib/dns/rdtypes/ANY/CDS.py index 094de12b..2ff42d9a 100644 --- a/lib/dns/rdtypes/ANY/CDS.py +++ b/lib/dns/rdtypes/ANY/CDS.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dsbase import dns.immutable +import dns.rdtypes.dsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/CERT.py b/lib/dns/rdtypes/ANY/CERT.py index 1b0cbeca..30fe863f 100644 --- a/lib/dns/rdtypes/ANY/CERT.py +++ b/lib/dns/rdtypes/ANY/CERT.py @@ -15,12 +15,12 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import base64 +import struct +import dns.dnssectypes import dns.exception import dns.immutable -import dns.dnssectypes import dns.rdata import dns.tokenizer diff --git a/lib/dns/rdtypes/ANY/CNAME.py b/lib/dns/rdtypes/ANY/CNAME.py index a4fcfa88..759adb90 100644 --- a/lib/dns/rdtypes/ANY/CNAME.py +++ b/lib/dns/rdtypes/ANY/CNAME.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/CSYNC.py b/lib/dns/rdtypes/ANY/CSYNC.py index f819c08c..315da9ff 100644 --- a/lib/dns/rdtypes/ANY/CSYNC.py +++ b/lib/dns/rdtypes/ANY/CSYNC.py @@ -19,9 +19,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdatatype -import dns.name import dns.rdtypes.util diff --git a/lib/dns/rdtypes/ANY/DLV.py b/lib/dns/rdtypes/ANY/DLV.py index 947dc42e..632e90f8 100644 --- a/lib/dns/rdtypes/ANY/DLV.py +++ b/lib/dns/rdtypes/ANY/DLV.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dsbase import dns.immutable +import dns.rdtypes.dsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/DNAME.py b/lib/dns/rdtypes/ANY/DNAME.py index f4984b55..556bff59 100644 --- a/lib/dns/rdtypes/ANY/DNAME.py +++ b/lib/dns/rdtypes/ANY/DNAME.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/DNSKEY.py b/lib/dns/rdtypes/ANY/DNSKEY.py index 50fa05b7..f1a63062 100644 --- a/lib/dns/rdtypes/ANY/DNSKEY.py +++ b/lib/dns/rdtypes/ANY/DNSKEY.py @@ -15,15 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import ( - SEP, +from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] REVOKE, + SEP, ZONE, -) # noqa: F401 lgtm[py/unused-import] +) # pylint: enable=unused-import diff --git a/lib/dns/rdtypes/ANY/DS.py b/lib/dns/rdtypes/ANY/DS.py index 3f6c3ee8..097ecfa0 100644 --- a/lib/dns/rdtypes/ANY/DS.py +++ b/lib/dns/rdtypes/ANY/DS.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dsbase import dns.immutable +import dns.rdtypes.dsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/EUI48.py b/lib/dns/rdtypes/ANY/EUI48.py index 0ab88ad0..7e4e1ff3 100644 --- a/lib/dns/rdtypes/ANY/EUI48.py +++ b/lib/dns/rdtypes/ANY/EUI48.py @@ -16,8 +16,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.euibase import dns.immutable +import dns.rdtypes.euibase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/EUI64.py b/lib/dns/rdtypes/ANY/EUI64.py index c42957ef..68b5820f 100644 --- a/lib/dns/rdtypes/ANY/EUI64.py +++ b/lib/dns/rdtypes/ANY/EUI64.py @@ -16,8 +16,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.euibase import dns.immutable +import dns.rdtypes.euibase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index 01fec822..a20aa1e5 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -15,9 +15,9 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import base64 import binascii +import struct import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index 52c97532..783d54af 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -21,7 +21,6 @@ import dns.exception import dns.immutable import dns.rdata - _pows = tuple(10**i for i in range(0, 11)) # default values are in centimeters @@ -40,7 +39,7 @@ def _exponent_of(what, desc): if what == 0: return 0 exp = None - for (i, pow) in enumerate(_pows): + for i, pow in enumerate(_pows): if what < pow: exp = i - 1 break diff --git a/lib/dns/rdtypes/ANY/MX.py b/lib/dns/rdtypes/ANY/MX.py index a697ea45..1f9df21f 100644 --- a/lib/dns/rdtypes/ANY/MX.py +++ b/lib/dns/rdtypes/ANY/MX.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/NINFO.py b/lib/dns/rdtypes/ANY/NINFO.py index d53e9676..55bc5614 100644 --- a/lib/dns/rdtypes/ANY/NINFO.py +++ b/lib/dns/rdtypes/ANY/NINFO.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/NS.py b/lib/dns/rdtypes/ANY/NS.py index a0cc232a..fe453f0d 100644 --- a/lib/dns/rdtypes/ANY/NS.py +++ b/lib/dns/rdtypes/ANY/NS.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index 7af7b77f..a2d98fa7 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -17,9 +17,9 @@ import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdatatype -import dns.name import dns.rdtypes.util diff --git a/lib/dns/rdtypes/ANY/NSEC3.py b/lib/dns/rdtypes/ANY/NSEC3.py index 6eae16e0..d32fe169 100644 --- a/lib/dns/rdtypes/ANY/NSEC3.py +++ b/lib/dns/rdtypes/ANY/NSEC3.py @@ -25,7 +25,6 @@ import dns.rdata import dns.rdatatype import dns.rdtypes.util - b32_hex_to_normal = bytes.maketrans( b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" ) @@ -67,6 +66,7 @@ class NSEC3(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() + next = next.rstrip("=") if self.salt == b"": salt = "-" else: @@ -94,6 +94,10 @@ class NSEC3(dns.rdata.Rdata): else: salt = binascii.unhexlify(salt.encode("ascii")) next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) + if next.endswith(b"="): + raise binascii.Error("Incorrect padding") + if len(next) % 8 != 0: + next += b"=" * (8 - len(next) % 8) next = base64.b32decode(next) bitmap = Bitmap.from_text(tok) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) diff --git a/lib/dns/rdtypes/ANY/NSEC3PARAM.py b/lib/dns/rdtypes/ANY/NSEC3PARAM.py index 1b7269a0..1a0c0e08 100644 --- a/lib/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/lib/dns/rdtypes/ANY/NSEC3PARAM.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/ANY/OPT.py b/lib/dns/rdtypes/ANY/OPT.py index 36d4c7c6..d70e5373 100644 --- a/lib/dns/rdtypes/ANY/OPT.py +++ b/lib/dns/rdtypes/ANY/OPT.py @@ -18,11 +18,10 @@ import struct import dns.edns -import dns.immutable import dns.exception +import dns.immutable import dns.rdata - # We don't implement from_text, and that's ok. # pylint: disable=abstract-method diff --git a/lib/dns/rdtypes/ANY/PTR.py b/lib/dns/rdtypes/ANY/PTR.py index 265bed03..7fd5547d 100644 --- a/lib/dns/rdtypes/ANY/PTR.py +++ b/lib/dns/rdtypes/ANY/PTR.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index c0c316b5..9c64c6e2 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -17,8 +17,8 @@ import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/RRSIG.py b/lib/dns/rdtypes/ANY/RRSIG.py index 3d5ad0f3..11605026 100644 --- a/lib/dns/rdtypes/ANY/RRSIG.py +++ b/lib/dns/rdtypes/ANY/RRSIG.py @@ -21,8 +21,8 @@ import struct import time import dns.dnssectypes -import dns.immutable import dns.exception +import dns.immutable import dns.rdata import dns.rdatatype diff --git a/lib/dns/rdtypes/ANY/RT.py b/lib/dns/rdtypes/ANY/RT.py index 8d9c6bd0..950f2a06 100644 --- a/lib/dns/rdtypes/ANY/RT.py +++ b/lib/dns/rdtypes/ANY/RT.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/SOA.py b/lib/dns/rdtypes/ANY/SOA.py index 6f6fe58b..bde55e15 100644 --- a/lib/dns/rdtypes/ANY/SOA.py +++ b/lib/dns/rdtypes/ANY/SOA.py @@ -19,8 +19,8 @@ import struct import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/SPF.py b/lib/dns/rdtypes/ANY/SPF.py index 1190e0de..c403589a 100644 --- a/lib/dns/rdtypes/ANY/SPF.py +++ b/lib/dns/rdtypes/ANY/SPF.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/SSHFP.py b/lib/dns/rdtypes/ANY/SSHFP.py index 58ffcbbc..67805452 100644 --- a/lib/dns/rdtypes/ANY/SSHFP.py +++ b/lib/dns/rdtypes/ANY/SSHFP.py @@ -15,11 +15,11 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct -import dns.rdata import dns.immutable +import dns.rdata import dns.rdatatype diff --git a/lib/dns/rdtypes/ANY/TKEY.py b/lib/dns/rdtypes/ANY/TKEY.py index 070f03af..d5f5fc45 100644 --- a/lib/dns/rdtypes/ANY/TKEY.py +++ b/lib/dns/rdtypes/ANY/TKEY.py @@ -18,8 +18,8 @@ import base64 import struct -import dns.immutable import dns.exception +import dns.immutable import dns.rdata diff --git a/lib/dns/rdtypes/ANY/TXT.py b/lib/dns/rdtypes/ANY/TXT.py index cc4b6611..f4e61930 100644 --- a/lib/dns/rdtypes/ANY/TXT.py +++ b/lib/dns/rdtypes/ANY/TXT.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/URI.py b/lib/dns/rdtypes/ANY/URI.py index b4c95a3b..7463e277 100644 --- a/lib/dns/rdtypes/ANY/URI.py +++ b/lib/dns/rdtypes/ANY/URI.py @@ -20,9 +20,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdtypes.util -import dns.name @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/ZONEMD.py b/lib/dns/rdtypes/ANY/ZONEMD.py index 1f86ba49..3062843b 100644 --- a/lib/dns/rdtypes/ANY/ZONEMD.py +++ b/lib/dns/rdtypes/ANY/ZONEMD.py @@ -1,7 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import struct import binascii +import struct import dns.immutable import dns.rdata diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py index 9905c7c9..e457f38a 100644 --- a/lib/dns/rdtypes/CH/A.py +++ b/lib/dns/rdtypes/CH/A.py @@ -17,8 +17,8 @@ import struct -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/APL.py b/lib/dns/rdtypes/IN/APL.py index 05e1689f..f1bb01db 100644 --- a/lib/dns/rdtypes/IN/APL.py +++ b/lib/dns/rdtypes/IN/APL.py @@ -124,7 +124,6 @@ class APL(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - items = [] while parser.remaining() > 0: header = parser.get_struct("!HBB") diff --git a/lib/dns/rdtypes/IN/HTTPS.py b/lib/dns/rdtypes/IN/HTTPS.py index 7797fbaf..15464cbd 100644 --- a/lib/dns/rdtypes/IN/HTTPS.py +++ b/lib/dns/rdtypes/IN/HTTPS.py @@ -1,7 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import dns.rdtypes.svcbbase import dns.immutable +import dns.rdtypes.svcbbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/IPSECKEY.py b/lib/dns/rdtypes/IN/IPSECKEY.py index 1255739f..8bb2bcb6 100644 --- a/lib/dns/rdtypes/IN/IPSECKEY.py +++ b/lib/dns/rdtypes/IN/IPSECKEY.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import base64 +import struct import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/IN/KX.py b/lib/dns/rdtypes/IN/KX.py index c27e9215..a03d1d51 100644 --- a/lib/dns/rdtypes/IN/KX.py +++ b/lib/dns/rdtypes/IN/KX.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/NSAP_PTR.py b/lib/dns/rdtypes/IN/NSAP_PTR.py index 57dadd47..0a18fdce 100644 --- a/lib/dns/rdtypes/IN/NSAP_PTR.py +++ b/lib/dns/rdtypes/IN/NSAP_PTR.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/PX.py b/lib/dns/rdtypes/IN/PX.py index b2216d6b..5c0aa81e 100644 --- a/lib/dns/rdtypes/IN/PX.py +++ b/lib/dns/rdtypes/IN/PX.py @@ -19,9 +19,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdtypes.util -import dns.name @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/SRV.py b/lib/dns/rdtypes/IN/SRV.py index 8b0b6bf7..84c54007 100644 --- a/lib/dns/rdtypes/IN/SRV.py +++ b/lib/dns/rdtypes/IN/SRV.py @@ -19,9 +19,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdtypes.util -import dns.name @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/SVCB.py b/lib/dns/rdtypes/IN/SVCB.py index 9a1ad101..ff3e9327 100644 --- a/lib/dns/rdtypes/IN/SVCB.py +++ b/lib/dns/rdtypes/IN/SVCB.py @@ -1,7 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import dns.rdtypes.svcbbase import dns.immutable +import dns.rdtypes.svcbbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/WKS.py b/lib/dns/rdtypes/IN/WKS.py index a671e203..26d287a3 100644 --- a/lib/dns/rdtypes/IN/WKS.py +++ b/lib/dns/rdtypes/IN/WKS.py @@ -18,8 +18,8 @@ import socket import struct -import dns.ipv4 import dns.immutable +import dns.ipv4 import dns.rdata try: diff --git a/lib/dns/rdtypes/dnskeybase.py b/lib/dns/rdtypes/dnskeybase.py index 1d17f70f..3bfcf860 100644 --- a/lib/dns/rdtypes/dnskeybase.py +++ b/lib/dns/rdtypes/dnskeybase.py @@ -19,9 +19,9 @@ import base64 import enum import struct +import dns.dnssectypes import dns.exception import dns.immutable -import dns.dnssectypes import dns.rdata # wildcard import @@ -43,7 +43,7 @@ class DNSKEYBase(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): super().__init__(rdclass, rdtype) - self.flags = self._as_uint16(flags) + self.flags = Flag(self._as_uint16(flags)) self.protocol = self._as_uint8(protocol) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.key = self._as_bytes(key) diff --git a/lib/dns/rdtypes/dsbase.py b/lib/dns/rdtypes/dsbase.py index b6032b0f..1ad0b7a5 100644 --- a/lib/dns/rdtypes/dsbase.py +++ b/lib/dns/rdtypes/dsbase.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct import dns.dnssectypes import dns.immutable @@ -44,7 +44,7 @@ class DSBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.key_tag = self._as_uint16(key_tag) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) - self.digest_type = self._as_uint8(digest_type) + self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type)) self.digest = self._as_bytes(digest) try: if len(self.digest) != self._digest_length_by_type[self.digest_type]: diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index e524aea9..4c4068b2 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -16,8 +16,8 @@ import binascii -import dns.rdata import dns.immutable +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/mxbase.py b/lib/dns/rdtypes/mxbase.py index b4b9b088..a6bae078 100644 --- a/lib/dns/rdtypes/mxbase.py +++ b/lib/dns/rdtypes/mxbase.py @@ -21,8 +21,8 @@ import struct import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata import dns.rdtypes.util diff --git a/lib/dns/rdtypes/nsbase.py b/lib/dns/rdtypes/nsbase.py index ba7a2ab7..56d94235 100644 --- a/lib/dns/rdtypes/nsbase.py +++ b/lib/dns/rdtypes/nsbase.py @@ -19,8 +19,8 @@ import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/svcbbase.py b/lib/dns/rdtypes/svcbbase.py index 8d6fb1c6..ba5b53d2 100644 --- a/lib/dns/rdtypes/svcbbase.py +++ b/lib/dns/rdtypes/svcbbase.py @@ -34,6 +34,7 @@ class ParamKey(dns.enum.IntEnum): IPV4HINT = 4 ECH = 5 IPV6HINT = 6 + DOHPATH = 7 @classmethod def _maximum(cls): diff --git a/lib/dns/rdtypes/tlsabase.py b/lib/dns/rdtypes/tlsabase.py index a3fdc354..4cdb7ab3 100644 --- a/lib/dns/rdtypes/tlsabase.py +++ b/lib/dns/rdtypes/tlsabase.py @@ -15,11 +15,11 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct -import dns.rdata import dns.immutable +import dns.rdata import dns.rdatatype diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index d4cb9bb2..fdbfb646 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -17,9 +17,8 @@ """TXT-like base class.""" -from typing import Any, Dict, Iterable, Optional, Tuple, Union - import struct +from typing import Any, Dict, Iterable, Optional, Tuple, Union import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/util.py b/lib/dns/rdtypes/util.py index 74596f05..54908fdc 100644 --- a/lib/dns/rdtypes/util.py +++ b/lib/dns/rdtypes/util.py @@ -18,6 +18,7 @@ import collections import random import struct +from typing import Any, List import dns.exception import dns.ipv4 @@ -119,7 +120,7 @@ class Bitmap: def __init__(self, windows=None): last_window = -1 self.windows = windows - for (window, bitmap) in self.windows: + for window, bitmap in self.windows: if not isinstance(window, int): raise ValueError(f"bad {self.type_name} window type") if window <= last_window: @@ -132,11 +133,11 @@ class Bitmap: if len(bitmap) == 0 or len(bitmap) > 32: raise ValueError(f"bad {self.type_name} octets") - def to_text(self): + def to_text(self) -> str: text = "" - for (window, bitmap) in self.windows: + for window, bitmap in self.windows: bits = [] - for (i, byte) in enumerate(bitmap): + for i, byte in enumerate(bitmap): for j in range(0, 8): if byte & (0x80 >> j): rdtype = window * 256 + i * 8 + j @@ -145,14 +146,18 @@ class Bitmap: return text @classmethod - def from_text(cls, tok): + def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap": rdtypes = [] for token in tok.get_remaining(): rdtype = dns.rdatatype.from_text(token.unescape().value) if rdtype == 0: raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0") rdtypes.append(rdtype) - rdtypes.sort() + return cls.from_rdtypes(rdtypes) + + @classmethod + def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap": + rdtypes = sorted(rdtypes) window = 0 octets = 0 prior_rdtype = 0 @@ -177,13 +182,13 @@ class Bitmap: windows.append((window, bytes(bitmap[0:octets]))) return cls(windows) - def to_wire(self, file): - for (window, bitmap) in self.windows: + def to_wire(self, file: Any) -> None: + for window, bitmap in self.windows: file.write(struct.pack("!BB", window, len(bitmap))) file.write(bitmap) @classmethod - def from_wire_parser(cls, parser): + def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap": windows = [] while parser.remaining() > 0: window = parser.get_uint8() @@ -226,7 +231,7 @@ def weighted_processing_order(iterable): total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) while len(rdatas) > 1: r = random.uniform(0, total) - for (n, rdata) in enumerate(rdatas): + for n, rdata in enumerate(rdatas): weight = rdata._processing_weight() or _no_weight if weight > r: break diff --git a/lib/dns/renderer.py b/lib/dns/renderer.py index 3c495f61..53e7c0f6 100644 --- a/lib/dns/renderer.py +++ b/lib/dns/renderer.py @@ -19,14 +19,13 @@ import contextlib import io -import struct import random +import struct import time import dns.exception import dns.tsig - QUESTION = 0 ANSWER = 1 AUTHORITY = 2 diff --git a/lib/dns/resolver.py b/lib/dns/resolver.py index a5b66c1d..f08f824d 100644 --- a/lib/dns/resolver.py +++ b/lib/dns/resolver.py @@ -17,29 +17,31 @@ """DNS stub resolver.""" -from typing import Any, Dict, List, Optional, Tuple, Union - -from urllib.parse import urlparse import contextlib +import random import socket import sys import threading import time -import random import warnings +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlparse -import dns.exception +import dns._ddr import dns.edns +import dns.exception import dns.flags import dns.inet import dns.ipv4 import dns.ipv6 import dns.message import dns.name +import dns.nameserver import dns.query import dns.rcode import dns.rdataclass import dns.rdatatype +import dns.rdtypes.svcbbase import dns.reversename import dns.tsig @@ -72,7 +74,7 @@ class NXDOMAIN(dns.exception.DNSException): kwargs = dict(qnames=qnames, responses=responses) return kwargs - def __str__(self): + def __str__(self) -> str: if "qnames" not in self.kwargs: return super().__str__() qnames = self.kwargs["qnames"] @@ -140,7 +142,11 @@ class YXDOMAIN(dns.exception.DNSException): ErrorTuple = Tuple[ - Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message] + Optional[str], + bool, + int, + Union[Exception, str], + Optional[dns.message.Message], ] @@ -148,11 +154,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: """Turn a resolution errors trace into a list of text.""" texts = [] for err in errors: - texts.append( - "Server {} {} port {} answered {}".format( - err[0], "TCP" if err[1] else "UDP", err[2], err[3] - ) - ) + texts.append("Server {} answered {}".format(err[0], err[3])) return texts @@ -184,7 +186,7 @@ Timeout = LifetimeTimeout class NoAnswer(dns.exception.DNSException): """The DNS response does not contain an answer to the question.""" - fmt = "The DNS response does not contain an answer " + "to the question: {query}" + fmt = "The DNS response does not contain an answer to the question: {query}" supp_kwargs = {"response"} # We do this as otherwise mypy complains about unexpected keyword argument @@ -264,7 +266,7 @@ class Answer: response: dns.message.QueryMessage, nameserver: Optional[str] = None, port: Optional[int] = None, - ): + ) -> None: self.qname = qname self.rdtype = rdtype self.rdclass = rdclass @@ -292,7 +294,7 @@ class Answer: else: raise AttributeError(attr) - def __len__(self): + def __len__(self) -> int: return self.rrset and len(self.rrset) or 0 def __iter__(self): @@ -309,14 +311,67 @@ class Answer: del self.rrset[i] +class Answers(dict): + """A dict of DNS stub resolver answers, indexed by type.""" + + +class HostAnswers(Answers): + """A dict of DNS stub resolver answers to a host name lookup, indexed by + type. + """ + + @classmethod + def make( + cls, + v6: Optional[Answer] = None, + v4: Optional[Answer] = None, + add_empty: bool = True, + ) -> "HostAnswers": + answers = HostAnswers() + if v6 is not None and (add_empty or v6.rrset): + answers[dns.rdatatype.AAAA] = v6 + if v4 is not None and (add_empty or v4.rrset): + answers[dns.rdatatype.A] = v4 + return answers + + # Returns pairs of (address, family) from this result, potentiallys + # filtering by address family. + def addresses_and_families( + self, family: int = socket.AF_UNSPEC + ) -> Iterator[Tuple[str, int]]: + if family == socket.AF_UNSPEC: + yield from self.addresses_and_families(socket.AF_INET6) + yield from self.addresses_and_families(socket.AF_INET) + return + elif family == socket.AF_INET6: + answer = self.get(dns.rdatatype.AAAA) + elif family == socket.AF_INET: + answer = self.get(dns.rdatatype.A) + else: + raise NotImplementedError(f"unknown address family {family}") + if answer: + for rdata in answer: + yield (rdata.address, family) + + # Returns addresses from this result, potentially filtering by + # address family. + def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]: + return (pair[0] for pair in self.addresses_and_families(family)) + + # Returns the canonical name from this result. + def canonical_name(self) -> dns.name.Name: + answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A)) + return answer.canonical_name + + class CacheStatistics: """Cache Statistics""" - def __init__(self, hits=0, misses=0): + def __init__(self, hits: int = 0, misses: int = 0) -> None: self.hits = hits self.misses = misses - def reset(self): + def reset(self) -> None: self.hits = 0 self.misses = 0 @@ -325,7 +380,7 @@ class CacheStatistics: class CacheBase: - def __init__(self): + def __init__(self) -> None: self.lock = threading.Lock() self.statistics = CacheStatistics() @@ -361,7 +416,7 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" - def __init__(self, cleaning_interval: float = 300.0): + def __init__(self, cleaning_interval: float = 300.0) -> None: """*cleaning_interval*, a ``float`` is the number of seconds between periodic cleanings. """ @@ -377,7 +432,7 @@ class Cache(CacheBase): now = time.time() if self.next_cleaning <= now: keys_to_delete = [] - for (k, v) in self.data.items(): + for k, v in self.data.items(): if v.expiration <= now: keys_to_delete.append(k) for k in keys_to_delete: @@ -447,13 +502,13 @@ class LRUCacheNode: self.prev = self self.next = self - def link_after(self, node): + def link_after(self, node: "LRUCacheNode") -> None: self.prev = node self.next = node.next node.next.prev = self node.next = self - def unlink(self): + def unlink(self) -> None: self.next.prev = self.prev self.prev.next = self.next @@ -468,7 +523,7 @@ class LRUCache(CacheBase): for a new one. """ - def __init__(self, max_size: int = 100000): + def __init__(self, max_size: int = 100000) -> None: """*max_size*, an ``int``, is the maximum number of nodes to cache; it must be greater than 0. """ @@ -590,30 +645,29 @@ class _Resolution: tcp: bool, raise_on_no_answer: bool, search: Optional[bool], - ): + ) -> None: if isinstance(qname, str): qname = dns.name.from_text(qname, None) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - if dns.rdatatype.is_metatype(the_rdtype): + rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(rdtype): raise NoMetaqueries - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - if dns.rdataclass.is_metaclass(the_rdclass): + rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(rdclass): raise NoMetaqueries self.resolver = resolver self.qnames_to_try = resolver._get_qnames_to_try(qname, search) self.qnames = self.qnames_to_try[:] - self.rdtype = the_rdtype - self.rdclass = the_rdclass + self.rdtype = rdtype + self.rdclass = rdclass self.tcp = tcp self.raise_on_no_answer = raise_on_no_answer self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} # Initialize other things to help analysis tools self.qname = dns.name.empty - self.nameservers: List[str] = [] - self.current_nameservers: List[str] = [] + self.nameservers: List[dns.nameserver.Nameserver] = [] + self.current_nameservers: List[dns.nameserver.Nameserver] = [] self.errors: List[ErrorTuple] = [] - self.nameserver: Optional[str] = None - self.port = 0 + self.nameserver: Optional[dns.nameserver.Nameserver] = None self.tcp_attempt = False self.retry_with_tcp = False self.request: Optional[dns.message.QueryMessage] = None @@ -670,7 +724,11 @@ class _Resolution: if self.resolver.flags is not None: request.flags = self.resolver.flags - self.nameservers = self.resolver.nameservers[:] + self.nameservers = self.resolver._enrich_nameservers( + self.resolver._nameservers, + self.resolver.nameserver_ports, + self.resolver.port, + ) if self.resolver.rotate: random.shuffle(self.nameservers) self.current_nameservers = self.nameservers[:] @@ -690,12 +748,13 @@ class _Resolution: # raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) - def next_nameserver(self) -> Tuple[str, int, bool, float]: + def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]: if self.retry_with_tcp: assert self.nameserver is not None + assert not self.nameserver.is_always_max_size() self.tcp_attempt = True self.retry_with_tcp = False - return (self.nameserver, self.port, True, 0) + return (self.nameserver, True, 0) backoff = 0.0 if not self.current_nameservers: @@ -707,11 +766,8 @@ class _Resolution: self.backoff = min(self.backoff * 2, 2) self.nameserver = self.current_nameservers.pop(0) - self.port = self.resolver.nameserver_ports.get( - self.nameserver, self.resolver.port - ) - self.tcp_attempt = self.tcp - return (self.nameserver, self.port, self.tcp_attempt, backoff) + self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size() + return (self.nameserver, self.tcp_attempt, backoff) def query_result( self, response: Optional[dns.message.Message], ex: Optional[Exception] @@ -724,7 +780,13 @@ class _Resolution: # Exception during I/O or from_wire() assert response is None self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, ex, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + ex, + response, + ) ) if ( isinstance(ex, dns.exception.FormError) @@ -752,12 +814,18 @@ class _Resolution: self.rdtype, self.rdclass, response, - self.nameserver, - self.port, + self.nameserver.answer_nameserver(), + self.nameserver.answer_port(), ) except Exception as e: self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, e, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + e, + response, + ) ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) @@ -776,7 +844,13 @@ class _Resolution: ) except Exception as e: self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, e, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + e, + response, + ) ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) @@ -792,7 +866,13 @@ class _Resolution: elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, yex, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + yex, + response, + ) ) raise yex else: @@ -804,9 +884,9 @@ class _Resolution: self.nameservers.remove(self.nameserver) self.errors.append( ( - self.nameserver, + str(self.nameserver), self.tcp_attempt, - self.port, + self.nameserver.answer_port(), dns.rcode.to_text(rcode), response, ) @@ -840,8 +920,11 @@ class BaseResolver: retry_servfail: bool rotate: bool ndots: Optional[int] + _nameservers: Sequence[Union[str, dns.nameserver.Nameserver]] - def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True): + def __init__( + self, filename: str = "/etc/resolv.conf", configure: bool = True + ) -> None: """*filename*, a ``str`` or file object, specifying a file in standard /etc/resolv.conf format. This parameter is meaningful only when *configure* is true and the platform is POSIX. @@ -860,13 +943,13 @@ class BaseResolver: elif filename: self.read_resolv_conf(filename) - def reset(self): + def reset(self) -> None: """Reset all resolver configuration to the defaults.""" self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) if len(self.domain) == 0: self.domain = dns.name.root - self.nameservers = [] + self._nameservers = [] self.nameserver_ports = {} self.port = 53 self.search = [] @@ -903,6 +986,7 @@ class BaseResolver: """ + nameservers = [] if isinstance(f, str): try: cm: contextlib.AbstractContextManager = open(f) @@ -922,7 +1006,7 @@ class BaseResolver: continue if tokens[0] == "nameserver": - self.nameservers.append(tokens[1]) + nameservers.append(tokens[1]) elif tokens[0] == "domain": self.domain = dns.name.from_text(tokens[1]) # domain and search are exclusive @@ -950,8 +1034,11 @@ class BaseResolver: self.ndots = int(opt.split(":")[1]) except (ValueError, IndexError): pass - if len(self.nameservers) == 0: + if len(nameservers) == 0: raise NoResolverConfiguration("no nameservers") + # Assigning directly instead of appending means we invoke the + # setter logic, with additonal checking and enrichment. + self.nameservers = nameservers def read_registry(self) -> None: """Extract resolver configuration from the Windows registry.""" @@ -1086,34 +1173,64 @@ class BaseResolver: self.flags = flags - @property - def nameservers(self) -> List[str]: - return self._nameservers - - @nameservers.setter - def nameservers(self, nameservers: List[str]) -> None: - """ - *nameservers*, a ``list`` of nameservers. - - Raises ``ValueError`` if *nameservers* is anything other than a - ``list``. - """ + @classmethod + def _enrich_nameservers( + cls, + nameservers: Sequence[Union[str, dns.nameserver.Nameserver]], + nameserver_ports: Dict[str, int], + default_port: int, + ) -> List[dns.nameserver.Nameserver]: + enriched_nameservers = [] if isinstance(nameservers, list): for nameserver in nameservers: - if not dns.inet.is_address(nameserver): + enriched_nameserver: dns.nameserver.Nameserver + if isinstance(nameserver, dns.nameserver.Nameserver): + enriched_nameserver = nameserver + elif dns.inet.is_address(nameserver): + port = nameserver_ports.get(nameserver, default_port) + enriched_nameserver = dns.nameserver.Do53Nameserver( + nameserver, port + ) + else: try: if urlparse(nameserver).scheme != "https": raise NotImplementedError except Exception: raise ValueError( - f"nameserver {nameserver} is not an " - "IP address or valid https URL" + f"nameserver {nameserver} is not a " + "dns.nameserver.Nameserver instance or text form, " + "IP address, nor a valid https URL" ) - self._nameservers = nameservers + enriched_nameserver = dns.nameserver.DoHNameserver(nameserver) + enriched_nameservers.append(enriched_nameserver) else: raise ValueError( - "nameservers must be a list (not a {})".format(type(nameservers)) + "nameservers must be a list or tuple (not a {})".format( + type(nameservers) + ) ) + return enriched_nameservers + + @property + def nameservers( + self, + ) -> Sequence[Union[str, dns.nameserver.Nameserver]]: + return self._nameservers + + @nameservers.setter + def nameservers( + self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]] + ) -> None: + """ + *nameservers*, a ``list`` of nameservers, where a nameserver is either + a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver`` + instance. + + Raises ``ValueError`` if *nameservers* is not a list of nameservers. + """ + # We just call _enrich_nameservers() for checking + self._enrich_nameservers(nameservers, self.nameserver_ports, self.port) + self._nameservers = nameservers class Resolver(BaseResolver): @@ -1198,33 +1315,18 @@ class Resolver(BaseResolver): assert request is not None # needed for type checking done = False while not done: - (nameserver, port, tcp, backoff) = resolution.next_nameserver() + (nameserver, tcp, backoff) = resolution.next_nameserver() if backoff: time.sleep(backoff) timeout = self._compute_timeout(start, lifetime, resolution.errors) try: - if dns.inet.is_address(nameserver): - if tcp: - response = dns.query.tcp( - request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port, - ) - else: - response = dns.query.udp( - request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port, - raise_on_truncation=True, - ) - else: - response = dns.query.https(request, nameserver, timeout=timeout) + response = nameserver.query( + request, + timeout=timeout, + source=source, + source_port=source_port, + max_size=tcp, + ) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -1293,7 +1395,72 @@ class Resolver(BaseResolver): modified_kwargs["rdclass"] = dns.rdataclass.IN return self.resolve( dns.reversename.from_address(ipaddr), *args, **modified_kwargs - ) # type: ignore[arg-type] + ) + + def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any, + ) -> HostAnswers: + """Use a resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) + lifetime = modified_kwargs.pop("lifetime", None) + start = time.time() + v6 = self.resolve( + name, + dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = self.resolve( + name, + dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer) + if not answers: + raise NoAnswer(response=v6.response) + return answers # pylint: disable=redefined-outer-name @@ -1320,6 +1487,37 @@ class Resolver(BaseResolver): # pylint: enable=redefined-outer-name + def try_ddr(self, lifetime: float = 5.0) -> None: + """Try to update the resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + *lifetime*, a float, is the maximum time to spend attempting DDR. The default + is 5 seconds. + + If the SVCB query is successful and results in a non-empty list of nameservers, + then the resolver's nameservers are set to the returned servers in priority + order. + + The current implementation does not use any address hints from the SVCB record, + nor does it resolve addresses for the SCVB target name, rather it assumes that + the bootstrap nameserver will always be one of the addresses and uses it. + A future revision to the code may offer fuller support. The code verifies that + the bootstrap nameserver is in the Subject Alternative Name field of the + TLS certficate. + """ + try: + expiration = time.time() + lifetime + answer = self.resolve( + dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime + ) + timeout = dns.query._remaining(expiration) + nameservers = dns._ddr._get_nameservers_sync(answer, timeout) + if len(nameservers) > 0: + self.nameservers = nameservers + except Exception: + pass + #: The default resolver. default_resolver: Optional[Resolver] = None @@ -1333,7 +1531,7 @@ def get_default_resolver() -> Resolver: return default_resolver -def reset_default_resolver(): +def reset_default_resolver() -> None: """Re-initialize default resolver. Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX @@ -1355,7 +1553,6 @@ def resolve( lifetime: Optional[float] = None, search: Optional[bool] = None, ) -> Answer: # pragma: no cover - """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver @@ -1421,6 +1618,18 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer: return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +def resolve_name( + name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any +) -> HostAnswers: + """Use a resolver to query for address records. + + See ``dns.resolver.Resolver.resolve_name`` for more information on the + parameters. + """ + + return get_default_resolver().resolve_name(name, family, **kwargs) + + def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. @@ -1431,6 +1640,16 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return get_default_resolver().canonical_name(name) +def try_ddr(lifetime: float = 5.0) -> None: + """Try to update the default resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + See :py:func:`dns.resolver.Resolver.try_ddr` for more information. + """ + return get_default_resolver().try_ddr(lifetime) + + def zone_for_name( name: Union[dns.name.Name, str], rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, @@ -1478,7 +1697,7 @@ def zone_for_name( while 1: try: rlifetime: Optional[float] - if expiration: + if expiration is not None: rlifetime = expiration - time.time() if rlifetime <= 0: rlifetime = 0 @@ -1516,6 +1735,83 @@ def zone_for_name( raise NoRootSOA +def make_resolver_at( + where: Union[dns.name.Name, str], + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Resolver: + """Make a stub resolver using the specified destination as the full resolver. + + *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the + full resolver. + + *port*, an ``int``, the port to use. If not specified, the default is 53. + + *family*, an ``int``, the address family to use. This parameter is used if + *where* is not an address. The default is ``socket.AF_UNSPEC`` in which case + the first address returned by ``resolve_name()`` will be used, otherwise the + first address of the specified family will be used. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames. If not specified, the default resolver will be used. + + Returns a ``dns.resolver.Resolver`` or raises an exception. + """ + if resolver is None: + resolver = get_default_resolver() + nameservers: List[Union[str, dns.nameserver.Nameserver]] = [] + if isinstance(where, str) and dns.inet.is_address(where): + nameservers.append(dns.nameserver.Do53Nameserver(where, port)) + else: + for address in resolver.resolve_name(where, family).addresses(): + nameservers.append(dns.nameserver.Do53Nameserver(address, port)) + res = dns.resolver.Resolver(configure=False) + res.nameservers = nameservers + return res + + +def resolve_at( + where: Union[dns.name.Name, str], + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Answer: + """Query nameservers to find the answer to the question. + + This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to + make a resolver, and then uses it to resolve the query. + + See ``dns.resolver.Resolver.resolve`` for more information on the resolution + parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver + parameters *where*, *port*, *family*, and *resolver*. + + If making more than one query, it is more efficient to call + ``dns.resolver.make_resolver_at()`` and then use that resolver for the queries + instead of calling ``resolve_at()`` multiple times. + """ + return make_resolver_at(where, port, family, resolver).resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + ) + + # # Support for overriding the system resolver for all python code in the # running process. @@ -1559,8 +1855,7 @@ def _getaddrinfo( ) if host is None and service is None: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") - v6addrs = [] - v4addrs = [] + addrs = [] canonical_name = None # pylint: disable=redefined-outer-name # Is host None or an address literal? If so, use the system's # getaddrinfo(). @@ -1576,24 +1871,9 @@ def _getaddrinfo( pass # Something needs resolution! try: - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False) - # Note that setting host ensures we query the same name - # for A as we did for AAAA. (This is just in case search lists - # are active by default in the resolver configuration and - # we might be talking to a server that says NXDOMAIN when it - # wants to say NOERROR no data. - host = v6.qname - canonical_name = v6.canonical_name.to_text(True) - if v6.rrset is not None: - for rdata in v6.rrset: - v6addrs.append(rdata.address) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False) - canonical_name = v4.canonical_name.to_text(True) - if v4.rrset is not None: - for rdata in v4.rrset: - v4addrs.append(rdata.address) + answers = _resolver.resolve_name(host, family) + addrs = answers.addresses_and_families() + canonical_name = answers.canonical_name().to_text(True) except dns.resolver.NXDOMAIN: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") except Exception: @@ -1625,20 +1905,11 @@ def _getaddrinfo( cname = canonical_name else: cname = "" - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - for addr in v6addrs: - for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: - tuples.append( - (socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0)) - ) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - for addr in v4addrs: - for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: - tuples.append( - (socket.AF_INET, socktype, proto, cname, (addr, port)) - ) + for addr, af in addrs: + for socktype in socktypes: + for proto in _protocols_for_socktype[socktype]: + addr_tuple = dns.inet.low_level_address_tuple((addr, port), af) + tuples.append((af, socktype, proto, cname, addr_tuple)) if len(tuples) == 0: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") return tuples diff --git a/lib/dns/reversename.py b/lib/dns/reversename.py index eb6a3b6b..8236c711 100644 --- a/lib/dns/reversename.py +++ b/lib/dns/reversename.py @@ -19,9 +19,9 @@ import binascii -import dns.name -import dns.ipv6 import dns.ipv4 +import dns.ipv6 +import dns.name ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.") ipv6_reverse_domain = dns.name.from_text("ip6.arpa.") diff --git a/lib/dns/rrset.py b/lib/dns/rrset.py index 3f22a90c..350de13e 100644 --- a/lib/dns/rrset.py +++ b/lib/dns/rrset.py @@ -17,11 +17,11 @@ """DNS RRsets (an RRset is a named rdataset)""" -from typing import Any, cast, Collection, Dict, Optional, Union +from typing import Any, Collection, Dict, Optional, Union, cast import dns.name -import dns.rdataset import dns.rdataclass +import dns.rdataset import dns.renderer @@ -214,9 +214,9 @@ def from_text_list( if isinstance(name, str): name = dns.name.from_text(name, None, idna_codec=idna_codec) - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - r = RRset(name, the_rdclass, the_rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = RRset(name, rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text( diff --git a/lib/dns/tokenizer.py b/lib/dns/tokenizer.py index 0551578a..454cac4a 100644 --- a/lib/dns/tokenizer.py +++ b/lib/dns/tokenizer.py @@ -17,10 +17,9 @@ """Tokenize DNS zone file format""" -from typing import Any, Optional, List, Tuple - import io import sys +from typing import Any, List, Optional, Tuple import dns.exception import dns.name diff --git a/lib/dns/transaction.py b/lib/dns/transaction.py index c4a9e1f6..21dea775 100644 --- a/lib/dns/transaction.py +++ b/lib/dns/transaction.py @@ -1,8 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Any, Callable, List, Optional, Tuple, Union - import collections +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import dns.exception import dns.name @@ -357,6 +356,27 @@ class Transaction: """ self._check_delete_name.append(check) + def iterate_rdatasets( + self, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: + """Iterate all the rdatasets in the transaction, returning + (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples. + + Note that as is usual with python iterators, adding or removing items + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_rdatasets() + + def iterate_names(self) -> Iterator[dns.name.Name]: + """Iterate all the names in the transaction. + + Note that as is usual with python iterators, adding or removing names + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_names() + # # Helper methods # @@ -416,7 +436,7 @@ class Transaction: rdataset = rrset.to_rdataset() else: raise TypeError( - f"{method} requires a name or RRset " + "as the first argument" + f"{method} requires a name or RRset as the first argument" ) if rdataset.rdclass != self.manager.get_class(): raise ValueError(f"{method} has objects of wrong RdataClass") @@ -475,7 +495,7 @@ class Transaction: name = rdataset.name else: raise TypeError( - f"{method} requires a name or RRset " + "as the first argument" + f"{method} requires a name or RRset as the first argument" ) self._raise_if_not_empty(method, args) if rdataset: @@ -610,6 +630,10 @@ class Transaction: """Return an iterator that yields (name, rdataset) tuples.""" raise NotImplementedError # pragma: no cover + def _iterate_names(self): + """Return an iterator that yields a name.""" + raise NotImplementedError # pragma: no cover + def _get_node(self, name): """Return the node at *name*, if any. diff --git a/lib/dns/tsig.py b/lib/dns/tsig.py index 2476fdfb..58760f5f 100644 --- a/lib/dns/tsig.py +++ b/lib/dns/tsig.py @@ -23,9 +23,9 @@ import hmac import struct import dns.exception -import dns.rdataclass import dns.name import dns.rcode +import dns.rdataclass class BadTime(dns.exception.DNSException): @@ -187,9 +187,7 @@ class HMACTSig: try: hashinfo = self._hashes[algorithm] except KeyError: - raise NotImplementedError( - f"TSIG algorithm {algorithm} " + "is not supported" - ) + raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported") # create the HMAC context if isinstance(hashinfo, tuple): diff --git a/lib/dns/tsigkeyring.py b/lib/dns/tsigkeyring.py index 6adba284..1010a79f 100644 --- a/lib/dns/tsigkeyring.py +++ b/lib/dns/tsigkeyring.py @@ -17,9 +17,8 @@ """A place to store TSIG keys.""" -from typing import Any, Dict - import base64 +from typing import Any, Dict import dns.name import dns.tsig @@ -33,7 +32,7 @@ def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]: @rtype: dict""" keyring = {} - for (name, value) in textring.items(): + for name, value in textring.items(): kname = dns.name.from_text(name) if isinstance(value, str): keyring[kname] = dns.tsig.Key(kname, value).secret @@ -55,7 +54,7 @@ def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]: def b64encode(secret): return base64.encodebytes(secret).decode().rstrip() - for (name, key) in keyring.items(): + for name, key in keyring.items(): tname = name.to_text() if isinstance(key, bytes): textring[tname] = b64encode(key) diff --git a/lib/dns/update.py b/lib/dns/update.py index 647e5b19..bf1157ac 100644 --- a/lib/dns/update.py +++ b/lib/dns/update.py @@ -24,8 +24,8 @@ import dns.name import dns.opcode import dns.rdata import dns.rdataclass -import dns.rdatatype import dns.rdataset +import dns.rdatatype import dns.tsig @@ -43,7 +43,6 @@ class UpdateSection(dns.enum.IntEnum): class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] - # ignore the mypy error here as we mean to use a different enum _section_enum = UpdateSection # type: ignore @@ -336,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] True, ) else: - the_rdtype = dns.rdatatype.RdataType.make(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) self.find_rrset( self.prerequisite, name, dns.rdataclass.NONE, - the_rdtype, + rdtype, dns.rdatatype.NONE, None, True, diff --git a/lib/dns/version.py b/lib/dns/version.py index 89d4cf1a..1f1fbf2d 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -20,13 +20,13 @@ #: MAJOR MAJOR = 2 #: MINOR -MINOR = 3 +MINOR = 4 #: MICRO -MICRO = 0 +MICRO = 2 #: RELEASELEVEL RELEASELEVEL = 0x0F #: SERIAL -SERIAL = 1 +SERIAL = 0 if RELEASELEVEL == 0x0F: # pragma: no cover lgtm[py/unreachable-statement] #: version diff --git a/lib/dns/versioned.py b/lib/dns/versioned.py index 02e24122..fd78e674 100644 --- a/lib/dns/versioned.py +++ b/lib/dns/versioned.py @@ -2,10 +2,9 @@ """DNS Versioned Zones.""" -from typing import Callable, Deque, Optional, Set, Union - import collections import threading +from typing import Callable, Deque, Optional, Set, Union import dns.exception import dns.immutable @@ -32,7 +31,6 @@ Transaction = dns.zone.Transaction class Zone(dns.zone.Zone): # lgtm[py/missing-equals] - __slots__ = [ "_versions", "_versions_lock", @@ -152,7 +150,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] # # This is not a problem with Threading module threads as # they cannot be canceled, but could be an issue with trio - # or curio tasks when we do the async version of writer(). + # tasks when we do the async version of writer(). # I.e. we'd need to do something like: # # try: diff --git a/lib/dns/win32util.py b/lib/dns/win32util.py index ac314750..b2ca61da 100644 --- a/lib/dns/win32util.py +++ b/lib/dns/win32util.py @@ -1,7 +1,6 @@ import sys if sys.platform == "win32": - from typing import Any import dns.name @@ -18,6 +17,7 @@ if sys.platform == "win32": try: import threading + import pythoncom # pylint: disable=import-error import wmi # pylint: disable=import-error @@ -206,7 +206,7 @@ if sys.platform == "win32": lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: tcp_params = winreg.OpenKey( - lm, r"SYSTEM\CurrentControlSet" r"\Services\Tcpip\Parameters" + lm, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" ) try: self._config_fromkey(tcp_params, True) @@ -214,9 +214,7 @@ if sys.platform == "win32": tcp_params.Close() interfaces = winreg.OpenKey( lm, - r"SYSTEM\CurrentControlSet" - r"\Services\Tcpip\Parameters" - r"\Interfaces", + r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces", ) try: i = 0 diff --git a/lib/dns/wire.py b/lib/dns/wire.py index cadf1686..9f9b1573 100644 --- a/lib/dns/wire.py +++ b/lib/dns/wire.py @@ -1,9 +1,8 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Iterator, Optional, Tuple - import contextlib import struct +from typing import Iterator, Optional, Tuple import dns.exception import dns.name diff --git a/lib/dns/xfr.py b/lib/dns/xfr.py index bb165888..dd247d33 100644 --- a/lib/dns/xfr.py +++ b/lib/dns/xfr.py @@ -21,9 +21,9 @@ import dns.exception import dns.message import dns.name import dns.rcode -import dns.serial import dns.rdataset import dns.rdatatype +import dns.serial import dns.transaction import dns.tsig import dns.zone diff --git a/lib/dns/zone.py b/lib/dns/zone.py index cc8268da..9e763f5f 100644 --- a/lib/dns/zone.py +++ b/lib/dns/zone.py @@ -17,30 +17,29 @@ """DNS Zones.""" -from typing import Any, Dict, Iterator, Iterable, List, Optional, Set, Tuple, Union - import contextlib import io import os import struct +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union import dns.exception +import dns.grange import dns.immutable import dns.name import dns.node -import dns.rdataclass -import dns.rdatatype import dns.rdata +import dns.rdataclass import dns.rdataset +import dns.rdatatype import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.ZONEMD import dns.rrset import dns.tokenizer import dns.transaction import dns.ttl -import dns.grange import dns.zonefile -from dns.zonetypes import DigestScheme, DigestHashAlgorithm, _digest_hashers +from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers class BadZone(dns.exception.DNSException): @@ -321,11 +320,11 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rdataset.Rdataset``. """ - the_name = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - node = self.find_node(the_name, create) - return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create) + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.find_node(name, create) + return node.find_rdataset(self.rdclass, rdtype, covers, create) def get_rdataset( self, @@ -404,14 +403,14 @@ class Zone(dns.transaction.TransactionManager): types were aggregated into a single RRSIG rdataset. """ - the_name = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - node = self.get_node(the_name) + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.get_node(name) if node is not None: - node.delete_rdataset(self.rdclass, the_rdtype, the_covers) + node.delete_rdataset(self.rdclass, rdtype, covers) if len(node) == 0: - self.delete_node(the_name) + self.delete_node(name) def replace_rdataset( self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset @@ -484,10 +483,10 @@ class Zone(dns.transaction.TransactionManager): """ vname = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers) - rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers) + rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers) rrset.update(rdataset) return rrset @@ -565,7 +564,7 @@ class Zone(dns.transaction.TransactionManager): rdtype = dns.rdatatype.RdataType.make(rdtype) covers = dns.rdatatype.RdataType.make(covers) - for (name, node) in self.items(): + for name, node in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or ( rds.rdtype == rdtype and rds.covers == covers @@ -597,7 +596,7 @@ class Zone(dns.transaction.TransactionManager): rdtype = dns.rdatatype.RdataType.make(rdtype) covers = dns.rdatatype.RdataType.make(covers) - for (name, node) in self.items(): + for name, node in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or ( rds.rdtype == rdtype and rds.covers == covers @@ -795,7 +794,7 @@ class Zone(dns.transaction.TransactionManager): assert self.origin is not None origin_name = self.origin hasher = hashinfo() - for (name, node) in sorted(self.items()): + for name, node in sorted(self.items()): rrnamebuf = name.to_digestable(self.origin) for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)): if name == origin_name and dns.rdatatype.ZONEMD in ( @@ -997,6 +996,9 @@ class Version: return None return node.get_rdataset(self.zone.rdclass, rdtype, covers) + def keys(self): + return self.nodes.keys() + def items(self): return self.nodes.items() @@ -1143,10 +1145,13 @@ class Transaction(dns.transaction.Transaction): self.version.origin = origin def _iterate_rdatasets(self): - for (name, node) in self.version.items(): + for name, node in self.version.items(): for rdataset in node: yield (name, rdataset) + def _iterate_names(self): + return self.version.keys() + def _get_node(self, name): return self.version.get_node(name) diff --git a/lib/dns/zonefile.py b/lib/dns/zonefile.py index 1a53f5bc..27f04924 100644 --- a/lib/dns/zonefile.py +++ b/lib/dns/zonefile.py @@ -17,23 +17,22 @@ """DNS Zones.""" -from typing import Any, Iterable, List, Optional, Set, Tuple, Union - import re import sys +from typing import Any, Iterable, List, Optional, Set, Tuple, Union import dns.exception +import dns.grange import dns.name import dns.node +import dns.rdata import dns.rdataclass import dns.rdatatype -import dns.rdata import dns.rdtypes.ANY.SOA import dns.rrset import dns.tokenizer import dns.transaction import dns.ttl -import dns.grange class UnknownOrigin(dns.exception.DNSException): @@ -191,10 +190,6 @@ class Reader: self.last_ttl_known = True token = None except dns.ttl.BadTTL: - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl self.tok.unget(token) # Class @@ -212,6 +207,22 @@ class Reader: if rdclass != self.zone_rdclass: raise dns.exception.SyntaxError("RR class is not zone's class") + if ttl is None: + # support for syntax + token = self._get_identifier() + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = None + except dns.ttl.BadTTL: + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + self.tok.unget(token) + # Type if self.force_rdtype is not None: rdtype = self.force_rdtype @@ -581,7 +592,7 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): pass def _name_exists(self, name): - for (n, _, _) in self.rdatasets: + for n, _, _ in self.rdatasets: if n == name: return True return False @@ -606,6 +617,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): def _iterate_rdatasets(self): raise NotImplementedError # pragma: no cover + def _iterate_names(self): + raise NotImplementedError # pragma: no cover + class RRSetsReaderManager(dns.transaction.TransactionManager): def __init__( @@ -707,26 +721,26 @@ def read_rrsets( if isinstance(default_ttl, str): default_ttl = dns.ttl.from_text(default_ttl) if rdclass is not None: - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdclass = dns.rdataclass.RdataClass.make(rdclass) else: - the_rdclass = None - the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + rdclass = None + default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) if rdtype is not None: - the_rdtype = dns.rdatatype.RdataType.make(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) else: - the_rdtype = None + rdtype = None manager = RRSetsReaderManager(origin, relativize, default_rdclass) with manager.writer(True) as txn: tok = dns.tokenizer.Tokenizer(text, "", idna_codec=idna_codec) reader = Reader( tok, - the_default_rdclass, + default_rdclass, txn, allow_directives=False, force_name=name, force_ttl=ttl, - force_rdclass=the_rdclass, - force_rdtype=the_rdtype, + force_rdclass=rdclass, + force_rdtype=rdtype, default_ttl=default_ttl, ) reader.read() diff --git a/requirements.txt b/requirements.txt index 03b49577..d4e4442f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ cheroot==9.0.0 cherrypy==18.8.0 cloudinary==1.34.0 distro==1.8.0 -dnspython==2.3.0 +dnspython==2.4.2 facebook-sdk==3.1.0 future==0.18.3 ga4mp==2.0.4