Bump dnspython from 2.3.0 to 2.4.2 (#2123)

* Bump dnspython from 2.3.0 to 2.4.2

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.3.0 to 2.4.2.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.3.0...v2.4.2)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.4.2

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2023-08-24 12:05:11 -07:00 committed by GitHub
parent 9f00f5dafa
commit c0aa4e4996
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
108 changed files with 2985 additions and 1136 deletions

View file

@ -22,6 +22,7 @@ __all__ = [
"asyncquery", "asyncquery",
"asyncresolver", "asyncresolver",
"dnssec", "dnssec",
"dnssecalgs",
"dnssectypes", "dnssectypes",
"e164", "e164",
"edns", "edns",

View file

@ -35,6 +35,9 @@ class Socket: # pragma: no cover
async def getsockname(self): async def getsockname(self):
raise NotImplementedError raise NotImplementedError
async def getpeercert(self, timeout):
raise NotImplementedError
async def __aenter__(self): async def __aenter__(self):
return self return self
@ -61,6 +64,11 @@ class StreamSocket(Socket): # pragma: no cover
raise NotImplementedError raise NotImplementedError
class NullTransport:
async def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
class Backend: # pragma: no cover class Backend: # pragma: no cover
def name(self): def name(self):
return "unknown" return "unknown"
@ -83,3 +91,9 @@ class Backend: # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
raise NotImplementedError raise NotImplementedError
def get_transport_class(self):
raise NotImplementedError
async def wait_for(self, awaitable, timeout):
raise NotImplementedError

View file

@ -2,14 +2,13 @@
"""asyncio library query support""" """asyncio library query support"""
import socket
import asyncio import asyncio
import socket
import sys import sys
import dns._asyncbackend import dns._asyncbackend
import dns.exception import dns.exception
_is_win32 = sys.platform == "win32" _is_win32 = sys.platform == "win32"
@ -38,14 +37,21 @@ class _DatagramProtocol:
def connection_lost(self, exc): def connection_lost(self, exc):
if self.recvfrom and not self.recvfrom.done(): if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc) if exc is None:
# EOF we triggered. Is there a better way to do this?
try:
raise EOFError
except EOFError as e:
self.recvfrom.set_exception(e)
else:
self.recvfrom.set_exception(exc)
def close(self): def close(self):
self.transport.close() self.transport.close()
async def _maybe_wait_for(awaitable, timeout): async def _maybe_wait_for(awaitable, timeout):
if timeout: if timeout is not None:
try: try:
return await asyncio.wait_for(awaitable, timeout) return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -85,6 +91,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getsockname(self): async def getsockname(self):
return self.transport.get_extra_info("sockname") return self.transport.get_extra_info("sockname")
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer): def __init__(self, af, reader, writer):
@ -101,10 +110,6 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def close(self): async def close(self):
self.writer.close() self.writer.close()
try:
await self.writer.wait_closed()
except AttributeError: # pragma: no cover
pass
async def getpeername(self): async def getpeername(self):
return self.writer.get_extra_info("peername") return self.writer.get_extra_info("peername")
@ -112,6 +117,97 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def getsockname(self): async def getsockname(self):
return self.writer.get_extra_info("sockname") return self.writer.get_extra_info("sockname")
async def getpeercert(self, timeout):
return self.writer.get_extra_info("peercert")
try:
import anyio
import httpcore
import httpcore._backends.anyio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
if local_port != 0:
raise NotImplementedError(
"the asyncio transport for HTTPX cannot set the local port"
)
async def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
with anyio.fail_after(timeout):
stream = await anyio.connect_tcp(
remote_host=address,
remote_port=port,
local_host=local_address,
)
return _CoreAnyIOStream(stream)
except Exception:
pass
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await anyio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
except ImportError:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
@ -171,3 +267,9 @@ class Backend(dns._asyncbackend.Backend):
def datagram_connection_required(self): def datagram_connection_required(self):
return _is_win32 return _is_win32
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
return await _maybe_wait_for(awaitable, timeout)

View file

@ -1,122 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""curio async I/O library query support"""
import socket
import curio
import curio.socket # type: ignore
import dns._asyncbackend
import dns.exception
import dns.inet
def _maybe_timeout(timeout):
if timeout:
return curio.ignore_after(timeout)
else:
return dns._asyncbackend.NullContext()
# for brevity
_lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, socket):
self.socket = socket
self.family = socket.family
async def sendall(self, what, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
class Backend(dns._asyncbackend.Backend):
def name(self):
return "curio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto)
try:
if source:
s.bind(_lltuple(source, af))
except Exception: # pragma: no cover
await s.close()
raise
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
if source:
source_addr = _lltuple(source, af)
else:
source_addr = None
async with _maybe_timeout(timeout):
s = await curio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname,
)
return StreamSocket(s)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await curio.sleep(interval)

154
lib/dns/_ddr.py Normal file
View file

@ -0,0 +1,154 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
#
# Support for Discovery of Designated Resolvers
import socket
import time
from urllib.parse import urlparse
import dns.asyncbackend
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdtypes.svcbbase
# The special name of the local resolver when using DDR
_local_resolver_name = dns.name.from_text("_dns.resolver.arpa")
#
# Processing is split up into I/O independent and I/O dependent parts to
# make supporting sync and async versions easy.
#
class _SVCBInfo:
def __init__(self, bootstrap_address, port, hostname, nameservers):
self.bootstrap_address = bootstrap_address
self.port = port
self.hostname = hostname
self.nameservers = nameservers
def ddr_check_certificate(self, cert):
"""Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)"""
for name, value in cert["subjectAltName"]:
if name == "IP Address" and value == self.bootstrap_address:
return True
return False
def make_tls_context(self):
ssl = dns.query.ssl
ctx = ssl.create_default_context()
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
return ctx
def ddr_tls_check_sync(self, lifetime):
ctx = self.make_tls_context()
expiration = time.time() + lifetime
with socket.create_connection(
(self.bootstrap_address, self.port), lifetime
) as s:
with ctx.wrap_socket(s, server_hostname=self.hostname) as ts:
ts.settimeout(dns.query._remaining(expiration))
ts.do_handshake()
cert = ts.getpeercert()
return self.ddr_check_certificate(cert)
async def ddr_tls_check_async(self, lifetime, backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
ctx = self.make_tls_context()
expiration = time.time() + lifetime
async with await backend.make_socket(
dns.inet.af_for_address(self.bootstrap_address),
socket.SOCK_STREAM,
0,
None,
(self.bootstrap_address, self.port),
lifetime,
ctx,
self.hostname,
) as ts:
cert = await ts.getpeercert(dns.query._remaining(expiration))
return self.ddr_check_certificate(cert)
def _extract_nameservers_from_svcb(answer):
bootstrap_address = answer.nameserver
if not dns.inet.is_address(bootstrap_address):
return []
infos = []
for rr in answer.rrset.processing_order():
nameservers = []
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN)
if param is None:
continue
alpns = set(param.ids)
host = rr.target.to_text(omit_final_dot=True)
port = None
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT)
if param is not None:
port = param.port
# For now we ignore address hints and address resolution and always use the
# bootstrap address
if b"h2" in alpns:
param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH)
if param is None or not param.value.endswith(b"{?dns}"):
continue
path = param.value[:-6].decode()
if not path.startswith("/"):
path = "/" + path
if port is None:
port = 443
url = f"https://{host}:{port}{path}"
# check the URL
try:
urlparse(url)
nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address))
except Exception:
# continue processing other ALPN types
pass
if b"dot" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoTNameserver(bootstrap_address, port, host)
)
if b"doq" in alpns:
if port is None:
port = 853
nameservers.append(
dns.nameserver.DoQNameserver(bootstrap_address, port, True, host)
)
if len(nameservers) > 0:
infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers))
return infos
def _get_nameservers_sync(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if info.ddr_tls_check_sync(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers
async def _get_nameservers_async(answer, lifetime):
"""Return a list of TLS-validated resolver nameservers extracted from an SVCB
answer."""
nameservers = []
infos = _extract_nameservers_from_svcb(answer)
for info in infos:
try:
if await info.ddr_tls_check_async(lifetime):
nameservers.extend(info.nameservers)
except Exception:
pass
return nameservers

View file

@ -7,7 +7,6 @@
import contextvars import contextvars
import inspect import inspect
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) _in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)

View file

@ -3,6 +3,7 @@
"""trio async I/O library query support""" """trio async I/O library query support"""
import socket import socket
import trio import trio
import trio.socket # type: ignore import trio.socket # type: ignore
@ -12,7 +13,7 @@ import dns.inet
def _maybe_timeout(timeout): def _maybe_timeout(timeout):
if timeout: if timeout is not None:
return trio.move_on_after(timeout) return trio.move_on_after(timeout)
else: else:
return dns._asyncbackend.NullContext() return dns._asyncbackend.NullContext()
@ -50,6 +51,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
async def getsockname(self): async def getsockname(self):
return self.socket.getsockname() return self.socket.getsockname()
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False): def __init__(self, family, stream, tls=False):
@ -82,6 +86,100 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
else: else:
return self.stream.socket.getsockname() return self.stream.socket.getsockname()
async def getpeercert(self, timeout):
if self.tls:
with _maybe_timeout(timeout):
await self.stream.do_handshake()
return self.stream.getpeercert()
else:
raise NotImplementedError
try:
import httpcore
import httpcore._backends.trio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreTrioStream = httpcore._backends.trio.TrioStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
async def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = (local_address, self._local_port)
else:
source = None
destination = (address, port)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
sock = await Backend().make_socket(
af, socket.SOCK_STREAM, 0, source, destination, timeout
)
return _CoreTrioStream(sock.stream)
except Exception:
continue
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await trio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
except ImportError:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
@ -104,8 +202,14 @@ class Backend(dns._asyncbackend.Backend):
if source: if source:
await s.bind(_lltuple(source, af)) await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM: if socktype == socket.SOCK_STREAM:
connected = False
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af)) await s.connect(_lltuple(destination, af))
connected = True
if not connected:
raise dns.exception.Timeout(
timeout=timeout
) # lgtm[py/unreachable-statement]
except Exception: # pragma: no cover except Exception: # pragma: no cover
s.close() s.close()
raise raise
@ -130,3 +234,13 @@ class Backend(dns._asyncbackend.Backend):
async def sleep(self, interval): async def sleep(self, interval):
await trio.sleep(interval) await trio.sleep(interval)
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
with _maybe_timeout(timeout):
return await awaitable
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]

View file

@ -5,13 +5,12 @@ from typing import Dict
import dns.exception import dns.exception
# pylint: disable=unused-import # pylint: disable=unused-import
from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
from dns._asyncbackend import (
Socket,
DatagramSocket,
StreamSocket,
Backend, Backend,
) # noqa: F401 lgtm[py/unused-import] DatagramSocket,
Socket,
StreamSocket,
)
# pylint: enable=unused-import # pylint: enable=unused-import
@ -30,8 +29,8 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
def get_backend(name: str) -> Backend: def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend. """Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio", *name*, a ``str``, the name of the backend. Currently the "trio"
"curio", and "asyncio" backends are available. and "asyncio" backends are available.
Raises NotImplementError if an unknown backend name is specified. Raises NotImplementError if an unknown backend name is specified.
""" """
@ -43,10 +42,6 @@ def get_backend(name: str) -> Backend:
import dns._trio_backend import dns._trio_backend
backend = dns._trio_backend.Backend() backend = dns._trio_backend.Backend()
elif name == "curio":
import dns._curio_backend
backend = dns._curio_backend.Backend()
elif name == "asyncio": elif name == "asyncio":
import dns._asyncio_backend import dns._asyncio_backend
@ -73,9 +68,7 @@ def sniff() -> str:
try: try:
return sniffio.current_async_library() return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError: except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError( raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
"sniffio cannot determine " + "async library"
)
except ImportError: except ImportError:
import asyncio import asyncio

View file

@ -17,39 +17,38 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64 import base64
import contextlib import contextlib
import socket import socket
import struct import struct
import time import time
from typing import Any, Dict, Optional, Tuple, Union
import dns.asyncbackend import dns.asyncbackend
import dns.exception import dns.exception
import dns.inet import dns.inet
import dns.name
import dns.message import dns.message
import dns.name
import dns.quic import dns.quic
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.transaction import dns.transaction
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.query import ( from dns.query import (
_compute_times,
_matches_destination,
BadResponse, BadResponse,
ssl,
UDPMode,
_have_httpx,
_have_http2,
NoDOH, NoDOH,
NoDOQ, NoDOQ,
UDPMode,
_compute_times,
_have_http2,
_matches_destination,
_remaining,
have_doh,
ssl,
) )
if _have_httpx: if have_doh:
import httpx import httpx
# for brevity # for brevity
@ -73,7 +72,7 @@ def _source_tuple(af, address, port):
def _timeout(expiration, now=None): def _timeout(expiration, now=None):
if expiration: if expiration is not None:
if not now: if not now:
now = time.time() now = time.time()
return max(expiration - now, 0) return max(expiration - now, 0)
@ -445,9 +444,6 @@ async def tls(
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None: if server_hostname is None:
ssl_context.check_hostname = False ssl_context.check_hostname = False
else:
ssl_context = None
server_hostname = None
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port) stuple = _source_tuple(af, source, source_port)
dtuple = (where, port) dtuple = (where, port)
@ -495,6 +491,9 @@ async def https(
path: str = "/dns-query", path: str = "/dns-query",
post: bool = True, post: bool = True,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
@ -508,8 +507,10 @@ async def https(
parameters, exceptions, and return type of this method. parameters, exceptions, and return type of this method.
""" """
if not _have_httpx: if not have_doh:
raise NoDOH("httpx is not available.") # pragma: no cover raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
wire = q.to_wire() wire = q.to_wire()
try: try:
@ -518,15 +519,32 @@ async def https(
af = None af = None
transport = None transport = None
headers = {"accept": "application/dns-message"} headers = {"accept": "application/dns-message"}
if af is not None: if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET: if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path) url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path) url = "https://[{}]:{}{}".format(where, port, path)
else: else:
url = where url = where
if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0]) backend = dns.asyncbackend.get_default_backend()
if source is None:
local_address = None
local_port = 0
else:
local_address = source
local_port = source_port
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=_have_http2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if client: if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client) cm: contextlib.AbstractAsyncContextManager = NullContext(client)
@ -545,14 +563,14 @@ async def https(
"content-length": str(len(wire)), "content-length": str(len(wire)),
} }
) )
response = await the_client.post( response = await backend.wait_for(
url, headers=headers, content=wire, timeout=timeout the_client.post(url, headers=headers, content=wire), timeout
) )
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes twire = wire.decode() # httpx does a repr() if we give it bytes
response = await the_client.get( response = await backend.wait_for(
url, headers=headers, timeout=timeout, params={"dns": twire} the_client.get(url, headers=headers, params={"dns": twire}), timeout
) )
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
@ -690,6 +708,7 @@ async def quic(
connection: Optional[dns.quic.AsyncQuicConnection] = None, connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None, backend: Optional[dns.asyncbackend.Backend] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via """Return the response obtained after sending an asynchronous query via
DNS-over-QUIC. DNS-over-QUIC.
@ -715,14 +734,16 @@ async def quic(
(cfactory, mfactory) = dns.quic.factories_for_backend(backend) (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context: async with cfactory() as context:
async with mfactory(context, verify_mode=verify) as the_manager: async with mfactory(
context, verify_mode=verify, server_name=server_hostname
) as the_manager:
if not connection: if not connection:
the_connection = the_manager.connect(where, port, source, source_port) the_connection = the_manager.connect(where, port, source, source_port)
start = time.time() (start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream() stream = await the_connection.make_stream(timeout)
async with stream: async with stream:
await stream.send(wire, True) await stream.send(wire, True)
wire = await stream.receive(timeout) wire = await stream.receive(_remaining(expiration))
finish = time.time() finish = time.time()
r = dns.message.from_wire( r = dns.message.from_wire(
wire, wire,

View file

@ -17,10 +17,11 @@
"""Asynchronous DNS stub resolver.""" """Asynchronous DNS stub resolver."""
from typing import Any, Dict, Optional, Union import socket
import time import time
from typing import Any, Dict, List, Optional, Union
import dns._ddr
import dns.asyncbackend import dns.asyncbackend
import dns.asyncquery import dns.asyncquery
import dns.exception import dns.exception
@ -31,8 +32,7 @@ import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from] import dns.resolver # lgtm[py/import-and-import-from]
# import some resolver symbols for brevity # import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
# for indentation purposes below # for indentation purposes below
_udp = dns.asyncquery.udp _udp = dns.asyncquery.udp
@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver):
assert request is not None # needed for type checking assert request is not None # needed for type checking
done = False done = False
while not done: while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver() (nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff: if backoff:
await backend.sleep(backoff) await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors) timeout = self._compute_timeout(start, lifetime, resolution.errors)
try: try:
if dns.inet.is_address(nameserver): response = await nameserver.async_query(
if tcp: request,
response = await _tcp( timeout=timeout,
request, source=source,
nameserver, source_port=source_port,
timeout, max_size=tcp,
port, backend=backend,
source, )
source_port,
backend=backend,
)
else:
response = await _udp(
request,
nameserver,
timeout,
port,
source,
source_port,
raise_on_truncation=True,
backend=backend,
)
else:
response = await dns.asyncquery.https(
request, nameserver, timeout=timeout
)
except Exception as ex: except Exception as ex:
(_, done) = resolution.query_result(None, ex) (_, done) = resolution.query_result(None, ex)
continue continue
@ -153,6 +135,73 @@ class Resolver(dns.resolver.BaseResolver):
dns.reversename.from_address(ipaddr), *args, **modified_kwargs dns.reversename.from_address(ipaddr), *args, **modified_kwargs
) )
async def resolve_name(
self,
name: Union[dns.name.Name, str],
family: int = socket.AF_UNSPEC,
**kwargs: Any,
) -> dns.resolver.HostAnswers:
"""Use an asynchronous resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
the specified name.
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
*family*, an ``int``, the address family. If socket.AF_UNSPEC
(the default), both A and AAAA records will be retrieved.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs.pop("rdtype", None)
modified_kwargs["rdclass"] = dns.rdataclass.IN
if family == socket.AF_INET:
v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
return dns.resolver.HostAnswers.make(v4=v4)
elif family == socket.AF_INET6:
v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return dns.resolver.HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC:
raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
lifetime = modified_kwargs.pop("lifetime", None)
start = time.time()
v6 = await self.resolve(
name,
dns.rdatatype.AAAA,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
# Note that setting name ensures we query the same name
# for A as we did for AAAA. (This is just in case search lists
# are active by default in the resolver configuration and
# we might be talking to a server that says NXDOMAIN when it
# wants to say NOERROR no data.
name = v6.qname
v4 = await self.resolve(
name,
dns.rdatatype.A,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
answers = dns.resolver.HostAnswers.make(
v6=v6, v4=v4, add_empty=not raise_on_no_answer
)
if not answers:
raise NoAnswer(response=v6.response)
return answers
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
@ -176,6 +225,37 @@ class Resolver(dns.resolver.BaseResolver):
canonical_name = e.canonical_name canonical_name = e.canonical_name
return canonical_name return canonical_name
async def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
is 5 seconds.
If the SVCB query is successful and results in a non-empty list of nameservers,
then the resolver's nameservers are set to the returned servers in priority
order.
The current implementation does not use any address hints from the SVCB record,
nor does it resolve addresses for the SCVB target name, rather it assumes that
the bootstrap nameserver will always be one of the addresses and uses it.
A future revision to the code may offer fuller support. The code verifies that
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
try:
expiration = time.time() + lifetime
answer = await self.resolve(
dns._ddr._local_resolver_name, "svcb", lifetime=lifetime
)
timeout = dns.query._remaining(expiration)
nameservers = await dns._ddr._get_nameservers_async(answer, timeout)
if len(nameservers) > 0:
self.nameservers = nameservers
except Exception:
pass
default_resolver = None default_resolver = None
@ -246,6 +326,18 @@ async def resolve_address(
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def resolve_name(
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
) -> dns.resolver.HostAnswers:
"""Use a resolver to asynchronously query for address records.
See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
information on the parameters.
"""
return await get_default_resolver().resolve_name(name, family, **kwargs)
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
@ -256,6 +348,16 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
return await get_default_resolver().canonical_name(name) return await get_default_resolver().canonical_name(name)
async def try_ddr(timeout: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
return await get_default_resolver().try_ddr(timeout)
async def zone_for_name( async def zone_for_name(
name: Union[dns.name.Name, str], name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
@ -290,3 +392,84 @@ async def zone_for_name(
name = name.parent() name = name.parent()
except dns.name.NoParent: # pragma: no cover except dns.name.NoParent: # pragma: no cover
raise NoRootSOA raise NoRootSOA
async def make_resolver_at(
where: Union[dns.name.Name, str],
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Optional[Resolver] = None,
) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
full resolver.
*port*, an ``int``, the port to use. If not specified, the default is 53.
*family*, an ``int``, the address family to use. This parameter is used if
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
the first address returned by ``resolve_name()`` will be used, otherwise the
first address of the specified family will be used.
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames. If not specified, the default resolver will be used.
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
if resolver is None:
resolver = get_default_resolver()
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
if isinstance(where, str) and dns.inet.is_address(where):
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
else:
answers = await resolver.resolve_name(where, family)
for address in answers.addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
res = dns.asyncresolver.Resolver(configure=False)
res.nameservers = nameservers
return res
async def resolve_at(
where: Union[dns.name.Name, str],
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Optional[Resolver] = None,
) -> dns.resolver.Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()``
to make a resolver, and then uses it to resolve the query.
See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution
parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the
resolver parameters *where*, *port*, *family*, and *resolver*.
If making more than one query, it is more efficient to call
``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
res = await make_resolver_at(where, port, family, resolver)
return await res.resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)

View file

@ -17,50 +17,44 @@
"""Common DNSSEC-related functions and constants.""" """Common DNSSEC-related functions and constants."""
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
import base64
import contextlib
import functools
import hashlib import hashlib
import math
import struct import struct
import time import time
import base64
from datetime import datetime from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
import dns.exception import dns.exception
import dns.name import dns.name
import dns.node import dns.node
import dns.rdataset
import dns.rdata import dns.rdata
import dns.rdatatype
import dns.rdataclass import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rrset import dns.rrset
import dns.transaction
import dns.zone
from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash
from dns.exception import ( # pylint: disable=W0611
AlgorithmKeyMismatch,
DeniedByPolicy,
UnsupportedAlgorithm,
ValidationFailure,
)
from dns.rdtypes.ANY.CDNSKEY import CDNSKEY from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
from dns.rdtypes.ANY.CDS import CDS from dns.rdtypes.ANY.CDS import CDS
from dns.rdtypes.ANY.DNSKEY import DNSKEY from dns.rdtypes.ANY.DNSKEY import DNSKEY
from dns.rdtypes.ANY.DS import DS from dns.rdtypes.ANY.DS import DS
from dns.rdtypes.ANY.NSEC import NSEC, Bitmap
from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM
from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
from dns.rdtypes.dnskeybase import Flag from dns.rdtypes.dnskeybase import Flag
class UnsupportedAlgorithm(dns.exception.DNSException):
"""The DNSSEC algorithm is not supported."""
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
"""The DNSSEC algorithm is not supported for the given key type."""
class ValidationFailure(dns.exception.DNSException):
"""The DNSSEC signature is invalid."""
class DeniedByPolicy(dns.exception.DNSException):
"""Denied by DNSSEC policy."""
PublicKey = Union[ PublicKey = Union[
"GenericPublicKey",
"rsa.RSAPublicKey", "rsa.RSAPublicKey",
"ec.EllipticCurvePublicKey", "ec.EllipticCurvePublicKey",
"ed25519.Ed25519PublicKey", "ed25519.Ed25519PublicKey",
@ -68,12 +62,15 @@ PublicKey = Union[
] ]
PrivateKey = Union[ PrivateKey = Union[
"GenericPrivateKey",
"rsa.RSAPrivateKey", "rsa.RSAPrivateKey",
"ec.EllipticCurvePrivateKey", "ec.EllipticCurvePrivateKey",
"ed25519.Ed25519PrivateKey", "ed25519.Ed25519PrivateKey",
"ed448.Ed448PrivateKey", "ed448.Ed448PrivateKey",
] ]
RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None]
def algorithm_from_text(text: str) -> Algorithm: def algorithm_from_text(text: str) -> Algorithm:
"""Convert text into a DNSSEC algorithm value. """Convert text into a DNSSEC algorithm value.
@ -308,113 +305,13 @@ def _find_candidate_keys(
return [ return [
cast(DNSKEY, rd) cast(DNSKEY, rd)
for rd in rdataset for rd in rdataset
if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag if rd.algorithm == rrsig.algorithm
and key_id(rd) == rrsig.key_tag
and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1
and rd.protocol == 3 # RFC 4034 2.1.2
] ]
def _is_rsa(algorithm: int) -> bool:
return algorithm in (
Algorithm.RSAMD5,
Algorithm.RSASHA1,
Algorithm.RSASHA1NSEC3SHA1,
Algorithm.RSASHA256,
Algorithm.RSASHA512,
)
def _is_dsa(algorithm: int) -> bool:
return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
def _is_ecdsa(algorithm: int) -> bool:
return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
def _is_eddsa(algorithm: int) -> bool:
return algorithm in (Algorithm.ED25519, Algorithm.ED448)
def _is_gost(algorithm: int) -> bool:
return algorithm == Algorithm.ECCGOST
def _is_md5(algorithm: int) -> bool:
return algorithm == Algorithm.RSAMD5
def _is_sha1(algorithm: int) -> bool:
return algorithm in (
Algorithm.DSA,
Algorithm.RSASHA1,
Algorithm.DSANSEC3SHA1,
Algorithm.RSASHA1NSEC3SHA1,
)
def _is_sha256(algorithm: int) -> bool:
return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
def _is_sha384(algorithm: int) -> bool:
return algorithm == Algorithm.ECDSAP384SHA384
def _is_sha512(algorithm: int) -> bool:
return algorithm == Algorithm.RSASHA512
def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None:
"""Ensure algorithm is valid for key type, throwing an exception on
mismatch."""
if isinstance(key, rsa.RSAPublicKey):
if _is_rsa(algorithm):
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm)
if isinstance(key, dsa.DSAPublicKey):
if _is_dsa(algorithm):
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm)
if isinstance(key, ec.EllipticCurvePublicKey):
if _is_ecdsa(algorithm):
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm)
if isinstance(key, ed25519.Ed25519PublicKey):
if algorithm == Algorithm.ED25519:
return
raise AlgorithmKeyMismatch(
'algorithm "%s" not valid for ED25519 key' % algorithm
)
if isinstance(key, ed448.Ed448PublicKey):
if algorithm == Algorithm.ED448:
return
raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm)
raise TypeError("unsupported key type")
def _make_hash(algorithm: int) -> Any:
if _is_md5(algorithm):
return hashes.MD5()
if _is_sha1(algorithm):
return hashes.SHA1()
if _is_sha256(algorithm):
return hashes.SHA256()
if _is_sha384(algorithm):
return hashes.SHA384()
if _is_sha512(algorithm):
return hashes.SHA512()
if algorithm == Algorithm.ED25519:
return hashes.SHA512()
if algorithm == Algorithm.ED448:
return hashes.SHAKE256(114)
raise ValidationFailure("unknown hash for algorithm %u" % algorithm)
def _bytes_to_long(b: bytes) -> int:
return int.from_bytes(b, "big")
def _get_rrname_rdataset( def _get_rrname_rdataset(
rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: ) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
@ -424,85 +321,13 @@ def _get_rrname_rdataset(
return rrset.name, rrset return rrset.name, rrset
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
keyptr: bytes public_cls = get_algorithm_cls_from_dnskey(key).public_cls
if _is_rsa(key.algorithm): try:
# we ignore because mypy is confused and thinks key.key is a str for unknown public_key = public_cls.from_dnskey(key)
# reasons. except ValueError:
keyptr = key.key raise ValidationFailure("invalid public key")
(bytes_,) = struct.unpack("!B", keyptr[0:1]) public_key.verify(sig, data)
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack("!H", keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
try:
rsa_public_key = rsa.RSAPublicNumbers(
_bytes_to_long(rsa_e), _bytes_to_long(rsa_n)
).public_key(default_backend())
except ValueError:
raise ValidationFailure("invalid public key")
rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
elif _is_dsa(key.algorithm):
keyptr = key.key
(t,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
try:
dsa_public_key = dsa.DSAPublicNumbers( # type: ignore
_bytes_to_long(dsa_y),
dsa.DSAParameterNumbers(
_bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g)
),
).public_key(default_backend())
except ValueError:
raise ValidationFailure("invalid public key")
dsa_public_key.verify(sig, data, chosen_hash)
elif _is_ecdsa(key.algorithm):
keyptr = key.key
curve: Any
if key.algorithm == Algorithm.ECDSAP256SHA256:
curve = ec.SECP256R1()
octets = 32
else:
curve = ec.SECP384R1()
octets = 48
ecdsa_x = keyptr[0:octets]
ecdsa_y = keyptr[octets : octets * 2]
try:
ecdsa_public_key = ec.EllipticCurvePublicNumbers(
curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y)
).public_key(default_backend())
except ValueError:
raise ValidationFailure("invalid public key")
ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
elif _is_eddsa(key.algorithm):
keyptr = key.key
loader: Any
if key.algorithm == Algorithm.ED25519:
loader = ed25519.Ed25519PublicKey
else:
loader = ed448.Ed448PublicKey
try:
eddsa_public_key = loader.from_public_bytes(keyptr)
except ValueError:
raise ValidationFailure("invalid public key")
eddsa_public_key.verify(sig, data)
elif _is_gost(key.algorithm):
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython'
% algorithm_to_text(key.algorithm)
)
else:
raise ValidationFailure("unknown algorithm %u" % key.algorithm)
def _validate_rrsig( def _validate_rrsig(
@ -559,29 +384,13 @@ def _validate_rrsig(
if rrsig.inception > now: if rrsig.inception > now:
raise ValidationFailure("not yet valid") raise ValidationFailure("not yet valid")
if _is_dsa(rrsig.algorithm):
sig_r = rrsig.signature[1:21]
sig_s = rrsig.signature[21:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
elif _is_ecdsa(rrsig.algorithm):
if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
octets = 32
else:
octets = 48
sig_r = rrsig.signature[0:octets]
sig_s = rrsig.signature[octets:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
else:
sig = rrsig.signature
data = _make_rrsig_signature_data(rrset, rrsig, origin) data = _make_rrsig_signature_data(rrset, rrsig, origin)
chosen_hash = _make_hash(rrsig.algorithm)
for candidate_key in candidate_keys: for candidate_key in candidate_keys:
if not policy.ok_to_validate(candidate_key): if not policy.ok_to_validate(candidate_key):
continue continue
try: try:
_validate_signature(sig, data, candidate_key, chosen_hash) _validate_signature(rrsig.signature, data, candidate_key)
return return
except (InvalidSignature, ValidationFailure): except (InvalidSignature, ValidationFailure):
# this happens on an individual validation failure # this happens on an individual validation failure
@ -673,6 +482,7 @@ def _sign(
lifetime: Optional[int] = None, lifetime: Optional[int] = None,
verify: bool = False, verify: bool = False,
policy: Optional[Policy] = None, policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
) -> RRSIG: ) -> RRSIG:
"""Sign RRset using private key. """Sign RRset using private key.
@ -708,6 +518,10 @@ def _sign(
*policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy,
``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624.
*origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all
names in the rrset (including its owner name) must be absolute; otherwise the
specified origin will be used to make names absolute when signing.
Raises ``DeniedByPolicy`` if the signature is denied by policy. Raises ``DeniedByPolicy`` if the signature is denied by policy.
""" """
@ -735,16 +549,26 @@ def _sign(
if expiration is not None: if expiration is not None:
rrsig_expiration = to_timestamp(expiration) rrsig_expiration = to_timestamp(expiration)
elif lifetime is not None: elif lifetime is not None:
rrsig_expiration = int(time.time()) + lifetime rrsig_expiration = rrsig_inception + lifetime
else: else:
raise ValueError("expiration or lifetime must be specified") raise ValueError("expiration or lifetime must be specified")
# Derelativize now because we need a correct labels length for the
# rrsig_template.
if origin is not None:
rrname = rrname.derelativize(origin)
labels = len(rrname) - 1
# Adjust labels appropriately for wildcards.
if rrname.is_wild():
labels -= 1
rrsig_template = RRSIG( rrsig_template = RRSIG(
rdclass=rdclass, rdclass=rdclass,
rdtype=dns.rdatatype.RRSIG, rdtype=dns.rdatatype.RRSIG,
type_covered=rdtype, type_covered=rdtype,
algorithm=dnskey.algorithm, algorithm=dnskey.algorithm,
labels=len(rrname) - 1, labels=labels,
original_ttl=original_ttl, original_ttl=original_ttl,
expiration=rrsig_expiration, expiration=rrsig_expiration,
inception=rrsig_inception, inception=rrsig_inception,
@ -753,63 +577,18 @@ def _sign(
signature=b"", signature=b"",
) )
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template) data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
chosen_hash = _make_hash(rrsig_template.algorithm)
signature = None
if isinstance(private_key, rsa.RSAPrivateKey): if isinstance(private_key, GenericPrivateKey):
if not _is_rsa(dnskey.algorithm): signing_key = private_key
raise ValueError("Invalid DNSKEY algorithm for RSA key")
signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash)
if verify:
private_key.public_key().verify(
signature, data, padding.PKCS1v15(), chosen_hash
)
elif isinstance(private_key, dsa.DSAPrivateKey):
if not _is_dsa(dnskey.algorithm):
raise ValueError("Invalid DNSKEY algorithm for DSA key")
public_dsa_key = private_key.public_key()
if public_dsa_key.key_size > 1024:
raise ValueError("DSA key size overflow")
der_signature = private_key.sign(data, chosen_hash)
if verify:
public_dsa_key.verify(der_signature, data, chosen_hash)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
octets = 20
signature = (
struct.pack("!B", dsa_t)
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
)
elif isinstance(private_key, ec.EllipticCurvePrivateKey):
if not _is_ecdsa(dnskey.algorithm):
raise ValueError("Invalid DNSKEY algorithm for EC key")
der_signature = private_key.sign(data, ec.ECDSA(chosen_hash))
if verify:
private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash))
if dnskey.algorithm == Algorithm.ECDSAP256SHA256:
octets = 32
else:
octets = 48
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes(
dsa_s, length=octets, byteorder="big"
)
elif isinstance(private_key, ed25519.Ed25519PrivateKey):
if dnskey.algorithm != Algorithm.ED25519:
raise ValueError("Invalid DNSKEY algorithm for ED25519 key")
signature = private_key.sign(data)
if verify:
private_key.public_key().verify(signature, data)
elif isinstance(private_key, ed448.Ed448PrivateKey):
if dnskey.algorithm != Algorithm.ED448:
raise ValueError("Invalid DNSKEY algorithm for ED448 key")
signature = private_key.sign(data)
if verify:
private_key.public_key().verify(signature, data)
else: else:
raise TypeError("Unsupported key algorithm") try:
private_cls = get_algorithm_cls_from_dnskey(dnskey)
signing_key = private_cls(key=private_key)
except UnsupportedAlgorithm:
raise TypeError("Unsupported key algorithm")
signature = signing_key.sign(data, verify)
return cast(RRSIG, rrsig_template.replace(signature=signature)) return cast(RRSIG, rrsig_template.replace(signature=signature))
@ -858,9 +637,12 @@ def _make_rrsig_signature_data(
raise ValidationFailure("relative RR name without an origin specified") raise ValidationFailure("relative RR name without an origin specified")
rrname = rrname.derelativize(origin) rrname = rrname.derelativize(origin)
if len(rrname) - 1 < rrsig.labels: name_len = len(rrname)
if rrname.is_wild() and rrsig.labels != name_len - 2:
raise ValidationFailure("wild owner name has wrong label length")
if name_len - 1 < rrsig.labels:
raise ValidationFailure("owner name longer than RRSIG labels") raise ValidationFailure("owner name longer than RRSIG labels")
elif rrsig.labels < len(rrname) - 1: elif rrsig.labels < name_len - 1:
suffix = rrname.split(rrsig.labels + 1)[1] suffix = rrname.split(rrsig.labels + 1)[1]
rrname = dns.name.from_text("*", suffix) rrname = dns.name.from_text("*", suffix)
rrnamebuf = rrname.to_digestable() rrnamebuf = rrname.to_digestable()
@ -884,9 +666,8 @@ def _make_dnskey(
) -> DNSKEY: ) -> DNSKEY:
"""Convert a public key to DNSKEY Rdata """Convert a public key to DNSKEY Rdata
*public_key*, the public key to convert, a *public_key*, a ``PublicKey`` (``GenericPublicKey`` or
``cryptography.hazmat.primitives.asymmetric`` public key class applicable ``cryptography.hazmat.primitives.asymmetric``) to convert.
for DNSSEC.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
@ -902,72 +683,13 @@ def _make_dnskey(
Return DNSKEY ``Rdata``. Return DNSKEY ``Rdata``.
""" """
def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes: algorithm = Algorithm.make(algorithm)
"""Encode a public key per RFC 3110, section 2."""
pn = public_key.public_numbers()
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
if _exp_len > 255:
exp_header = b"\0" + struct.pack("!H", _exp_len)
else:
exp_header = struct.pack("!B", _exp_len)
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
raise ValueError("unsupported RSA key length")
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes: if isinstance(public_key, GenericPublicKey):
"""Encode a public key per RFC 2536, section 2.""" return public_key.to_dnskey(flags=flags, protocol=protocol)
pn = public_key.public_numbers()
dsa_t = (public_key.key_size // 8 - 64) // 8
if dsa_t > 8:
raise ValueError("unsupported DSA key size")
octets = 64 + dsa_t * 8
res = struct.pack("!B", dsa_t)
res += pn.parameter_numbers.q.to_bytes(20, "big")
res += pn.parameter_numbers.p.to_bytes(octets, "big")
res += pn.parameter_numbers.g.to_bytes(octets, "big")
res += pn.y.to_bytes(octets, "big")
return res
def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes:
"""Encode a public key per RFC 6605, section 4."""
pn = public_key.public_numbers()
if isinstance(public_key.curve, ec.SECP256R1):
return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big")
elif isinstance(public_key.curve, ec.SECP384R1):
return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big")
else:
raise ValueError("unsupported ECDSA curve")
the_algorithm = Algorithm.make(algorithm)
_ensure_algorithm_key_combination(the_algorithm, public_key)
if isinstance(public_key, rsa.RSAPublicKey):
key_bytes = encode_rsa_public_key(public_key)
elif isinstance(public_key, dsa.DSAPublicKey):
key_bytes = encode_dsa_public_key(public_key)
elif isinstance(public_key, ec.EllipticCurvePublicKey):
key_bytes = encode_ecdsa_public_key(public_key)
elif isinstance(public_key, ed25519.Ed25519PublicKey):
key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
elif isinstance(public_key, ed448.Ed448PublicKey):
key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
else: else:
raise TypeError("unsupported key algorithm") public_cls = get_algorithm_cls(algorithm).public_cls
return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol)
return DNSKEY(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
algorithm=the_algorithm,
key=key_bytes,
)
def _make_cdnskey( def _make_cdnskey(
@ -1216,23 +938,252 @@ def dnskey_rdataset_to_cdnskey_rdataset(
return dns.rdataset.from_rdata_list(rdataset.ttl, res) return dns.rdataset.from_rdata_list(rdataset.ttl, res)
def default_rrset_signer(
txn: dns.transaction.Transaction,
rrset: dns.rrset.RRset,
signer: dns.name.Name,
ksks: List[Tuple[PrivateKey, DNSKEY]],
zsks: List[Tuple[PrivateKey, DNSKEY]],
inception: Optional[Union[datetime, str, int, float]] = None,
expiration: Optional[Union[datetime, str, int, float]] = None,
lifetime: Optional[int] = None,
policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
) -> None:
"""Default RRset signer"""
if rrset.rdtype in set(
[
dns.rdatatype.RdataType.DNSKEY,
dns.rdatatype.RdataType.CDS,
dns.rdatatype.RdataType.CDNSKEY,
]
):
keys = ksks
else:
keys = zsks
for private_key, dnskey in keys:
rrsig = dns.dnssec.sign(
rrset=rrset,
private_key=private_key,
dnskey=dnskey,
inception=inception,
expiration=expiration,
lifetime=lifetime,
signer=signer,
policy=policy,
origin=origin,
)
txn.add(rrset.name, rrset.ttl, rrsig)
def sign_zone(
zone: dns.zone.Zone,
txn: Optional[dns.transaction.Transaction] = None,
keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None,
add_dnskey: bool = True,
dnskey_ttl: Optional[int] = None,
inception: Optional[Union[datetime, str, int, float]] = None,
expiration: Optional[Union[datetime, str, int, float]] = None,
lifetime: Optional[int] = None,
nsec3: Optional[NSEC3PARAM] = None,
rrset_signer: Optional[RRsetSigner] = None,
policy: Optional[Policy] = None,
) -> None:
"""Sign zone.
*zone*, a ``dns.zone.Zone``, the zone to sign.
*txn*, a ``dns.transaction.Transaction``, an optional transaction to use for
signing.
*keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK
roles are assigned automatically if the SEP flag is used, otherwise all RRsets are
signed by all keys.
*add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are
automatically added to the zone on signing.
*dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL
of the existing DNSKEY RRset used or the TTL of the SOA RRset.
*inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
inception time. If ``None``, the current time is used. If a ``str``, the format is
"YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text
form; this is the same the RRSIG rdata's text form. Values of type `int` or `float`
are interpreted as seconds since the UNIX epoch.
*expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature
expiration time. If ``None``, the expiration time will be the inception time plus
the value of the *lifetime* parameter. See the description of *inception* above for
how the various parameter types are interpreted.
*lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This
parameter is only meaningful if *expiration* is ``None``.
*nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet
implemented.
*rrset_signer*, a ``Callable``, an optional function for signing RRsets. The
function requires two arguments: transaction and RRset. If the not specified,
``dns.dnssec.default_rrset_signer`` will be used.
Returns ``None``.
"""
ksks = []
zsks = []
# if we have both KSKs and ZSKs, split by SEP flag. if not, sign all
# records with all keys
if keys:
for key in keys:
if key[1].flags & Flag.SEP:
ksks.append(key)
else:
zsks.append(key)
if not ksks:
ksks = keys
if not zsks:
zsks = keys
else:
keys = []
if txn:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn)
else:
cm = zone.writer()
with cm as _txn:
if add_dnskey:
if dnskey_ttl is None:
dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY)
if dnskey:
dnskey_ttl = dnskey.ttl
else:
soa = _txn.get(zone.origin, dns.rdatatype.SOA)
dnskey_ttl = soa.ttl
for _, dnskey in keys:
_txn.add(zone.origin, dnskey_ttl, dnskey)
if nsec3:
raise NotImplementedError("Signing with NSEC3 not yet implemented")
else:
_rrset_signer = rrset_signer or functools.partial(
default_rrset_signer,
signer=zone.origin,
ksks=ksks,
zsks=zsks,
inception=inception,
expiration=expiration,
lifetime=lifetime,
policy=policy,
origin=zone.origin,
)
return _sign_zone_nsec(zone, _txn, _rrset_signer)
def _sign_zone_nsec(
zone: dns.zone.Zone,
txn: dns.transaction.Transaction,
rrset_signer: Optional[RRsetSigner] = None,
) -> None:
"""NSEC zone signer"""
def _txn_add_nsec(
txn: dns.transaction.Transaction,
name: dns.name.Name,
next_secure: Optional[dns.name.Name],
rdclass: dns.rdataclass.RdataClass,
ttl: int,
rrset_signer: Optional[RRsetSigner] = None,
) -> None:
"""NSEC zone signer helper"""
mandatory_types = set(
[dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC]
)
node = txn.get_node(name)
if node and next_secure:
types = (
set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types
)
windows = Bitmap.from_rdtypes(list(types))
rrset = dns.rrset.from_rdata(
name,
ttl,
NSEC(
rdclass=rdclass,
rdtype=dns.rdatatype.RdataType.NSEC,
next=next_secure,
windows=windows,
),
)
txn.add(rrset)
if rrset_signer:
rrset_signer(txn, rrset)
rrsig_ttl = zone.get_soa().minimum
delegation = None
last_secure = None
for name in sorted(txn.iterate_names()):
if delegation and name.is_subdomain(delegation):
# names below delegations are not secure
continue
elif txn.get(name, dns.rdatatype.NS) and name != zone.origin:
# inside delegation
delegation = name
else:
# outside delegation
delegation = None
if rrset_signer:
node = txn.get_node(name)
if node:
for rdataset in node.rdatasets:
if rdataset.rdtype == dns.rdatatype.RRSIG:
# do not sign RRSIGs
continue
elif delegation and rdataset.rdtype != dns.rdatatype.DS:
# do not sign delegations except DS records
continue
else:
rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset)
rrset_signer(txn, rrset)
# We need "is not None" as the empty name is False because its length is 0.
if last_secure is not None:
_txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer)
last_secure = name
if last_secure:
_txn_add_nsec(
txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer
)
def _need_pyca(*args, **kwargs): def _need_pyca(*args, **kwargs):
raise ImportError( raise ImportError(
"DNSSEC validation requires " + "python cryptography" "DNSSEC validation requires python cryptography"
) # pragma: no cover ) # pragma: no cover
try: try:
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import utils from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import ec ed25519,
from cryptography.hazmat.primitives.asymmetric import ed25519 )
from cryptography.hazmat.primitives.asymmetric import ed448
from cryptography.hazmat.primitives.asymmetric import rsa from dns.dnssecalgs import ( # pylint: disable=C0412
get_algorithm_cls,
get_algorithm_cls_from_dnskey,
)
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
validate = _need_pyca validate = _need_pyca
validate_rrsig = _need_pyca validate_rrsig = _need_pyca

View file

@ -0,0 +1,121 @@
from typing import Dict, Optional, Tuple, Type, Union
import dns.name
try:
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
from dns.dnssecalgs.rsa import (
PrivateRSAMD5,
PrivateRSASHA1,
PrivateRSASHA1NSEC3SHA1,
PrivateRSASHA256,
PrivateRSASHA512,
)
_have_cryptography = True
except ImportError:
_have_cryptography = False
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography:
algorithms.update(
{
(Algorithm.RSAMD5, None): PrivateRSAMD5,
(Algorithm.DSA, None): PrivateDSA,
(Algorithm.RSASHA1, None): PrivateRSASHA1,
(Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
(Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
(Algorithm.RSASHA256, None): PrivateRSASHA256,
(Algorithm.RSASHA512, None): PrivateRSASHA512,
(Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
(Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
(Algorithm.ED25519, None): PrivateED25519,
(Algorithm.ED448, None): PrivateED448,
}
)
def get_algorithm_cls(
algorithm: Union[int, str], prefix: AlgorithmPrefix = None
) -> Type[GenericPrivateKey]:
"""Get Private Key class from Algorithm.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
algorithm = Algorithm.make(algorithm)
cls = algorithms.get((algorithm, prefix))
if cls:
return cls
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
)
def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
"""Get Private Key class from DNSKEY.
*dnskey*, a ``DNSKEY`` to get Algorithm class for.
Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
Returns a ``dns.dnssecalgs.GenericPrivateKey``
"""
prefix: AlgorithmPrefix = None
if dnskey.algorithm == Algorithm.PRIVATEDNS:
prefix, _ = dns.name.from_wire(dnskey.key, 0)
elif dnskey.algorithm == Algorithm.PRIVATEOID:
length = int(dnskey.key[0])
prefix = dnskey.key[0 : length + 1]
return get_algorithm_cls(dnskey.algorithm, prefix)
def register_algorithm_cls(
algorithm: Union[int, str],
algorithm_cls: Type[GenericPrivateKey],
name: Optional[Union[dns.name.Name, str]] = None,
oid: Optional[bytes] = None,
) -> None:
"""Register Algorithm Private Key class.
*algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
*algorithm_cls*: A `GenericPrivateKey` class.
*name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
*oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
Raises ``ValueError`` if a name or oid is specified incorrectly.
"""
if not issubclass(algorithm_cls, GenericPrivateKey):
raise TypeError("Invalid algorithm class")
algorithm = Algorithm.make(algorithm)
prefix: AlgorithmPrefix = None
if algorithm == Algorithm.PRIVATEDNS:
if name is None:
raise ValueError("Name required for PRIVATEDNS algorithms")
if isinstance(name, str):
name = dns.name.from_text(name)
prefix = name
elif algorithm == Algorithm.PRIVATEOID:
if oid is None:
raise ValueError("OID required for PRIVATEOID algorithms")
prefix = bytes([len(oid)]) + oid
elif name:
raise ValueError("Name only supported for PRIVATEDNS algorithm")
elif oid:
raise ValueError("OID only supported for PRIVATEOID algorithm")
algorithms[(algorithm, prefix)] = algorithm_cls

View file

@ -0,0 +1,84 @@
from abc import ABC, abstractmethod # pylint: disable=no-name-in-module
from typing import Any, Optional, Type
import dns.rdataclass
import dns.rdatatype
from dns.dnssectypes import Algorithm
from dns.exception import AlgorithmKeyMismatch
from dns.rdtypes.ANY.DNSKEY import DNSKEY
from dns.rdtypes.dnskeybase import Flag
class GenericPublicKey(ABC):
algorithm: Algorithm
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def verify(self, signature: bytes, data: bytes) -> None:
"""Verify signed DNSSEC data"""
@abstractmethod
def encode_key_bytes(self) -> bytes:
"""Encode key as bytes for DNSKEY"""
@classmethod
def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
if key.algorithm != cls.algorithm:
raise AlgorithmKeyMismatch
def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
"""Return public key as DNSKEY"""
return DNSKEY(
rdclass=dns.rdataclass.IN,
rdtype=dns.rdatatype.DNSKEY,
flags=flags,
protocol=protocol,
algorithm=self.algorithm,
key=self.encode_key_bytes(),
)
@classmethod
@abstractmethod
def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
"""Create public key from DNSKEY"""
@classmethod
@abstractmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
"""Create public key from PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
@abstractmethod
def to_pem(self) -> bytes:
"""Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
in RFC 5280"""
class GenericPrivateKey(ABC):
public_cls: Type[GenericPublicKey]
@abstractmethod
def __init__(self, key: Any) -> None:
pass
@abstractmethod
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign DNSSEC data"""
@abstractmethod
def public_key(self) -> "GenericPublicKey":
"""Return public key instance"""
@classmethod
@abstractmethod
def from_pem(
cls, private_pem: bytes, password: Optional[bytes] = None
) -> "GenericPrivateKey":
"""Create private key from PEM-encoded PKCS#8"""
@abstractmethod
def to_pem(self, password: Optional[bytes] = None) -> bytes:
"""Return private key as PEM-encoded PKCS#8"""

View file

@ -0,0 +1,68 @@
from typing import Any, Optional, Type
from cryptography.hazmat.primitives import serialization
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
from dns.exception import AlgorithmKeyMismatch
class CryptographyPublicKey(GenericPublicKey):
key: Any = None
key_cls: Any = None
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
@classmethod
def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
key = serialization.load_pem_public_key(public_pem)
return cls(key=key)
def to_pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
class CryptographyPrivateKey(GenericPrivateKey):
key: Any = None
key_cls: Any = None
public_cls: Type[CryptographyPublicKey]
def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called
if self.key_cls is None:
raise TypeError("Undefined private key class")
if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type
key, self.key_cls
):
raise AlgorithmKeyMismatch
self.key = key
def public_key(self) -> "CryptographyPublicKey":
return self.public_cls(key=self.key.public_key())
@classmethod
def from_pem(
cls, private_pem: bytes, password: Optional[bytes] = None
) -> "GenericPrivateKey":
key = serialization.load_pem_private_key(private_pem, password=password)
return cls(key=key)
def to_pem(self, password: Optional[bytes] = None) -> bytes:
encryption_algorithm: serialization.KeySerializationEncryption
if password:
encryption_algorithm = serialization.BestAvailableEncryption(password)
else:
encryption_algorithm = serialization.NoEncryption()
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)

101
lib/dns/dnssecalgs/dsa.py Normal file
View file

@ -0,0 +1,101 @@
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import dsa, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicDSA(CryptographyPublicKey):
key: dsa.DSAPublicKey
key_cls = dsa.DSAPublicKey
algorithm = Algorithm.DSA
chosen_hash = hashes.SHA1()
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[1:21]
sig_s = signature[21:]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 2536, section 2."""
pn = self.key.public_numbers()
dsa_t = (self.key.key_size // 8 - 64) // 8
if dsa_t > 8:
raise ValueError("unsupported DSA key size")
octets = 64 + dsa_t * 8
res = struct.pack("!B", dsa_t)
res += pn.parameter_numbers.q.to_bytes(20, "big")
res += pn.parameter_numbers.p.to_bytes(octets, "big")
res += pn.parameter_numbers.g.to_bytes(octets, "big")
res += pn.y.to_bytes(octets, "big")
return res
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(t,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
return cls(
key=dsa.DSAPublicNumbers( # type: ignore
int.from_bytes(dsa_y, "big"),
dsa.DSAParameterNumbers(
int.from_bytes(dsa_p, "big"),
int.from_bytes(dsa_q, "big"),
int.from_bytes(dsa_g, "big"),
),
).public_key(default_backend()),
)
class PrivateDSA(CryptographyPrivateKey):
key: dsa.DSAPrivateKey
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024:
raise ValueError("DSA key size overflow")
der_signature = self.key.sign(data, self.public_cls.chosen_hash)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
octets = 20
signature = (
struct.pack("!B", dsa_t)
+ int.to_bytes(dsa_r, length=octets, byteorder="big")
+ int.to_bytes(dsa_s, length=octets, byteorder="big")
)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateDSA":
return cls(
key=dsa.generate_private_key(key_size=key_size),
)
class PublicDSANSEC3SHA1(PublicDSA):
algorithm = Algorithm.DSANSEC3SHA1
class PrivateDSANSEC3SHA1(PrivateDSA):
public_cls = PublicDSANSEC3SHA1

View file

@ -0,0 +1,89 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, utils
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicECDSA(CryptographyPublicKey):
key: ec.EllipticCurvePublicKey
key_cls = ec.EllipticCurvePublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
curve: ec.EllipticCurve
octets: int
def verify(self, signature: bytes, data: bytes) -> None:
sig_r = signature[0 : self.octets]
sig_s = signature[self.octets :]
sig = utils.encode_dss_signature(
int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
)
self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 6605, section 4."""
pn = self.key.public_numbers()
return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
cls._ensure_algorithm_key_combination(key)
ecdsa_x = key.key[0 : cls.octets]
ecdsa_y = key.key[cls.octets : cls.octets * 2]
return cls(
key=ec.EllipticCurvePublicNumbers(
curve=cls.curve,
x=int.from_bytes(ecdsa_x, "big"),
y=int.from_bytes(ecdsa_y, "big"),
).public_key(default_backend()),
)
class PrivateECDSA(CryptographyPrivateKey):
key: ec.EllipticCurvePrivateKey
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big"
) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big")
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateECDSA":
return cls(
key=ec.generate_private_key(
curve=cls.public_cls.curve, backend=default_backend()
),
)
class PublicECDSAP256SHA256(PublicECDSA):
algorithm = Algorithm.ECDSAP256SHA256
chosen_hash = hashes.SHA256()
curve = ec.SECP256R1()
octets = 32
class PrivateECDSAP256SHA256(PrivateECDSA):
public_cls = PublicECDSAP256SHA256
class PublicECDSAP384SHA384(PublicECDSA):
algorithm = Algorithm.ECDSAP384SHA384
chosen_hash = hashes.SHA384()
curve = ec.SECP384R1()
octets = 48
class PrivateECDSAP384SHA384(PrivateECDSA):
public_cls = PublicECDSAP384SHA384

View file

@ -0,0 +1,65 @@
from typing import Type
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicEDDSA(CryptographyPublicKey):
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 8080, section 3."""
return self.key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
cls._ensure_algorithm_key_combination(key)
return cls(
key=cls.key_cls.from_public_bytes(key.key),
)
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA]
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls) -> "PrivateEDDSA":
return cls(key=cls.key_cls.generate())
class PublicED25519(PublicEDDSA):
key: ed25519.Ed25519PublicKey
key_cls = ed25519.Ed25519PublicKey
algorithm = Algorithm.ED25519
class PrivateED25519(PrivateEDDSA):
key: ed25519.Ed25519PrivateKey
key_cls = ed25519.Ed25519PrivateKey
public_cls = PublicED25519
class PublicED448(PublicEDDSA):
key: ed448.Ed448PublicKey
key_cls = ed448.Ed448PublicKey
algorithm = Algorithm.ED448
class PrivateED448(PrivateEDDSA):
key: ed448.Ed448PrivateKey
key_cls = ed448.Ed448PrivateKey
public_cls = PublicED448

119
lib/dns/dnssecalgs/rsa.py Normal file
View file

@ -0,0 +1,119 @@
import math
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
from dns.dnssectypes import Algorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
class PublicRSA(CryptographyPublicKey):
key: rsa.RSAPublicKey
key_cls = rsa.RSAPublicKey
algorithm: Algorithm
chosen_hash: hashes.HashAlgorithm
def verify(self, signature: bytes, data: bytes) -> None:
self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
def encode_key_bytes(self) -> bytes:
"""Encode a public key per RFC 3110, section 2."""
pn = self.key.public_numbers()
_exp_len = math.ceil(int.bit_length(pn.e) / 8)
exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
if _exp_len > 255:
exp_header = b"\0" + struct.pack("!H", _exp_len)
else:
exp_header = struct.pack("!B", _exp_len)
if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
raise ValueError("unsupported RSA key length")
return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
@classmethod
def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
cls._ensure_algorithm_key_combination(key)
keyptr = key.key
(bytes_,) = struct.unpack("!B", keyptr[0:1])
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack("!H", keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
return cls(
key=rsa.RSAPublicNumbers(
int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
).public_key(default_backend())
)
class PrivateRSA(CryptographyPrivateKey):
key: rsa.RSAPrivateKey
key_cls = rsa.RSAPrivateKey
public_cls = PublicRSA
default_public_exponent = 65537
def sign(self, data: bytes, verify: bool = False) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
if verify:
self.public_key().verify(signature, data)
return signature
@classmethod
def generate(cls, key_size: int) -> "PrivateRSA":
return cls(
key=rsa.generate_private_key(
public_exponent=cls.default_public_exponent,
key_size=key_size,
backend=default_backend(),
)
)
class PublicRSAMD5(PublicRSA):
algorithm = Algorithm.RSAMD5
chosen_hash = hashes.MD5()
class PrivateRSAMD5(PrivateRSA):
public_cls = PublicRSAMD5
class PublicRSASHA1(PublicRSA):
algorithm = Algorithm.RSASHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1(PrivateRSA):
public_cls = PublicRSASHA1
class PublicRSASHA1NSEC3SHA1(PublicRSA):
algorithm = Algorithm.RSASHA1NSEC3SHA1
chosen_hash = hashes.SHA1()
class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
public_cls = PublicRSASHA1NSEC3SHA1
class PublicRSASHA256(PublicRSA):
algorithm = Algorithm.RSASHA256
chosen_hash = hashes.SHA256()
class PrivateRSASHA256(PrivateRSA):
public_cls = PublicRSASHA256
class PublicRSASHA512(PublicRSA):
algorithm = Algorithm.RSASHA512
chosen_hash = hashes.SHA512()
class PrivateRSASHA512(PrivateRSA):
public_cls = PublicRSASHA512

View file

@ -17,11 +17,10 @@
"""EDNS Options""" """EDNS Options"""
from typing import Any, Dict, Optional, Union
import math import math
import socket import socket
import struct import struct
from typing import Any, Dict, Optional, Union
import dns.enum import dns.enum
import dns.inet import dns.inet
@ -380,7 +379,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
def from_wire_parser( def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option: ) -> Option:
the_code = EDECode.make(parser.get_uint16()) code = EDECode.make(parser.get_uint16())
text = parser.get_remaining() text = parser.get_remaining()
if text: if text:
@ -390,7 +389,7 @@ class EDEOption(Option): # lgtm[py/missing-equals]
else: else:
btext = None btext = None
return cls(the_code, btext) return cls(code, btext)
_type_to_class: Dict[OptionType, Any] = { _type_to_class: Dict[OptionType, Any] = {
@ -424,8 +423,8 @@ def option_from_wire_parser(
Returns an instance of a subclass of ``dns.edns.Option``. Returns an instance of a subclass of ``dns.edns.Option``.
""" """
the_otype = OptionType.make(otype) otype = OptionType.make(otype)
cls = get_option_class(the_otype) cls = get_option_class(otype)
return cls.from_wire_parser(otype, parser) return cls.from_wire_parser(otype, parser)

View file

@ -15,17 +15,15 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from typing import Any, Optional
import os
import hashlib import hashlib
import os
import random import random
import threading import threading
import time import time
from typing import Any, Optional
class EntropyPool: class EntropyPool:
# This is an entropy pool for Python implementations that do not # This is an entropy pool for Python implementations that do not
# have a working SystemRandom. I'm not sure there are any, but # have a working SystemRandom. I'm not sure there are any, but
# leaving this code doesn't hurt anything as the library code # leaving this code doesn't hurt anything as the library code

View file

@ -16,18 +16,31 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import enum import enum
from typing import Type, TypeVar, Union
TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
class IntEnum(enum.IntEnum): class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _check_value(cls, value): def _missing_(cls, value):
max = cls._maximum() cls._check_value(value)
if value < 0 or value > max: val = int.__new__(cls, value)
name = cls._short_name() val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
raise ValueError(f"{name} must be between >= 0 and <= {max}") val._value_ = value
return val
@classmethod @classmethod
def from_text(cls, text): def _check_value(cls, value):
max = cls._maximum()
if not isinstance(value, int):
raise TypeError
if value < 0 or value > max:
name = cls._short_name()
raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
@classmethod
def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum:
text = text.upper() text = text.upper()
try: try:
return cls[text] return cls[text]
@ -47,7 +60,7 @@ class IntEnum(enum.IntEnum):
raise cls._unknown_exception_class() raise cls._unknown_exception_class()
@classmethod @classmethod
def to_text(cls, value): def to_text(cls: Type[TIntEnum], value: int) -> str:
cls._check_value(value) cls._check_value(value)
try: try:
text = cls(value).name text = cls(value).name
@ -59,7 +72,7 @@ class IntEnum(enum.IntEnum):
return text return text
@classmethod @classmethod
def make(cls, value): def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
"""Convert text or a value into an enumerated type, if possible. """Convert text or a value into an enumerated type, if possible.
*value*, the ``int`` or ``str`` to convert. *value*, the ``int`` or ``str`` to convert.
@ -76,10 +89,7 @@ class IntEnum(enum.IntEnum):
if isinstance(value, str): if isinstance(value, str):
return cls.from_text(value) return cls.from_text(value)
cls._check_value(value) cls._check_value(value)
try: return cls(value)
return cls(value)
except ValueError:
return value
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):

View file

@ -140,6 +140,22 @@ class Timeout(DNSException):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class UnsupportedAlgorithm(DNSException):
"""The DNSSEC algorithm is not supported."""
class AlgorithmKeyMismatch(UnsupportedAlgorithm):
"""The DNSSEC algorithm is not supported for the given key type."""
class ValidationFailure(DNSException):
"""The DNSSEC signature is invalid."""
class DeniedByPolicy(DNSException):
"""Denied by DNSSEC policy."""
class ExceptionWrapper: class ExceptionWrapper:
def __init__(self, exception_class): def __init__(self, exception_class):
self.exception_class = exception_class self.exception_class = exception_class

View file

@ -17,9 +17,8 @@
"""DNS Message Flags.""" """DNS Message Flags."""
from typing import Any
import enum import enum
from typing import Any
# Standard DNS flags # Standard DNS flags

View file

@ -1,8 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Any
import collections.abc import collections.abc
from typing import Any
from dns._immutable_ctx import immutable from dns._immutable_ctx import immutable

View file

@ -17,14 +17,12 @@
"""Generic Internet address helper functions.""" """Generic Internet address helper functions."""
from typing import Any, Optional, Tuple
import socket import socket
from typing import Any, Optional, Tuple
import dns.ipv4 import dns.ipv4
import dns.ipv6 import dns.ipv6
# We assume that AF_INET and AF_INET6 are always defined. We keep # We assume that AF_INET and AF_INET6 are always defined. We keep
# these here for the benefit of any old code (unlikely though that # these here for the benefit of any old code (unlikely though that
# is!). # is!).
@ -171,3 +169,12 @@ def low_level_address_tuple(
return tup return tup
else: else:
raise NotImplementedError(f"unknown address family {af}") raise NotImplementedError(f"unknown address family {af}")
def any_for_af(af):
"""Return the 'any' address for the specified address family."""
if af == socket.AF_INET:
return "0.0.0.0"
elif af == socket.AF_INET6:
return "::"
raise NotImplementedError(f"unknown address family {af}")

View file

@ -17,9 +17,8 @@
"""IPv4 helper functions.""" """IPv4 helper functions."""
from typing import Union
import struct import struct
from typing import Union
import dns.exception import dns.exception

View file

@ -17,10 +17,9 @@
"""IPv6 helper functions.""" """IPv6 helper functions."""
from typing import List, Union
import re
import binascii import binascii
import re
from typing import List, Union
import dns.exception import dns.exception
import dns.ipv4 import dns.ipv4

View file

@ -17,30 +17,29 @@
"""DNS Messages""" """DNS Messages"""
from typing import Any, Dict, List, Optional, Tuple, Union
import contextlib import contextlib
import io import io
import time import time
from typing import Any, Dict, List, Optional, Tuple, Union
import dns.wire
import dns.edns import dns.edns
import dns.entropy
import dns.enum import dns.enum
import dns.exception import dns.exception
import dns.flags import dns.flags
import dns.name import dns.name
import dns.opcode import dns.opcode
import dns.entropy
import dns.rcode import dns.rcode
import dns.rdata import dns.rdata
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.rrset
import dns.renderer
import dns.ttl
import dns.tsig
import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.OPT
import dns.rdtypes.ANY.TSIG import dns.rdtypes.ANY.TSIG
import dns.renderer
import dns.rrset
import dns.tsig
import dns.ttl
import dns.wire
class ShortHeader(dns.exception.FormError): class ShortHeader(dns.exception.FormError):
@ -135,7 +134,7 @@ IndexKeyType = Tuple[
Optional[dns.rdataclass.RdataClass], Optional[dns.rdataclass.RdataClass],
] ]
IndexType = Dict[IndexKeyType, dns.rrset.RRset] IndexType = Dict[IndexKeyType, dns.rrset.RRset]
SectionType = Union[int, List[dns.rrset.RRset]] SectionType = Union[int, str, List[dns.rrset.RRset]]
class Message: class Message:
@ -231,7 +230,7 @@ class Message:
s.write("payload %d\n" % self.payload) s.write("payload %d\n" % self.payload)
for opt in self.options: for opt in self.options:
s.write("option %s\n" % opt.to_text()) s.write("option %s\n" % opt.to_text())
for (name, which) in self._section_enum.__members__.items(): for name, which in self._section_enum.__members__.items():
s.write(f";{name}\n") s.write(f";{name}\n")
for rrset in self.section_from_number(which): for rrset in self.section_from_number(which):
s.write(rrset.to_text(origin, relativize, **kw)) s.write(rrset.to_text(origin, relativize, **kw))
@ -348,27 +347,29 @@ class Message:
deleting: Optional[dns.rdataclass.RdataClass] = None, deleting: Optional[dns.rdataclass.RdataClass] = None,
create: bool = False, create: bool = False,
force_unique: bool = False, force_unique: bool = False,
idna_codec: Optional[dns.name.IDNACodec] = None,
) -> dns.rrset.RRset: ) -> dns.rrset.RRset:
"""Find the RRset with the given attributes in the specified section. """Find the RRset with the given attributes in the specified section.
*section*, an ``int`` section number, or one of the section *section*, an ``int`` section number, a ``str`` section name, or one of
attributes of this message. This specifies the the section attributes of this message. This specifies the
the section of the message to search. For example:: the section of the message to search. For example::
my_message.find_rrset(my_message.answer, name, rdclass, rdtype) my_message.find_rrset(my_message.answer, name, rdclass, rdtype)
my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype)
my_message.find_rrset("ANSWER", name, rdclass, rdtype)
*name*, a ``dns.name.Name``, the name of the RRset. *name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
*rdclass*, an ``int``, the class of the RRset. *rdclass*, an ``int`` or ``str``, the class of the RRset.
*rdtype*, an ``int``, the type of the RRset. *rdtype*, an ``int`` or ``str``, the type of the RRset.
*covers*, an ``int`` or ``None``, the covers value of the RRset. *covers*, an ``int`` or ``str``, the covers value of the RRset.
The default is ``None``. The default is ``dns.rdatatype.NONE``.
*deleting*, an ``int`` or ``None``, the deleting value of the RRset. *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
The default is ``None``. RRset. The default is ``None``.
*create*, a ``bool``. If ``True``, create the RRset if it is not found. *create*, a ``bool``. If ``True``, create the RRset if it is not found.
The created RRset is appended to *section*. The created RRset is appended to *section*.
@ -378,6 +379,10 @@ class Message:
already. The default is ``False``. This is useful when creating already. The default is ``False``. This is useful when creating
DDNS Update messages, as order matters for them. DDNS Update messages, as order matters for them.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
is used.
Raises ``KeyError`` if the RRset was not found and create was Raises ``KeyError`` if the RRset was not found and create was
``False``. ``False``.
@ -386,10 +391,19 @@ class Message:
if isinstance(section, int): if isinstance(section, int):
section_number = section section_number = section
the_section = self.section_from_number(section_number) section = self.section_from_number(section_number)
elif isinstance(section, str):
section_number = MessageSection.from_text(section)
section = self.section_from_number(section_number)
else: else:
section_number = self.section_number(section) section_number = self.section_number(section)
the_section = section if isinstance(name, str):
name = dns.name.from_text(name, idna_codec=idna_codec)
rdtype = dns.rdatatype.RdataType.make(rdtype)
rdclass = dns.rdataclass.RdataClass.make(rdclass)
covers = dns.rdatatype.RdataType.make(covers)
if deleting is not None:
deleting = dns.rdataclass.RdataClass.make(deleting)
key = (section_number, name, rdclass, rdtype, covers, deleting) key = (section_number, name, rdclass, rdtype, covers, deleting)
if not force_unique: if not force_unique:
if self.index is not None: if self.index is not None:
@ -397,13 +411,13 @@ class Message:
if rrset is not None: if rrset is not None:
return rrset return rrset
else: else:
for rrset in the_section: for rrset in section:
if rrset.full_match(name, rdclass, rdtype, covers, deleting): if rrset.full_match(name, rdclass, rdtype, covers, deleting):
return rrset return rrset
if not create: if not create:
raise KeyError raise KeyError
rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting)
the_section.append(rrset) section.append(rrset)
if self.index is not None: if self.index is not None:
self.index[key] = rrset self.index[key] = rrset
return rrset return rrset
@ -418,29 +432,31 @@ class Message:
deleting: Optional[dns.rdataclass.RdataClass] = None, deleting: Optional[dns.rdataclass.RdataClass] = None,
create: bool = False, create: bool = False,
force_unique: bool = False, force_unique: bool = False,
idna_codec: Optional[dns.name.IDNACodec] = None,
) -> Optional[dns.rrset.RRset]: ) -> Optional[dns.rrset.RRset]:
"""Get the RRset with the given attributes in the specified section. """Get the RRset with the given attributes in the specified section.
If the RRset is not found, None is returned. If the RRset is not found, None is returned.
*section*, an ``int`` section number, or one of the section *section*, an ``int`` section number, a ``str`` section name, or one of
attributes of this message. This specifies the the section attributes of this message. This specifies the
the section of the message to search. For example:: the section of the message to search. For example::
my_message.get_rrset(my_message.answer, name, rdclass, rdtype) my_message.get_rrset(my_message.answer, name, rdclass, rdtype)
my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype)
my_message.get_rrset("ANSWER", name, rdclass, rdtype)
*name*, a ``dns.name.Name``, the name of the RRset. *name*, a ``dns.name.Name`` or ``str``, the name of the RRset.
*rdclass*, an ``int``, the class of the RRset. *rdclass*, an ``int`` or ``str``, the class of the RRset.
*rdtype*, an ``int``, the type of the RRset. *rdtype*, an ``int`` or ``str``, the type of the RRset.
*covers*, an ``int`` or ``None``, the covers value of the RRset. *covers*, an ``int`` or ``str``, the covers value of the RRset.
The default is ``None``. The default is ``dns.rdatatype.NONE``.
*deleting*, an ``int`` or ``None``, the deleting value of the RRset. *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the
The default is ``None``. RRset. The default is ``None``.
*create*, a ``bool``. If ``True``, create the RRset if it is not found. *create*, a ``bool``. If ``True``, create the RRset if it is not found.
The created RRset is appended to *section*. The created RRset is appended to *section*.
@ -450,12 +466,24 @@ class Message:
already. The default is ``False``. This is useful when creating already. The default is ``False``. This is useful when creating
DDNS Update messages, as order matters for them. DDNS Update messages, as order matters for them.
*idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
is used.
Returns a ``dns.rrset.RRset object`` or ``None``. Returns a ``dns.rrset.RRset object`` or ``None``.
""" """
try: try:
rrset = self.find_rrset( rrset = self.find_rrset(
section, name, rdclass, rdtype, covers, deleting, create, force_unique section,
name,
rdclass,
rdtype,
covers,
deleting,
create,
force_unique,
idna_codec,
) )
except KeyError: except KeyError:
rrset = None rrset = None
@ -1708,13 +1736,11 @@ def make_query(
if isinstance(qname, str): if isinstance(qname, str):
qname = dns.name.from_text(qname, idna_codec=idna_codec) qname = dns.name.from_text(qname, idna_codec=idna_codec)
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
the_rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
m = QueryMessage(id=id) m = QueryMessage(id=id)
m.flags = dns.flags.Flag(flags) m.flags = dns.flags.Flag(flags)
m.find_rrset( m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True
)
# only pass keywords on to use_edns if they have been set to a # only pass keywords on to use_edns if they have been set to a
# non-None value. Setting a field will turn EDNS on if it hasn't # non-None value. Setting a field will turn EDNS on if it hasn't
# been configured. # been configured.

View file

@ -18,12 +18,10 @@
"""DNS Names. """DNS Names.
""" """
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import copy import copy
import struct
import encodings.idna # type: ignore import encodings.idna # type: ignore
import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union
try: try:
import idna # type: ignore import idna # type: ignore
@ -33,10 +31,9 @@ except ImportError: # pragma: no cover
have_idna_2008 = False have_idna_2008 = False
import dns.enum import dns.enum
import dns.wire
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.wire
CompressType = Dict["Name", int] CompressType = Dict["Name", int]

329
lib/dns/nameserver.py Normal file
View file

@ -0,0 +1,329 @@
from typing import Optional, Union
from urllib.parse import urlparse
import dns.asyncbackend
import dns.asyncquery
import dns.inet
import dns.message
import dns.query
class Nameserver:
def __init__(self):
pass
def __str__(self):
raise NotImplementedError
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
raise NotImplementedError
def answer_nameserver(self) -> str:
raise NotImplementedError
def answer_port(self) -> int:
raise NotImplementedError
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
raise NotImplementedError
class AddressAndPortNameserver(Nameserver):
def __init__(self, address: str, port: int):
super().__init__()
self.address = address
self.port = port
def kind(self) -> str:
raise NotImplementedError
def is_always_max_size(self) -> bool:
return False
def __str__(self):
ns_kind = self.kind()
return f"{ns_kind}:{self.address}@{self.port}"
def answer_nameserver(self) -> str:
return self.address
def answer_port(self) -> int:
return self.port
class Do53Nameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 53):
super().__init__(address, port)
def kind(self):
return "Do53"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = dns.query.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = dns.query.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return response
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
if max_size:
response = await dns.asyncquery.tcp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
else:
response = await dns.asyncquery.udp(
request,
self.address,
timeout=timeout,
port=self.port,
source=source,
source_port=source_port,
raise_on_truncation=True,
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return response
class DoHNameserver(Nameserver):
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
def kind(self):
return "DoH"
def is_always_max_size(self) -> bool:
return True
def __str__(self):
return self.url
def answer_nameserver(self) -> str:
return self.url
def answer_port(self) -> int:
port = urlparse(self.url).port
if port is None:
port = 443
return port
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.https(
request,
self.url,
timeout=timeout,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.https(
request,
self.url,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
class DoTNameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
super().__init__(address, port)
self.hostname = hostname
def kind(self):
return "DoT"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.tls(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
)
class DoQNameserver(AddressAndPortNameserver):
def __init__(
self,
address: str,
port: int = 853,
verify: Union[bool, str] = True,
server_hostname: Optional[str] = None,
):
super().__init__(address, port)
self.verify = verify
self.server_hostname = server_hostname
def kind(self):
return "DoQ"
def query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return dns.query.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)
async def async_query(
self,
request: dns.message.QueryMessage,
timeout: float,
source: Optional[str],
source_port: int,
max_size: bool,
backend: dns.asyncbackend.Backend,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
) -> dns.message.Message:
return await dns.asyncquery.quic(
request,
self.address,
port=self.port,
timeout=timeout,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
server_hostname=self.server_hostname,
)

View file

@ -17,19 +17,17 @@
"""DNS nodes. A node is a set of rdatasets.""" """DNS nodes. A node is a set of rdatasets."""
from typing import Any, Dict, Optional
import enum import enum
import io import io
from typing import Any, Dict, Optional
import dns.immutable import dns.immutable
import dns.name import dns.name
import dns.rdataclass import dns.rdataclass
import dns.rdataset import dns.rdataset
import dns.rdatatype import dns.rdatatype
import dns.rrset
import dns.renderer import dns.renderer
import dns.rrset
_cname_types = { _cname_types = {
dns.rdatatype.CNAME, dns.rdatatype.CNAME,

View file

@ -17,8 +17,6 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64 import base64
import contextlib import contextlib
import enum import enum
@ -28,12 +26,12 @@ import selectors
import socket import socket
import struct import struct
import time import time
import urllib.parse from typing import Any, Dict, Optional, Tuple, Union
import dns.exception import dns.exception
import dns.inet import dns.inet
import dns.name
import dns.message import dns.message
import dns.name
import dns.quic import dns.quic
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
@ -43,20 +41,32 @@ import dns.transaction
import dns.tsig import dns.tsig
import dns.xfr import dns.xfr
try:
import requests
from requests_toolbelt.adapters.source import SourceAddressAdapter
from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
_have_requests = True def _remaining(expiration):
except ImportError: # pragma: no cover if expiration is None:
_have_requests = False return None
timeout = expiration - time.time()
if timeout <= 0.0:
raise dns.exception.Timeout
return timeout
def _expiration_for_this_attempt(timeout, expiration):
if expiration is None:
return None
return min(time.time() + timeout, expiration)
_have_httpx = False _have_httpx = False
_have_http2 = False _have_http2 = False
try: try:
import httpcore
import httpcore._backends.sync
import httpx import httpx
_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream
_have_httpx = True _have_httpx = True
try: try:
# See if http2 support is available. # See if http2 support is available.
@ -64,10 +74,87 @@ try:
_have_http2 = True _have_http2 = True
except Exception: except Exception:
pass pass
except ImportError: # pragma: no cover
pass
have_doh = _have_requests or _have_httpx class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = dns.inet.low_level_address_tuple(
(local_address, self._local_port), af
)
else:
source = None
sock = _make_socket(af, socket.SOCK_STREAM, source)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
try:
_connect(
sock,
dns.inet.low_level_address_tuple((address, port), af),
attempt_expiration,
)
return _CoreSyncStream(sock)
except Exception:
pass
raise httpcore.ConnectError
def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
class _HTTPTransport(httpx.HTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
resolver = dns.resolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
except ImportError: # pragma: no cover
class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address):
raise NotImplementedError
have_doh = _have_httpx
try: try:
import ssl import ssl
@ -88,7 +175,7 @@ except ImportError: # pragma: no cover
@classmethod @classmethod
def create_default_context(cls, *args, **kwargs): def create_default_context(cls, *args, **kwargs):
raise Exception("no ssl support") raise Exception("no ssl support") # pylint: disable=broad-exception-raised
# Function used to create a socket. Can be overridden if needed in special # Function used to create a socket. Can be overridden if needed in special
@ -105,7 +192,7 @@ class BadResponse(dns.exception.FormError):
class NoDOH(dns.exception.DNSException): class NoDOH(dns.exception.DNSException):
"""DNS over HTTPS (DOH) was requested but the requests module is not """DNS over HTTPS (DOH) was requested but the httpx module is not
available.""" available."""
@ -230,7 +317,7 @@ def _destination_and_source(
# We know the destination af, so source had better agree! # We know the destination af, so source had better agree!
if saf != af: if saf != af:
raise ValueError( raise ValueError(
"different address families for source " + "and destination" "different address families for source and destination"
) )
else: else:
# We didn't know the destination af, but we know the source, # We didn't know the destination af, but we know the source,
@ -240,11 +327,10 @@ def _destination_and_source(
# Caller has specified a source_port but not an address, so we # Caller has specified a source_port but not an address, so we
# need to return a source, and we need to use the appropriate # need to return a source, and we need to use the appropriate
# wildcard address as the address. # wildcard address as the address.
if af == socket.AF_INET: try:
source = "0.0.0.0" source = dns.inet.any_for_af(af)
elif af == socket.AF_INET6: except Exception:
source = "::" # we catch this and raise ValueError for backwards compatibility
else:
raise ValueError("source_port specified but address family is unknown") raise ValueError("source_port specified but address family is unknown")
# Convert high-level (address, port) tuples into low-level address # Convert high-level (address, port) tuples into low-level address
# tuples. # tuples.
@ -289,6 +375,8 @@ def https(
post: bool = True, post: bool = True,
bootstrap_address: Optional[str] = None, bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
@ -314,91 +402,78 @@ def https(
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message. received message.
*session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the *session*, an ``httpx.Client``. If provided, the client session to use to send the
client/session to use to send the queries. queries.
*path*, a ``str``. If *where* is an IP address, then *path* will be used to *path*, a ``str``. If *where* is an IP address, then *path* will be used to
construct the URL to send the DNS query to. construct the URL to send the DNS query to.
*post*, a ``bool``. If ``True``, the default, POST method will be used. *post*, a ``bool``. If ``True``, the default, POST method will be used.
*bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS *bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
resolver.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification. directory which will be used for verification.
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames in URLs. If not specified, a new resolver with a default
configuration will be used; note this is *not* the default resolver as that resolver
might have been configured to use DoH causing a chicken-and-egg problem. This
parameter only has an effect if the HTTP library is httpx.
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
if not have_doh: if not have_doh:
raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover raise NoDOH # pragma: no cover
if session and not isinstance(session, httpx.Client):
_httpx_ok = _have_httpx raise ValueError("session parameter must be an httpx.Client")
wire = q.to_wire() wire = q.to_wire()
(af, _, source) = _destination_and_source(where, port, source, source_port, False) (af, _, the_source) = _destination_and_source(
transport_adapter = None where, port, source, source_port, False
)
transport = None transport = None
headers = {"accept": "application/dns-message"} headers = {"accept": "application/dns-message"}
if af is not None: if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET: if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path) url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path) url = "https://[{}]:{}{}".format(where, port, path)
elif bootstrap_address is not None:
_httpx_ok = False
split_url = urllib.parse.urlsplit(where)
if split_url.hostname is None:
raise ValueError("DoH URL has no hostname")
headers["Host"] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
if _have_requests:
transport_adapter = HostHeaderSSLAdapter()
else: else:
url = where url = where
if source is not None:
# set source port and source address
if _have_httpx:
if source_port == 0:
transport = httpx.HTTPTransport(local_address=source[0], verify=verify)
else:
_httpx_ok = False
if _have_requests:
transport_adapter = SourceAddressAdapter(source)
if session: # set source port and source address
if _have_httpx:
_is_httpx = isinstance(session, httpx.Client) if the_source is None:
else: local_address = None
_is_httpx = False local_port = 0
if _is_httpx and not _httpx_ok:
raise NoDOH(
"Session is httpx, but httpx cannot be used for "
"the requested operation."
)
else: else:
_is_httpx = _httpx_ok local_address = the_source[0]
local_port = the_source[1]
if not _httpx_ok and not _have_requests: transport = _HTTPTransport(
raise NoDOH( local_address=local_address,
"Cannot use httpx for this operation, and requests is not available." http1=True,
) http2=_have_http2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if session: if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
elif _is_httpx: else:
cm = httpx.Client( cm = httpx.Client(
http1=True, http2=_have_http2, verify=verify, transport=transport http1=True, http2=_have_http2, verify=verify, transport=transport
) )
else:
cm = requests.sessions.Session()
with cm as session: with cm as session:
if transport_adapter and not _is_httpx:
session.mount(url, transport_adapter)
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
if post: if post:
@ -408,29 +483,13 @@ def https(
"content-length": str(len(wire)), "content-length": str(len(wire)),
} }
) )
if _is_httpx: response = session.post(url, headers=headers, content=wire, timeout=timeout)
response = session.post(
url, headers=headers, content=wire, timeout=timeout
)
else:
response = session.post(
url, headers=headers, data=wire, timeout=timeout, verify=verify
)
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
if _is_httpx: twire = wire.decode() # httpx does a repr() if we give it bytes
twire = wire.decode() # httpx does a repr() if we give it bytes response = session.get(
response = session.get( url, headers=headers, timeout=timeout, params={"dns": twire}
url, headers=headers, timeout=timeout, params={"dns": twire} )
)
else:
response = session.get(
url,
headers=headers,
timeout=timeout,
verify=verify,
params={"dns": wire},
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
@ -1070,6 +1129,7 @@ def quic(
ignore_trailing: bool = False, ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None, connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
server_hostname: Optional[str] = None,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC. """Return the response obtained after sending a query via DNS-over-QUIC.
@ -1101,6 +1161,10 @@ def quic(
verification is done; if a `str` then it specifies the path to a certificate file or verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification. directory which will be used for verification.
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
@ -1115,16 +1179,18 @@ def quic(
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection the_connection = connection
else: else:
manager = dns.quic.SyncQuicManager(verify_mode=verify) manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=server_hostname
)
the_manager = manager # for type checking happiness the_manager = manager # for type checking happiness
with manager: with manager:
if not connection: if not connection:
the_connection = the_manager.connect(where, port, source, source_port) the_connection = the_manager.connect(where, port, source, source_port)
start = time.time() (start, expiration) = _compute_times(timeout)
with the_connection.make_stream() as stream: with the_connection.make_stream(timeout) as stream:
stream.send(wire, True) stream.send(wire, True)
wire = stream.receive(timeout) wire = stream.receive(_remaining(expiration))
finish = time.time() finish = time.time()
r = dns.message.from_wire( r = dns.message.from_wire(
wire, wire,

View file

@ -5,13 +5,13 @@ try:
import dns.asyncbackend import dns.asyncbackend
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
from dns.quic._asyncio import ( from dns.quic._asyncio import (
AsyncioQuicManager,
AsyncioQuicConnection, AsyncioQuicConnection,
AsyncioQuicManager,
AsyncioQuicStream, AsyncioQuicStream,
) )
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream
have_quic = True have_quic = True
@ -33,9 +33,10 @@ try:
try: try:
import trio import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports from dns.quic._trio import ( # pylint: disable=ungrouped-imports
TrioQuicManager,
TrioQuicConnection, TrioQuicConnection,
TrioQuicManager,
TrioQuicStream, TrioQuicStream,
) )

View file

@ -9,14 +9,16 @@ import time
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore import aioquic.quic.events # type: ignore
import dns.inet
import dns.asyncbackend
import dns.asyncbackend
import dns.exception
import dns.inet
from dns.quic._common import ( from dns.quic._common import (
BaseQuicStream, QUIC_MAX_DATAGRAM,
AsyncQuicConnection, AsyncQuicConnection,
AsyncQuicManager, AsyncQuicManager,
QUIC_MAX_DATAGRAM, BaseQuicStream,
UnexpectedEOF,
) )
@ -30,15 +32,15 @@ class AsyncioQuicStream(BaseQuicStream):
await self._wake_up.wait() await self._wake_up.wait()
async def wait_for(self, amount, expiration): async def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True: while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount): if self._buffer.have(amount):
return return
self._expecting = amount self._expecting = amount
try: try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout) await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except Exception: except TimeoutError:
pass raise dns.exception.Timeout
self._expecting = 0 self._expecting = 0
async def receive(self, timeout=None): async def receive(self, timeout=None):
@ -86,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
try: try:
af = dns.inet.af_for_address(self._address) af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio") backend = dns.asyncbackend.get_backend("asyncio")
# Note that peer is a low-level address tuple, but make_socket() wants
# a high-level address tuple, so we convert.
self._socket = await backend.make_socket( self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, self._peer af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
) )
self._socket_created.set() self._socket_created.set()
async with self._socket: async with self._socket:
@ -106,6 +110,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer.notify_all() self._wake_timer.notify_all()
except Exception: except Exception:
pass pass
finally:
self._done = True
async with self._wake_timer:
self._wake_timer.notify_all()
self._handshake_complete.set()
async def _wait_for_wake_timer(self): async def _wait_for_wake_timer(self):
async with self._wake_timer: async with self._wake_timer:
@ -115,7 +124,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await self._socket_created.wait() await self._socket_created.wait()
while not self._done: while not self._done:
datagrams = self._connection.datagrams_to_send(time.time()) datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, address) in datagrams: for datagram, address in datagrams:
assert address == self._peer[0] assert address == self._peer[0]
await self._socket.sendto(datagram, self._peer, None) await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values() (expiration, interval) = self._get_timer_values()
@ -160,8 +169,13 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._receiver_task = asyncio.Task(self._receiver()) self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender()) self._sender_task = asyncio.Task(self._sender())
async def make_stream(self): async def make_stream(self, timeout=None):
await self._handshake_complete.wait() try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
except TimeoutError:
raise dns.exception.Timeout
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False) stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id) stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream self._streams[stream_id] = stream
@ -172,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._manager.closed(self._peer[0], self._peer[1]) self._manager.closed(self._peer[0], self._peer[1])
self._closed = True self._closed = True
self._connection.close() self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._socket.close()
async with self._wake_timer: async with self._wake_timer:
self._wake_timer.notify_all() self._wake_timer.notify_all()
try: try:
@ -185,8 +202,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
class AsyncioQuicManager(AsyncQuicManager): class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection) super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port) (connection, start) = self._connect(address, port, source, source_port)
@ -198,7 +215,7 @@ class AsyncioQuicManager(AsyncQuicManager):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections # Copy the iterator into a list as exiting things will mutate the connections
# table. # table.
connections = list(self._connections.values()) connections = list(self._connections.values())
for connection in connections: for connection in connections:

View file

@ -3,13 +3,12 @@
import socket import socket
import struct import struct
import time import time
from typing import Any, Optional
from typing import Any
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
import dns.inet
import dns.inet
QUIC_MAX_DATAGRAM = 2048 QUIC_MAX_DATAGRAM = 2048
@ -135,12 +134,12 @@ class BaseQuicConnection:
class AsyncQuicConnection(BaseQuicConnection): class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self) -> Any: async def make_stream(self, timeout: Optional[float] = None) -> Any:
pass pass
class BaseQuicManager: class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory): def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {} self._connections = {}
self._connection_factory = connection_factory self._connection_factory = connection_factory
if conf is None: if conf is None:
@ -151,6 +150,7 @@ class BaseQuicManager:
conf = aioquic.quic.configuration.QuicConfiguration( conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=["doq", "doq-i03"], alpn_protocols=["doq", "doq-i03"],
verify_mode=verify_mode, verify_mode=verify_mode,
server_name=server_name,
) )
if verify_path is not None: if verify_path is not None:
conf.load_verify_locations(verify_path) conf.load_verify_locations(verify_path)

View file

@ -1,8 +1,8 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import selectors
import socket import socket
import ssl import ssl
import selectors
import struct import struct
import threading import threading
import time import time
@ -10,13 +10,15 @@ import time
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore import aioquic.quic.events # type: ignore
import dns.inet
import dns.exception
import dns.inet
from dns.quic._common import ( from dns.quic._common import (
BaseQuicStream, QUIC_MAX_DATAGRAM,
BaseQuicConnection, BaseQuicConnection,
BaseQuicManager, BaseQuicManager,
QUIC_MAX_DATAGRAM, BaseQuicStream,
UnexpectedEOF,
) )
# Avoid circularity with dns.query # Avoid circularity with dns.query
@ -33,14 +35,15 @@ class SyncQuicStream(BaseQuicStream):
self._lock = threading.Lock() self._lock = threading.Lock()
def wait_for(self, amount, expiration): def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True: while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock: with self._lock:
if self._buffer.have(amount): if self._buffer.have(amount):
return return
self._expecting = amount self._expecting = amount
with self._wake_up: with self._wake_up:
self._wake_up.wait(timeout) if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
self._expecting = 0 self._expecting = 0
def receive(self, timeout=None): def receive(self, timeout=None):
@ -114,24 +117,30 @@ class SyncQuicConnection(BaseQuicConnection):
return return
def _worker(self): def _worker(self):
sel = _selector_class() try:
sel.register(self._socket, selectors.EVENT_READ, self._read) sel = _selector_class()
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) sel.register(self._socket, selectors.EVENT_READ, self._read)
while not self._done: sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
(expiration, interval) = self._get_timer_values(False) while not self._done:
items = sel.select(interval) (expiration, interval) = self._get_timer_values(False)
for (key, _) in items: items = sel.select(interval)
key.data() for key, _ in items:
key.data()
with self._lock:
self._handle_timer(expiration)
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
finally:
with self._lock: with self._lock:
self._handle_timer(expiration) self._done = True
datagrams = self._connection.datagrams_to_send(time.time()) # Ensure anyone waiting for this gets woken up.
for (datagram, _) in datagrams: self._handshake_complete.set()
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
def _handle_events(self): def _handle_events(self):
while True: while True:
@ -163,9 +172,12 @@ class SyncQuicConnection(BaseQuicConnection):
self._worker_thread = threading.Thread(target=self._worker) self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start() self._worker_thread.start()
def make_stream(self): def make_stream(self, timeout=None):
self._handshake_complete.wait() if not self._handshake_complete.wait(timeout):
raise dns.exception.Timeout
with self._lock: with self._lock:
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False) stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id) stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream self._streams[stream_id] = stream
@ -187,8 +199,8 @@ class SyncQuicConnection(BaseQuicConnection):
class SyncQuicManager(BaseQuicManager): class SyncQuicManager(BaseQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, SyncQuicConnection) super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock() self._lock = threading.Lock()
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):
@ -206,7 +218,7 @@ class SyncQuicManager(BaseQuicManager):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections # Copy the iterator into a list as exiting things will mutate the connections
# table. # table.
connections = list(self._connections.values()) connections = list(self._connections.values())
for connection in connections: for connection in connections:

View file

@ -10,13 +10,15 @@ import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore import aioquic.quic.events # type: ignore
import trio import trio
import dns.exception
import dns.inet import dns.inet
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.quic._common import ( from dns.quic._common import (
BaseQuicStream, QUIC_MAX_DATAGRAM,
AsyncQuicConnection, AsyncQuicConnection,
AsyncQuicManager, AsyncQuicManager,
QUIC_MAX_DATAGRAM, BaseQuicStream,
UnexpectedEOF,
) )
@ -44,6 +46,7 @@ class TrioQuicStream(BaseQuicStream):
(size,) = struct.unpack("!H", self._buffer.get(2)) (size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size) await self.wait_for(size)
return self._buffer.get(size) return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False): async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram) data = self._encapsulate(datagram)
@ -80,20 +83,26 @@ class TrioQuicConnection(AsyncQuicConnection):
self._worker_scope = None self._worker_scope = None
async def _worker(self): async def _worker(self):
await self._socket.connect(self._peer) try:
while not self._done: await self._socket.connect(self._peer)
(expiration, interval) = self._get_timer_values(False) while not self._done:
with trio.CancelScope( (expiration, interval) = self._get_timer_values(False)
deadline=trio.current_time() + interval with trio.CancelScope(
) as self._worker_scope: deadline=trio.current_time() + interval
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) ) as self._worker_scope:
self._connection.receive_datagram(datagram, self._peer[0], time.time()) datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._worker_scope = None self._connection.receive_datagram(
self._handle_timer(expiration) datagram, self._peer[0], time.time()
datagrams = self._connection.datagrams_to_send(time.time()) )
for (datagram, _) in datagrams: self._worker_scope = None
await self._socket.send(datagram) self._handle_timer(expiration)
await self._handle_events() datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
await self._socket.send(datagram)
await self._handle_events()
finally:
self._done = True
self._handshake_complete.set()
async def _handle_events(self): async def _handle_events(self):
count = 0 count = 0
@ -130,12 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection):
nursery.start_soon(self._worker) nursery.start_soon(self._worker)
self._run_done.set() self._run_done.set()
async def make_stream(self): async def make_stream(self, timeout=None):
await self._handshake_complete.wait() if timeout is None:
stream_id = self._connection.get_next_available_stream_id(False) context = NullContext(None)
stream = TrioQuicStream(self, stream_id) else:
self._streams[stream_id] = stream context = trio.move_on_after(timeout)
return stream with context:
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
raise dns.exception.Timeout
async def close(self): async def close(self):
if not self._closed: if not self._closed:
@ -148,8 +165,10 @@ class TrioQuicConnection(AsyncQuicConnection):
class TrioQuicManager(AsyncQuicManager): class TrioQuicManager(AsyncQuicManager):
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): def __init__(
super().__init__(conf, verify_mode, TrioQuicConnection) self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery self._nursery = nursery
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):
@ -162,7 +181,7 @@ class TrioQuicManager(AsyncQuicManager):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections # Copy the iterator into a list as exiting things will mutate the connections
# table. # table.
connections = list(self._connections.values()) connections = list(self._connections.values())
for connection in connections: for connection in connections:

View file

@ -17,17 +17,15 @@
"""DNS rdata.""" """DNS rdata."""
from typing import Any, Dict, Optional, Tuple, Union
from importlib import import_module
import base64 import base64
import binascii import binascii
import io
import inspect import inspect
import io
import itertools import itertools
import random import random
from importlib import import_module
from typing import Any, Dict, Optional, Tuple, Union
import dns.wire
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.ipv4 import dns.ipv4
@ -37,6 +35,7 @@ import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.tokenizer import dns.tokenizer
import dns.ttl import dns.ttl
import dns.wire
_chunksize = 32 _chunksize = 32
@ -358,7 +357,6 @@ class Rdata:
or self.rdclass != other.rdclass or self.rdclass != other.rdclass
or self.rdtype != other.rdtype or self.rdtype != other.rdtype
): ):
return NotImplemented return NotImplemented
return self._cmp(other) < 0 return self._cmp(other) < 0
@ -881,16 +879,11 @@ def register_type(
it applies to all classes. it applies to all classes.
""" """
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
existing_cls = get_rdata_class(rdclass, the_rdtype) existing_cls = get_rdata_class(rdclass, rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
try: _rdata_classes[(rdclass, rdtype)] = getattr(
if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
except ValueError:
pass
_rdata_classes[(rdclass, the_rdtype)] = getattr(
implementation, rdtype_text.replace("-", "_") implementation, rdtype_text.replace("-", "_")
) )
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)

View file

@ -17,18 +17,17 @@
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
from typing import Any, cast, Collection, Dict, List, Optional, Union
import io import io
import random import random
import struct import struct
from typing import Any, Collection, Dict, List, Optional, Union, cast
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name import dns.name
import dns.rdatatype
import dns.rdataclass
import dns.rdata import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.set import dns.set
import dns.ttl import dns.ttl
@ -471,9 +470,9 @@ def from_text_list(
Returns a ``dns.rdataset.Rdataset`` object. Returns a ``dns.rdataset.Rdataset`` object.
""" """
the_rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
r = Rdataset(the_rdclass, the_rdtype) r = Rdataset(rdclass, rdtype)
r.update_ttl(ttl) r.update_ttl(ttl)
for t in text_rdatas: for t in text_rdatas:
rd = dns.rdata.from_text( rd = dns.rdata.from_text(

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,15 +15,15 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable import dns.immutable
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
# pylint: disable=unused-import # pylint: disable=unused-import
from dns.rdtypes.dnskeybase import ( from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
SEP,
REVOKE, REVOKE,
SEP,
ZONE, ZONE,
) # noqa: F401 lgtm[py/unused-import] )
# pylint: enable=unused-import # pylint: enable=unused-import

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase
import dns.immutable import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,12 +15,12 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import base64 import base64
import struct
import dns.dnssectypes
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.dnssectypes
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -19,9 +19,9 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.name
import dns.rdtypes.util import dns.rdtypes.util

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase
import dns.immutable import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,15 +15,15 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable import dns.immutable
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
# pylint: disable=unused-import # pylint: disable=unused-import
from dns.rdtypes.dnskeybase import ( from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
SEP,
REVOKE, REVOKE,
SEP,
ZONE, ZONE,
) # noqa: F401 lgtm[py/unused-import] )
# pylint: enable=unused-import # pylint: enable=unused-import

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase
import dns.immutable import dns.immutable
import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -16,8 +16,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.euibase
import dns.immutable import dns.immutable
import dns.rdtypes.euibase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -16,8 +16,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.euibase
import dns.immutable import dns.immutable
import dns.rdtypes.euibase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,9 +15,9 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import base64 import base64
import binascii import binascii
import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable

View file

@ -21,7 +21,6 @@ import dns.exception
import dns.immutable import dns.immutable
import dns.rdata import dns.rdata
_pows = tuple(10**i for i in range(0, 11)) _pows = tuple(10**i for i in range(0, 11))
# default values are in centimeters # default values are in centimeters
@ -40,7 +39,7 @@ def _exponent_of(what, desc):
if what == 0: if what == 0:
return 0 return 0
exp = None exp = None
for (i, pow) in enumerate(_pows): for i, pow in enumerate(_pows):
if what < pow: if what < pow:
exp = i - 1 exp = i - 1
break break

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -17,9 +17,9 @@
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.name
import dns.rdtypes.util import dns.rdtypes.util

View file

@ -25,7 +25,6 @@ import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.rdtypes.util import dns.rdtypes.util
b32_hex_to_normal = bytes.maketrans( b32_hex_to_normal = bytes.maketrans(
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
) )
@ -67,6 +66,7 @@ class NSEC3(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
next = next.rstrip("=")
if self.salt == b"": if self.salt == b"":
salt = "-" salt = "-"
else: else:
@ -94,6 +94,10 @@ class NSEC3(dns.rdata.Rdata):
else: else:
salt = binascii.unhexlify(salt.encode("ascii")) salt = binascii.unhexlify(salt.encode("ascii"))
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
if next.endswith(b"="):
raise binascii.Error("Incorrect padding")
if len(next) % 8 != 0:
next += b"=" * (8 - len(next) % 8)
next = base64.b32decode(next) next = base64.b32decode(next)
bitmap = Bitmap.from_text(tok) bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii import binascii
import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable

View file

@ -18,11 +18,10 @@
import struct import struct
import dns.edns import dns.edns
import dns.immutable
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
# We don't implement from_text, and that's ok. # We don't implement from_text, and that's ok.
# pylint: disable=abstract-method # pylint: disable=abstract-method

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -17,8 +17,8 @@
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
import dns.name import dns.name
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -21,8 +21,8 @@ import struct
import time import time
import dns.dnssectypes import dns.dnssectypes
import dns.immutable
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -19,8 +19,8 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
import dns.name import dns.name
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,11 +15,11 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii import binascii
import struct
import dns.rdata
import dns.immutable import dns.immutable
import dns.rdata
import dns.rdatatype import dns.rdatatype

View file

@ -18,8 +18,8 @@
import base64 import base64
import struct import struct
import dns.immutable
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase
import dns.immutable import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -20,9 +20,9 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdtypes.util import dns.rdtypes.util
import dns.name
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -1,7 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import binascii import binascii
import struct
import dns.immutable import dns.immutable
import dns.rdata import dns.rdata

View file

@ -17,8 +17,8 @@
import struct import struct
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -124,7 +124,6 @@ class APL(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
items = [] items = []
while parser.remaining() > 0: while parser.remaining() > 0:
header = parser.get_struct("!HBB") header = parser.get_struct("!HBB")

View file

@ -1,7 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.rdtypes.svcbbase
import dns.immutable import dns.immutable
import dns.rdtypes.svcbbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import base64 import base64
import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase
import dns.immutable import dns.immutable
import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -19,9 +19,9 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdtypes.util import dns.rdtypes.util
import dns.name
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -19,9 +19,9 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdata import dns.rdata
import dns.rdtypes.util import dns.rdtypes.util
import dns.name
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -1,7 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.rdtypes.svcbbase
import dns.immutable import dns.immutable
import dns.rdtypes.svcbbase
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -18,8 +18,8 @@
import socket import socket
import struct import struct
import dns.ipv4
import dns.immutable import dns.immutable
import dns.ipv4
import dns.rdata import dns.rdata
try: try:

View file

@ -19,9 +19,9 @@ import base64
import enum import enum
import struct import struct
import dns.dnssectypes
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.dnssectypes
import dns.rdata import dns.rdata
# wildcard import # wildcard import
@ -43,7 +43,7 @@ class DNSKEYBase(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.flags = self._as_uint16(flags) self.flags = Flag(self._as_uint16(flags))
self.protocol = self._as_uint8(protocol) self.protocol = self._as_uint8(protocol)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.key = self._as_bytes(key) self.key = self._as_bytes(key)

View file

@ -15,8 +15,8 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii import binascii
import struct
import dns.dnssectypes import dns.dnssectypes
import dns.immutable import dns.immutable
@ -44,7 +44,7 @@ class DSBase(dns.rdata.Rdata):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.key_tag = self._as_uint16(key_tag) self.key_tag = self._as_uint16(key_tag)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.digest_type = self._as_uint8(digest_type) self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type))
self.digest = self._as_bytes(digest) self.digest = self._as_bytes(digest)
try: try:
if len(self.digest) != self._digest_length_by_type[self.digest_type]: if len(self.digest) != self._digest_length_by_type[self.digest_type]:

View file

@ -16,8 +16,8 @@
import binascii import binascii
import dns.rdata
import dns.immutable import dns.immutable
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -21,8 +21,8 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
import dns.name import dns.name
import dns.rdata
import dns.rdtypes.util import dns.rdtypes.util

View file

@ -19,8 +19,8 @@
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
import dns.name import dns.name
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable

View file

@ -34,6 +34,7 @@ class ParamKey(dns.enum.IntEnum):
IPV4HINT = 4 IPV4HINT = 4
ECH = 5 ECH = 5
IPV6HINT = 6 IPV6HINT = 6
DOHPATH = 7
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):

View file

@ -15,11 +15,11 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii import binascii
import struct
import dns.rdata
import dns.immutable import dns.immutable
import dns.rdata
import dns.rdatatype import dns.rdatatype

View file

@ -17,9 +17,8 @@
"""TXT-like base class.""" """TXT-like base class."""
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import struct import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import dns.exception import dns.exception
import dns.immutable import dns.immutable

View file

@ -18,6 +18,7 @@
import collections import collections
import random import random
import struct import struct
from typing import Any, List
import dns.exception import dns.exception
import dns.ipv4 import dns.ipv4
@ -119,7 +120,7 @@ class Bitmap:
def __init__(self, windows=None): def __init__(self, windows=None):
last_window = -1 last_window = -1
self.windows = windows self.windows = windows
for (window, bitmap) in self.windows: for window, bitmap in self.windows:
if not isinstance(window, int): if not isinstance(window, int):
raise ValueError(f"bad {self.type_name} window type") raise ValueError(f"bad {self.type_name} window type")
if window <= last_window: if window <= last_window:
@ -132,11 +133,11 @@ class Bitmap:
if len(bitmap) == 0 or len(bitmap) > 32: if len(bitmap) == 0 or len(bitmap) > 32:
raise ValueError(f"bad {self.type_name} octets") raise ValueError(f"bad {self.type_name} octets")
def to_text(self): def to_text(self) -> str:
text = "" text = ""
for (window, bitmap) in self.windows: for window, bitmap in self.windows:
bits = [] bits = []
for (i, byte) in enumerate(bitmap): for i, byte in enumerate(bitmap):
for j in range(0, 8): for j in range(0, 8):
if byte & (0x80 >> j): if byte & (0x80 >> j):
rdtype = window * 256 + i * 8 + j rdtype = window * 256 + i * 8 + j
@ -145,14 +146,18 @@ class Bitmap:
return text return text
@classmethod @classmethod
def from_text(cls, tok): def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap":
rdtypes = [] rdtypes = []
for token in tok.get_remaining(): for token in tok.get_remaining():
rdtype = dns.rdatatype.from_text(token.unescape().value) rdtype = dns.rdatatype.from_text(token.unescape().value)
if rdtype == 0: if rdtype == 0:
raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0") raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
rdtypes.append(rdtype) rdtypes.append(rdtype)
rdtypes.sort() return cls.from_rdtypes(rdtypes)
@classmethod
def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap":
rdtypes = sorted(rdtypes)
window = 0 window = 0
octets = 0 octets = 0
prior_rdtype = 0 prior_rdtype = 0
@ -177,13 +182,13 @@ class Bitmap:
windows.append((window, bytes(bitmap[0:octets]))) windows.append((window, bytes(bitmap[0:octets])))
return cls(windows) return cls(windows)
def to_wire(self, file): def to_wire(self, file: Any) -> None:
for (window, bitmap) in self.windows: for window, bitmap in self.windows:
file.write(struct.pack("!BB", window, len(bitmap))) file.write(struct.pack("!BB", window, len(bitmap)))
file.write(bitmap) file.write(bitmap)
@classmethod @classmethod
def from_wire_parser(cls, parser): def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap":
windows = [] windows = []
while parser.remaining() > 0: while parser.remaining() > 0:
window = parser.get_uint8() window = parser.get_uint8()
@ -226,7 +231,7 @@ def weighted_processing_order(iterable):
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
while len(rdatas) > 1: while len(rdatas) > 1:
r = random.uniform(0, total) r = random.uniform(0, total)
for (n, rdata) in enumerate(rdatas): for n, rdata in enumerate(rdatas):
weight = rdata._processing_weight() or _no_weight weight = rdata._processing_weight() or _no_weight
if weight > r: if weight > r:
break break

View file

@ -19,14 +19,13 @@
import contextlib import contextlib
import io import io
import struct
import random import random
import struct
import time import time
import dns.exception import dns.exception
import dns.tsig import dns.tsig
QUESTION = 0 QUESTION = 0
ANSWER = 1 ANSWER = 1
AUTHORITY = 2 AUTHORITY = 2

View file

@ -17,29 +17,31 @@
"""DNS stub resolver.""" """DNS stub resolver."""
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
import contextlib import contextlib
import random
import socket import socket
import sys import sys
import threading import threading
import time import time
import random
import warnings import warnings
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlparse
import dns.exception import dns._ddr
import dns.edns import dns.edns
import dns.exception
import dns.flags import dns.flags
import dns.inet import dns.inet
import dns.ipv4 import dns.ipv4
import dns.ipv6 import dns.ipv6
import dns.message import dns.message
import dns.name import dns.name
import dns.nameserver
import dns.query import dns.query
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.rdtypes.svcbbase
import dns.reversename import dns.reversename
import dns.tsig import dns.tsig
@ -72,7 +74,7 @@ class NXDOMAIN(dns.exception.DNSException):
kwargs = dict(qnames=qnames, responses=responses) kwargs = dict(qnames=qnames, responses=responses)
return kwargs return kwargs
def __str__(self): def __str__(self) -> str:
if "qnames" not in self.kwargs: if "qnames" not in self.kwargs:
return super().__str__() return super().__str__()
qnames = self.kwargs["qnames"] qnames = self.kwargs["qnames"]
@ -140,7 +142,11 @@ class YXDOMAIN(dns.exception.DNSException):
ErrorTuple = Tuple[ ErrorTuple = Tuple[
Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message] Optional[str],
bool,
int,
Union[Exception, str],
Optional[dns.message.Message],
] ]
@ -148,11 +154,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
"""Turn a resolution errors trace into a list of text.""" """Turn a resolution errors trace into a list of text."""
texts = [] texts = []
for err in errors: for err in errors:
texts.append( texts.append("Server {} answered {}".format(err[0], err[3]))
"Server {} {} port {} answered {}".format(
err[0], "TCP" if err[1] else "UDP", err[2], err[3]
)
)
return texts return texts
@ -184,7 +186,7 @@ Timeout = LifetimeTimeout
class NoAnswer(dns.exception.DNSException): class NoAnswer(dns.exception.DNSException):
"""The DNS response does not contain an answer to the question.""" """The DNS response does not contain an answer to the question."""
fmt = "The DNS response does not contain an answer " + "to the question: {query}" fmt = "The DNS response does not contain an answer to the question: {query}"
supp_kwargs = {"response"} supp_kwargs = {"response"}
# We do this as otherwise mypy complains about unexpected keyword argument # We do this as otherwise mypy complains about unexpected keyword argument
@ -264,7 +266,7 @@ class Answer:
response: dns.message.QueryMessage, response: dns.message.QueryMessage,
nameserver: Optional[str] = None, nameserver: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
): ) -> None:
self.qname = qname self.qname = qname
self.rdtype = rdtype self.rdtype = rdtype
self.rdclass = rdclass self.rdclass = rdclass
@ -292,7 +294,7 @@ class Answer:
else: else:
raise AttributeError(attr) raise AttributeError(attr)
def __len__(self): def __len__(self) -> int:
return self.rrset and len(self.rrset) or 0 return self.rrset and len(self.rrset) or 0
def __iter__(self): def __iter__(self):
@ -309,14 +311,67 @@ class Answer:
del self.rrset[i] del self.rrset[i]
class Answers(dict):
"""A dict of DNS stub resolver answers, indexed by type."""
class HostAnswers(Answers):
"""A dict of DNS stub resolver answers to a host name lookup, indexed by
type.
"""
@classmethod
def make(
cls,
v6: Optional[Answer] = None,
v4: Optional[Answer] = None,
add_empty: bool = True,
) -> "HostAnswers":
answers = HostAnswers()
if v6 is not None and (add_empty or v6.rrset):
answers[dns.rdatatype.AAAA] = v6
if v4 is not None and (add_empty or v4.rrset):
answers[dns.rdatatype.A] = v4
return answers
# Returns pairs of (address, family) from this result, potentiallys
# filtering by address family.
def addresses_and_families(
self, family: int = socket.AF_UNSPEC
) -> Iterator[Tuple[str, int]]:
if family == socket.AF_UNSPEC:
yield from self.addresses_and_families(socket.AF_INET6)
yield from self.addresses_and_families(socket.AF_INET)
return
elif family == socket.AF_INET6:
answer = self.get(dns.rdatatype.AAAA)
elif family == socket.AF_INET:
answer = self.get(dns.rdatatype.A)
else:
raise NotImplementedError(f"unknown address family {family}")
if answer:
for rdata in answer:
yield (rdata.address, family)
# Returns addresses from this result, potentially filtering by
# address family.
def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]:
return (pair[0] for pair in self.addresses_and_families(family))
# Returns the canonical name from this result.
def canonical_name(self) -> dns.name.Name:
answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A))
return answer.canonical_name
class CacheStatistics: class CacheStatistics:
"""Cache Statistics""" """Cache Statistics"""
def __init__(self, hits=0, misses=0): def __init__(self, hits: int = 0, misses: int = 0) -> None:
self.hits = hits self.hits = hits
self.misses = misses self.misses = misses
def reset(self): def reset(self) -> None:
self.hits = 0 self.hits = 0
self.misses = 0 self.misses = 0
@ -325,7 +380,7 @@ class CacheStatistics:
class CacheBase: class CacheBase:
def __init__(self): def __init__(self) -> None:
self.lock = threading.Lock() self.lock = threading.Lock()
self.statistics = CacheStatistics() self.statistics = CacheStatistics()
@ -361,7 +416,7 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla
class Cache(CacheBase): class Cache(CacheBase):
"""Simple thread-safe DNS answer cache.""" """Simple thread-safe DNS answer cache."""
def __init__(self, cleaning_interval: float = 300.0): def __init__(self, cleaning_interval: float = 300.0) -> None:
"""*cleaning_interval*, a ``float`` is the number of seconds between """*cleaning_interval*, a ``float`` is the number of seconds between
periodic cleanings. periodic cleanings.
""" """
@ -377,7 +432,7 @@ class Cache(CacheBase):
now = time.time() now = time.time()
if self.next_cleaning <= now: if self.next_cleaning <= now:
keys_to_delete = [] keys_to_delete = []
for (k, v) in self.data.items(): for k, v in self.data.items():
if v.expiration <= now: if v.expiration <= now:
keys_to_delete.append(k) keys_to_delete.append(k)
for k in keys_to_delete: for k in keys_to_delete:
@ -447,13 +502,13 @@ class LRUCacheNode:
self.prev = self self.prev = self
self.next = self self.next = self
def link_after(self, node): def link_after(self, node: "LRUCacheNode") -> None:
self.prev = node self.prev = node
self.next = node.next self.next = node.next
node.next.prev = self node.next.prev = self
node.next = self node.next = self
def unlink(self): def unlink(self) -> None:
self.next.prev = self.prev self.next.prev = self.prev
self.prev.next = self.next self.prev.next = self.next
@ -468,7 +523,7 @@ class LRUCache(CacheBase):
for a new one. for a new one.
""" """
def __init__(self, max_size: int = 100000): def __init__(self, max_size: int = 100000) -> None:
"""*max_size*, an ``int``, is the maximum number of nodes to cache; """*max_size*, an ``int``, is the maximum number of nodes to cache;
it must be greater than 0. it must be greater than 0.
""" """
@ -590,30 +645,29 @@ class _Resolution:
tcp: bool, tcp: bool,
raise_on_no_answer: bool, raise_on_no_answer: bool,
search: Optional[bool], search: Optional[bool],
): ) -> None:
if isinstance(qname, str): if isinstance(qname, str):
qname = dns.name.from_text(qname, None) qname = dns.name.from_text(qname, None)
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
if dns.rdatatype.is_metatype(the_rdtype): if dns.rdatatype.is_metatype(rdtype):
raise NoMetaqueries raise NoMetaqueries
the_rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
if dns.rdataclass.is_metaclass(the_rdclass): if dns.rdataclass.is_metaclass(rdclass):
raise NoMetaqueries raise NoMetaqueries
self.resolver = resolver self.resolver = resolver
self.qnames_to_try = resolver._get_qnames_to_try(qname, search) self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
self.qnames = self.qnames_to_try[:] self.qnames = self.qnames_to_try[:]
self.rdtype = the_rdtype self.rdtype = rdtype
self.rdclass = the_rdclass self.rdclass = rdclass
self.tcp = tcp self.tcp = tcp
self.raise_on_no_answer = raise_on_no_answer self.raise_on_no_answer = raise_on_no_answer
self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
# Initialize other things to help analysis tools # Initialize other things to help analysis tools
self.qname = dns.name.empty self.qname = dns.name.empty
self.nameservers: List[str] = [] self.nameservers: List[dns.nameserver.Nameserver] = []
self.current_nameservers: List[str] = [] self.current_nameservers: List[dns.nameserver.Nameserver] = []
self.errors: List[ErrorTuple] = [] self.errors: List[ErrorTuple] = []
self.nameserver: Optional[str] = None self.nameserver: Optional[dns.nameserver.Nameserver] = None
self.port = 0
self.tcp_attempt = False self.tcp_attempt = False
self.retry_with_tcp = False self.retry_with_tcp = False
self.request: Optional[dns.message.QueryMessage] = None self.request: Optional[dns.message.QueryMessage] = None
@ -670,7 +724,11 @@ class _Resolution:
if self.resolver.flags is not None: if self.resolver.flags is not None:
request.flags = self.resolver.flags request.flags = self.resolver.flags
self.nameservers = self.resolver.nameservers[:] self.nameservers = self.resolver._enrich_nameservers(
self.resolver._nameservers,
self.resolver.nameserver_ports,
self.resolver.port,
)
if self.resolver.rotate: if self.resolver.rotate:
random.shuffle(self.nameservers) random.shuffle(self.nameservers)
self.current_nameservers = self.nameservers[:] self.current_nameservers = self.nameservers[:]
@ -690,12 +748,13 @@ class _Resolution:
# #
raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
def next_nameserver(self) -> Tuple[str, int, bool, float]: def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]:
if self.retry_with_tcp: if self.retry_with_tcp:
assert self.nameserver is not None assert self.nameserver is not None
assert not self.nameserver.is_always_max_size()
self.tcp_attempt = True self.tcp_attempt = True
self.retry_with_tcp = False self.retry_with_tcp = False
return (self.nameserver, self.port, True, 0) return (self.nameserver, True, 0)
backoff = 0.0 backoff = 0.0
if not self.current_nameservers: if not self.current_nameservers:
@ -707,11 +766,8 @@ class _Resolution:
self.backoff = min(self.backoff * 2, 2) self.backoff = min(self.backoff * 2, 2)
self.nameserver = self.current_nameservers.pop(0) self.nameserver = self.current_nameservers.pop(0)
self.port = self.resolver.nameserver_ports.get( self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size()
self.nameserver, self.resolver.port return (self.nameserver, self.tcp_attempt, backoff)
)
self.tcp_attempt = self.tcp
return (self.nameserver, self.port, self.tcp_attempt, backoff)
def query_result( def query_result(
self, response: Optional[dns.message.Message], ex: Optional[Exception] self, response: Optional[dns.message.Message], ex: Optional[Exception]
@ -724,7 +780,13 @@ class _Resolution:
# Exception during I/O or from_wire() # Exception during I/O or from_wire()
assert response is None assert response is None
self.errors.append( self.errors.append(
(self.nameserver, self.tcp_attempt, self.port, ex, response) (
str(self.nameserver),
self.tcp_attempt,
self.nameserver.answer_port(),
ex,
response,
)
) )
if ( if (
isinstance(ex, dns.exception.FormError) isinstance(ex, dns.exception.FormError)
@ -752,12 +814,18 @@ class _Resolution:
self.rdtype, self.rdtype,
self.rdclass, self.rdclass,
response, response,
self.nameserver, self.nameserver.answer_nameserver(),
self.port, self.nameserver.answer_port(),
) )
except Exception as e: except Exception as e:
self.errors.append( self.errors.append(
(self.nameserver, self.tcp_attempt, self.port, e, response) (
str(self.nameserver),
self.tcp_attempt,
self.nameserver.answer_port(),
e,
response,
)
) )
# The nameserver is no good, take it out of the mix. # The nameserver is no good, take it out of the mix.
self.nameservers.remove(self.nameserver) self.nameservers.remove(self.nameserver)
@ -776,7 +844,13 @@ class _Resolution:
) )
except Exception as e: except Exception as e:
self.errors.append( self.errors.append(
(self.nameserver, self.tcp_attempt, self.port, e, response) (
str(self.nameserver),
self.tcp_attempt,
self.nameserver.answer_port(),
e,
response,
)
) )
# The nameserver is no good, take it out of the mix. # The nameserver is no good, take it out of the mix.
self.nameservers.remove(self.nameserver) self.nameservers.remove(self.nameserver)
@ -792,7 +866,13 @@ class _Resolution:
elif rcode == dns.rcode.YXDOMAIN: elif rcode == dns.rcode.YXDOMAIN:
yex = YXDOMAIN() yex = YXDOMAIN()
self.errors.append( self.errors.append(
(self.nameserver, self.tcp_attempt, self.port, yex, response) (
str(self.nameserver),
self.tcp_attempt,
self.nameserver.answer_port(),
yex,
response,
)
) )
raise yex raise yex
else: else:
@ -804,9 +884,9 @@ class _Resolution:
self.nameservers.remove(self.nameserver) self.nameservers.remove(self.nameserver)
self.errors.append( self.errors.append(
( (
self.nameserver, str(self.nameserver),
self.tcp_attempt, self.tcp_attempt,
self.port, self.nameserver.answer_port(),
dns.rcode.to_text(rcode), dns.rcode.to_text(rcode),
response, response,
) )
@ -840,8 +920,11 @@ class BaseResolver:
retry_servfail: bool retry_servfail: bool
rotate: bool rotate: bool
ndots: Optional[int] ndots: Optional[int]
_nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True): def __init__(
self, filename: str = "/etc/resolv.conf", configure: bool = True
) -> None:
"""*filename*, a ``str`` or file object, specifying a file """*filename*, a ``str`` or file object, specifying a file
in standard /etc/resolv.conf format. This parameter is meaningful in standard /etc/resolv.conf format. This parameter is meaningful
only when *configure* is true and the platform is POSIX. only when *configure* is true and the platform is POSIX.
@ -860,13 +943,13 @@ class BaseResolver:
elif filename: elif filename:
self.read_resolv_conf(filename) self.read_resolv_conf(filename)
def reset(self): def reset(self) -> None:
"""Reset all resolver configuration to the defaults.""" """Reset all resolver configuration to the defaults."""
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
if len(self.domain) == 0: if len(self.domain) == 0:
self.domain = dns.name.root self.domain = dns.name.root
self.nameservers = [] self._nameservers = []
self.nameserver_ports = {} self.nameserver_ports = {}
self.port = 53 self.port = 53
self.search = [] self.search = []
@ -903,6 +986,7 @@ class BaseResolver:
""" """
nameservers = []
if isinstance(f, str): if isinstance(f, str):
try: try:
cm: contextlib.AbstractContextManager = open(f) cm: contextlib.AbstractContextManager = open(f)
@ -922,7 +1006,7 @@ class BaseResolver:
continue continue
if tokens[0] == "nameserver": if tokens[0] == "nameserver":
self.nameservers.append(tokens[1]) nameservers.append(tokens[1])
elif tokens[0] == "domain": elif tokens[0] == "domain":
self.domain = dns.name.from_text(tokens[1]) self.domain = dns.name.from_text(tokens[1])
# domain and search are exclusive # domain and search are exclusive
@ -950,8 +1034,11 @@ class BaseResolver:
self.ndots = int(opt.split(":")[1]) self.ndots = int(opt.split(":")[1])
except (ValueError, IndexError): except (ValueError, IndexError):
pass pass
if len(self.nameservers) == 0: if len(nameservers) == 0:
raise NoResolverConfiguration("no nameservers") raise NoResolverConfiguration("no nameservers")
# Assigning directly instead of appending means we invoke the
# setter logic, with additonal checking and enrichment.
self.nameservers = nameservers
def read_registry(self) -> None: def read_registry(self) -> None:
"""Extract resolver configuration from the Windows registry.""" """Extract resolver configuration from the Windows registry."""
@ -1086,34 +1173,64 @@ class BaseResolver:
self.flags = flags self.flags = flags
@property @classmethod
def nameservers(self) -> List[str]: def _enrich_nameservers(
return self._nameservers cls,
nameservers: Sequence[Union[str, dns.nameserver.Nameserver]],
@nameservers.setter nameserver_ports: Dict[str, int],
def nameservers(self, nameservers: List[str]) -> None: default_port: int,
""" ) -> List[dns.nameserver.Nameserver]:
*nameservers*, a ``list`` of nameservers. enriched_nameservers = []
Raises ``ValueError`` if *nameservers* is anything other than a
``list``.
"""
if isinstance(nameservers, list): if isinstance(nameservers, list):
for nameserver in nameservers: for nameserver in nameservers:
if not dns.inet.is_address(nameserver): enriched_nameserver: dns.nameserver.Nameserver
if isinstance(nameserver, dns.nameserver.Nameserver):
enriched_nameserver = nameserver
elif dns.inet.is_address(nameserver):
port = nameserver_ports.get(nameserver, default_port)
enriched_nameserver = dns.nameserver.Do53Nameserver(
nameserver, port
)
else:
try: try:
if urlparse(nameserver).scheme != "https": if urlparse(nameserver).scheme != "https":
raise NotImplementedError raise NotImplementedError
except Exception: except Exception:
raise ValueError( raise ValueError(
f"nameserver {nameserver} is not an " f"nameserver {nameserver} is not a "
"IP address or valid https URL" "dns.nameserver.Nameserver instance or text form, "
"IP address, nor a valid https URL"
) )
self._nameservers = nameservers enriched_nameserver = dns.nameserver.DoHNameserver(nameserver)
enriched_nameservers.append(enriched_nameserver)
else: else:
raise ValueError( raise ValueError(
"nameservers must be a list (not a {})".format(type(nameservers)) "nameservers must be a list or tuple (not a {})".format(
type(nameservers)
)
) )
return enriched_nameservers
@property
def nameservers(
self,
) -> Sequence[Union[str, dns.nameserver.Nameserver]]:
return self._nameservers
@nameservers.setter
def nameservers(
self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]]
) -> None:
"""
*nameservers*, a ``list`` of nameservers, where a nameserver is either
a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
instance.
Raises ``ValueError`` if *nameservers* is not a list of nameservers.
"""
# We just call _enrich_nameservers() for checking
self._enrich_nameservers(nameservers, self.nameserver_ports, self.port)
self._nameservers = nameservers
class Resolver(BaseResolver): class Resolver(BaseResolver):
@ -1198,33 +1315,18 @@ class Resolver(BaseResolver):
assert request is not None # needed for type checking assert request is not None # needed for type checking
done = False done = False
while not done: while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver() (nameserver, tcp, backoff) = resolution.next_nameserver()
if backoff: if backoff:
time.sleep(backoff) time.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, resolution.errors) timeout = self._compute_timeout(start, lifetime, resolution.errors)
try: try:
if dns.inet.is_address(nameserver): response = nameserver.query(
if tcp: request,
response = dns.query.tcp( timeout=timeout,
request, source=source,
nameserver, source_port=source_port,
timeout=timeout, max_size=tcp,
port=port, )
source=source,
source_port=source_port,
)
else:
response = dns.query.udp(
request,
nameserver,
timeout=timeout,
port=port,
source=source,
source_port=source_port,
raise_on_truncation=True,
)
else:
response = dns.query.https(request, nameserver, timeout=timeout)
except Exception as ex: except Exception as ex:
(_, done) = resolution.query_result(None, ex) (_, done) = resolution.query_result(None, ex)
continue continue
@ -1293,7 +1395,72 @@ class Resolver(BaseResolver):
modified_kwargs["rdclass"] = dns.rdataclass.IN modified_kwargs["rdclass"] = dns.rdataclass.IN
return self.resolve( return self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs dns.reversename.from_address(ipaddr), *args, **modified_kwargs
) # type: ignore[arg-type] )
def resolve_name(
self,
name: Union[dns.name.Name, str],
family: int = socket.AF_UNSPEC,
**kwargs: Any,
) -> HostAnswers:
"""Use a resolver to query for address records.
This utilizes the resolve() method to perform A and/or AAAA lookups on
the specified name.
*qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
*family*, an ``int``, the address family. If socket.AF_UNSPEC
(the default), both A and AAAA records will be retrieved.
All other arguments that can be passed to the resolve() function
except for rdtype and rdclass are also supported by this
function.
"""
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs.pop("rdtype", None)
modified_kwargs["rdclass"] = dns.rdataclass.IN
if family == socket.AF_INET:
v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs)
return HostAnswers.make(v4=v4)
elif family == socket.AF_INET6:
v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC:
raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
lifetime = modified_kwargs.pop("lifetime", None)
start = time.time()
v6 = self.resolve(
name,
dns.rdatatype.AAAA,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
# Note that setting name ensures we query the same name
# for A as we did for AAAA. (This is just in case search lists
# are active by default in the resolver configuration and
# we might be talking to a server that says NXDOMAIN when it
# wants to say NOERROR no data.
name = v6.qname
v4 = self.resolve(
name,
dns.rdatatype.A,
raise_on_no_answer=False,
lifetime=self._compute_timeout(start, lifetime),
**modified_kwargs,
)
answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer)
if not answers:
raise NoAnswer(response=v6.response)
return answers
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
@ -1320,6 +1487,37 @@ class Resolver(BaseResolver):
# pylint: enable=redefined-outer-name # pylint: enable=redefined-outer-name
def try_ddr(self, lifetime: float = 5.0) -> None:
"""Try to update the resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
*lifetime*, a float, is the maximum time to spend attempting DDR. The default
is 5 seconds.
If the SVCB query is successful and results in a non-empty list of nameservers,
then the resolver's nameservers are set to the returned servers in priority
order.
The current implementation does not use any address hints from the SVCB record,
nor does it resolve addresses for the SCVB target name, rather it assumes that
the bootstrap nameserver will always be one of the addresses and uses it.
A future revision to the code may offer fuller support. The code verifies that
the bootstrap nameserver is in the Subject Alternative Name field of the
TLS certficate.
"""
try:
expiration = time.time() + lifetime
answer = self.resolve(
dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime
)
timeout = dns.query._remaining(expiration)
nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
if len(nameservers) > 0:
self.nameservers = nameservers
except Exception:
pass
#: The default resolver. #: The default resolver.
default_resolver: Optional[Resolver] = None default_resolver: Optional[Resolver] = None
@ -1333,7 +1531,7 @@ def get_default_resolver() -> Resolver:
return default_resolver return default_resolver
def reset_default_resolver(): def reset_default_resolver() -> None:
"""Re-initialize default resolver. """Re-initialize default resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
@ -1355,7 +1553,6 @@ def resolve(
lifetime: Optional[float] = None, lifetime: Optional[float] = None,
search: Optional[bool] = None, search: Optional[bool] = None,
) -> Answer: # pragma: no cover ) -> Answer: # pragma: no cover
"""Query nameservers to find the answer to the question. """Query nameservers to find the answer to the question.
This is a convenience function that uses the default resolver This is a convenience function that uses the default resolver
@ -1421,6 +1618,18 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer:
return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
def resolve_name(
name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any
) -> HostAnswers:
"""Use a resolver to query for address records.
See ``dns.resolver.Resolver.resolve_name`` for more information on the
parameters.
"""
return get_default_resolver().resolve_name(name, family, **kwargs)
def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
@ -1431,6 +1640,16 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
return get_default_resolver().canonical_name(name) return get_default_resolver().canonical_name(name)
def try_ddr(lifetime: float = 5.0) -> None:
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
See :py:func:`dns.resolver.Resolver.try_ddr` for more information.
"""
return get_default_resolver().try_ddr(lifetime)
def zone_for_name( def zone_for_name(
name: Union[dns.name.Name, str], name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
@ -1478,7 +1697,7 @@ def zone_for_name(
while 1: while 1:
try: try:
rlifetime: Optional[float] rlifetime: Optional[float]
if expiration: if expiration is not None:
rlifetime = expiration - time.time() rlifetime = expiration - time.time()
if rlifetime <= 0: if rlifetime <= 0:
rlifetime = 0 rlifetime = 0
@ -1516,6 +1735,83 @@ def zone_for_name(
raise NoRootSOA raise NoRootSOA
def make_resolver_at(
where: Union[dns.name.Name, str],
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Optional[Resolver] = None,
) -> Resolver:
"""Make a stub resolver using the specified destination as the full resolver.
*where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the
full resolver.
*port*, an ``int``, the port to use. If not specified, the default is 53.
*family*, an ``int``, the address family to use. This parameter is used if
*where* is not an address. The default is ``socket.AF_UNSPEC`` in which case
the first address returned by ``resolve_name()`` will be used, otherwise the
first address of the specified family will be used.
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
resolution of hostnames. If not specified, the default resolver will be used.
Returns a ``dns.resolver.Resolver`` or raises an exception.
"""
if resolver is None:
resolver = get_default_resolver()
nameservers: List[Union[str, dns.nameserver.Nameserver]] = []
if isinstance(where, str) and dns.inet.is_address(where):
nameservers.append(dns.nameserver.Do53Nameserver(where, port))
else:
for address in resolver.resolve_name(where, family).addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
res = dns.resolver.Resolver(configure=False)
res.nameservers = nameservers
return res
def resolve_at(
where: Union[dns.name.Name, str],
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
port: int = 53,
family: int = socket.AF_UNSPEC,
resolver: Optional[Resolver] = None,
) -> Answer:
"""Query nameservers to find the answer to the question.
This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to
make a resolver, and then uses it to resolve the query.
See ``dns.resolver.Resolver.resolve`` for more information on the resolution
parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver
parameters *where*, *port*, *family*, and *resolver*.
If making more than one query, it is more efficient to call
``dns.resolver.make_resolver_at()`` and then use that resolver for the queries
instead of calling ``resolve_at()`` multiple times.
"""
return make_resolver_at(where, port, family, resolver).resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
)
# #
# Support for overriding the system resolver for all python code in the # Support for overriding the system resolver for all python code in the
# running process. # running process.
@ -1559,8 +1855,7 @@ def _getaddrinfo(
) )
if host is None and service is None: if host is None and service is None:
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
v6addrs = [] addrs = []
v4addrs = []
canonical_name = None # pylint: disable=redefined-outer-name canonical_name = None # pylint: disable=redefined-outer-name
# Is host None or an address literal? If so, use the system's # Is host None or an address literal? If so, use the system's
# getaddrinfo(). # getaddrinfo().
@ -1576,24 +1871,9 @@ def _getaddrinfo(
pass pass
# Something needs resolution! # Something needs resolution!
try: try:
if family == socket.AF_INET6 or family == socket.AF_UNSPEC: answers = _resolver.resolve_name(host, family)
v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False) addrs = answers.addresses_and_families()
# Note that setting host ensures we query the same name canonical_name = answers.canonical_name().to_text(True)
# for A as we did for AAAA. (This is just in case search lists
# are active by default in the resolver configuration and
# we might be talking to a server that says NXDOMAIN when it
# wants to say NOERROR no data.
host = v6.qname
canonical_name = v6.canonical_name.to_text(True)
if v6.rrset is not None:
for rdata in v6.rrset:
v6addrs.append(rdata.address)
if family == socket.AF_INET or family == socket.AF_UNSPEC:
v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False)
canonical_name = v4.canonical_name.to_text(True)
if v4.rrset is not None:
for rdata in v4.rrset:
v4addrs.append(rdata.address)
except dns.resolver.NXDOMAIN: except dns.resolver.NXDOMAIN:
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
except Exception: except Exception:
@ -1625,20 +1905,11 @@ def _getaddrinfo(
cname = canonical_name cname = canonical_name
else: else:
cname = "" cname = ""
if family == socket.AF_INET6 or family == socket.AF_UNSPEC: for addr, af in addrs:
for addr in v6addrs: for socktype in socktypes:
for socktype in socktypes: for proto in _protocols_for_socktype[socktype]:
for proto in _protocols_for_socktype[socktype]: addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
tuples.append( tuples.append((af, socktype, proto, cname, addr_tuple))
(socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0))
)
if family == socket.AF_INET or family == socket.AF_UNSPEC:
for addr in v4addrs:
for socktype in socktypes:
for proto in _protocols_for_socktype[socktype]:
tuples.append(
(socket.AF_INET, socktype, proto, cname, (addr, port))
)
if len(tuples) == 0: if len(tuples) == 0:
raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
return tuples return tuples

View file

@ -19,9 +19,9 @@
import binascii import binascii
import dns.name
import dns.ipv6
import dns.ipv4 import dns.ipv4
import dns.ipv6
import dns.name
ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.") ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.")
ipv6_reverse_domain = dns.name.from_text("ip6.arpa.") ipv6_reverse_domain = dns.name.from_text("ip6.arpa.")

View file

@ -17,11 +17,11 @@
"""DNS RRsets (an RRset is a named rdataset)""" """DNS RRsets (an RRset is a named rdataset)"""
from typing import Any, cast, Collection, Dict, Optional, Union from typing import Any, Collection, Dict, Optional, Union, cast
import dns.name import dns.name
import dns.rdataset
import dns.rdataclass import dns.rdataclass
import dns.rdataset
import dns.renderer import dns.renderer
@ -214,9 +214,9 @@ def from_text_list(
if isinstance(name, str): if isinstance(name, str):
name = dns.name.from_text(name, None, idna_codec=idna_codec) name = dns.name.from_text(name, None, idna_codec=idna_codec)
the_rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
r = RRset(name, the_rdclass, the_rdtype) r = RRset(name, rdclass, rdtype)
r.update_ttl(ttl) r.update_ttl(ttl)
for t in text_rdatas: for t in text_rdatas:
rd = dns.rdata.from_text( rd = dns.rdata.from_text(

View file

@ -17,10 +17,9 @@
"""Tokenize DNS zone file format""" """Tokenize DNS zone file format"""
from typing import Any, Optional, List, Tuple
import io import io
import sys import sys
from typing import Any, List, Optional, Tuple
import dns.exception import dns.exception
import dns.name import dns.name

View file

@ -1,8 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Any, Callable, List, Optional, Tuple, Union
import collections import collections
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
import dns.exception import dns.exception
import dns.name import dns.name
@ -357,6 +356,27 @@ class Transaction:
""" """
self._check_delete_name.append(check) self._check_delete_name.append(check)
def iterate_rdatasets(
self,
) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
"""Iterate all the rdatasets in the transaction, returning
(`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
Note that as is usual with python iterators, adding or removing items
while iterating will invalidate the iterator and may raise `RuntimeError`
or fail to iterate over all entries."""
self._check_ended()
return self._iterate_rdatasets()
def iterate_names(self) -> Iterator[dns.name.Name]:
"""Iterate all the names in the transaction.
Note that as is usual with python iterators, adding or removing names
while iterating will invalidate the iterator and may raise `RuntimeError`
or fail to iterate over all entries."""
self._check_ended()
return self._iterate_names()
# #
# Helper methods # Helper methods
# #
@ -416,7 +436,7 @@ class Transaction:
rdataset = rrset.to_rdataset() rdataset = rrset.to_rdataset()
else: else:
raise TypeError( raise TypeError(
f"{method} requires a name or RRset " + "as the first argument" f"{method} requires a name or RRset as the first argument"
) )
if rdataset.rdclass != self.manager.get_class(): if rdataset.rdclass != self.manager.get_class():
raise ValueError(f"{method} has objects of wrong RdataClass") raise ValueError(f"{method} has objects of wrong RdataClass")
@ -475,7 +495,7 @@ class Transaction:
name = rdataset.name name = rdataset.name
else: else:
raise TypeError( raise TypeError(
f"{method} requires a name or RRset " + "as the first argument" f"{method} requires a name or RRset as the first argument"
) )
self._raise_if_not_empty(method, args) self._raise_if_not_empty(method, args)
if rdataset: if rdataset:
@ -610,6 +630,10 @@ class Transaction:
"""Return an iterator that yields (name, rdataset) tuples.""" """Return an iterator that yields (name, rdataset) tuples."""
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def _iterate_names(self):
"""Return an iterator that yields a name."""
raise NotImplementedError # pragma: no cover
def _get_node(self, name): def _get_node(self, name):
"""Return the node at *name*, if any. """Return the node at *name*, if any.

View file

@ -23,9 +23,9 @@ import hmac
import struct import struct
import dns.exception import dns.exception
import dns.rdataclass
import dns.name import dns.name
import dns.rcode import dns.rcode
import dns.rdataclass
class BadTime(dns.exception.DNSException): class BadTime(dns.exception.DNSException):
@ -187,9 +187,7 @@ class HMACTSig:
try: try:
hashinfo = self._hashes[algorithm] hashinfo = self._hashes[algorithm]
except KeyError: except KeyError:
raise NotImplementedError( raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported")
f"TSIG algorithm {algorithm} " + "is not supported"
)
# create the HMAC context # create the HMAC context
if isinstance(hashinfo, tuple): if isinstance(hashinfo, tuple):

View file

@ -17,9 +17,8 @@
"""A place to store TSIG keys.""" """A place to store TSIG keys."""
from typing import Any, Dict
import base64 import base64
from typing import Any, Dict
import dns.name import dns.name
import dns.tsig import dns.tsig
@ -33,7 +32,7 @@ def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
@rtype: dict""" @rtype: dict"""
keyring = {} keyring = {}
for (name, value) in textring.items(): for name, value in textring.items():
kname = dns.name.from_text(name) kname = dns.name.from_text(name)
if isinstance(value, str): if isinstance(value, str):
keyring[kname] = dns.tsig.Key(kname, value).secret keyring[kname] = dns.tsig.Key(kname, value).secret
@ -55,7 +54,7 @@ def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]:
def b64encode(secret): def b64encode(secret):
return base64.encodebytes(secret).decode().rstrip() return base64.encodebytes(secret).decode().rstrip()
for (name, key) in keyring.items(): for name, key in keyring.items():
tname = name.to_text() tname = name.to_text()
if isinstance(key, bytes): if isinstance(key, bytes):
textring[tname] = b64encode(key) textring[tname] = b64encode(key)

View file

@ -24,8 +24,8 @@ import dns.name
import dns.opcode import dns.opcode
import dns.rdata import dns.rdata
import dns.rdataclass import dns.rdataclass
import dns.rdatatype
import dns.rdataset import dns.rdataset
import dns.rdatatype
import dns.tsig import dns.tsig
@ -43,7 +43,6 @@ class UpdateSection(dns.enum.IntEnum):
class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
# ignore the mypy error here as we mean to use a different enum # ignore the mypy error here as we mean to use a different enum
_section_enum = UpdateSection # type: ignore _section_enum = UpdateSection # type: ignore
@ -336,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals]
True, True,
) )
else: else:
the_rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
self.find_rrset( self.find_rrset(
self.prerequisite, self.prerequisite,
name, name,
dns.rdataclass.NONE, dns.rdataclass.NONE,
the_rdtype, rdtype,
dns.rdatatype.NONE, dns.rdatatype.NONE,
None, None,
True, True,

Some files were not shown because too many files have changed in this diff Show more