Bump dnspython from 2.6.1 to 2.7.0 (#2440)

* Bump dnspython from 2.6.1 to 2.7.0

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.6.1 to 2.7.0.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.6.1...v2.7.0)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.7.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2024-11-19 10:00:50 -08:00 committed by GitHub
parent 0836fb902c
commit feca713b76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 1382 additions and 665 deletions

View file

@ -26,6 +26,10 @@ class NullContext:
class Socket: # pragma: no cover class Socket: # pragma: no cover
def __init__(self, family: int, type: int):
self.family = family
self.type = type
async def close(self): async def close(self):
pass pass
@ -46,9 +50,6 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
raise NotImplementedError raise NotImplementedError

View file

@ -42,7 +42,7 @@ class _DatagramProtocol:
if exc is None: if exc is None:
# EOF we triggered. Is there a better way to do this? # EOF we triggered. Is there a better way to do this?
try: try:
raise EOFError raise EOFError("EOF")
except EOFError as e: except EOFError as e:
self.recvfrom.set_exception(e) self.recvfrom.set_exception(e)
else: else:
@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol): def __init__(self, family, transport, protocol):
super().__init__(family) super().__init__(family, socket.SOCK_DGRAM)
self.transport = transport self.transport = transport
self.protocol = protocol self.protocol = protocol
@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer): def __init__(self, af, reader, writer):
self.family = af super().__init__(af, socket.SOCK_STREAM)
self.reader = reader self.reader = reader
self.writer = writer self.writer = writer
@ -197,7 +197,7 @@ if dns._features.have("doh"):
family=socket.AF_UNSPEC, family=socket.AF_UNSPEC,
**kwargs, **kwargs,
): ):
if resolver is None: if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name # pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver import dns.asyncresolver

View file

@ -32,6 +32,9 @@ def _version_check(
package, minimum = requirement.split(">=") package, minimum = requirement.split(">=")
try: try:
version = importlib.metadata.version(package) version = importlib.metadata.version(package)
# This shouldn't happen, but it apparently can.
if version is None:
return False
except Exception: except Exception:
return False return False
t_version = _tuple_from_text(version) t_version = _tuple_from_text(version)
@ -82,10 +85,10 @@ def force(feature: str, enabled: bool) -> None:
_requirements: Dict[str, List[str]] = { _requirements: Dict[str, List[str]] = {
### BEGIN generated requirements ### BEGIN generated requirements
"dnssec": ["cryptography>=41"], "dnssec": ["cryptography>=43"],
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"], "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
"doq": ["aioquic>=0.9.25"], "doq": ["aioquic>=1.0.0"],
"idna": ["idna>=3.6"], "idna": ["idna>=3.7"],
"trio": ["trio>=0.23"], "trio": ["trio>=0.23"],
"wmi": ["wmi>=1.5.1"], "wmi": ["wmi>=1.5.1"],
### END generated requirements ### END generated requirements

View file

@ -30,13 +30,16 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, sock):
super().__init__(socket.family) super().__init__(sock.family, socket.SOCK_DGRAM)
self.socket = socket self.socket = sock
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination) if destination is None:
return await self.socket.send(what)
else:
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout( raise dns.exception.Timeout(
timeout=timeout timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement] ) # pragma: no cover lgtm[py/unreachable-statement]
@ -61,7 +64,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False): def __init__(self, family, stream, tls=False):
self.family = family super().__init__(family, socket.SOCK_STREAM)
self.stream = stream self.stream = stream
self.tls = tls self.tls = tls
@ -171,7 +174,7 @@ if dns._features.have("doh"):
family=socket.AF_UNSPEC, family=socket.AF_UNSPEC,
**kwargs, **kwargs,
): ):
if resolver is None: if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name # pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver import dns.asyncresolver
@ -205,7 +208,7 @@ class Backend(dns._asyncbackend.Backend):
try: try:
if source: if source:
await s.bind(_lltuple(source, af)) await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM: if socktype == socket.SOCK_STREAM or destination is not None:
connected = False connected = False
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af)) await s.connect(_lltuple(destination, af))

View file

@ -19,10 +19,12 @@
import base64 import base64
import contextlib import contextlib
import random
import socket import socket
import struct import struct
import time import time
from typing import Any, Dict, Optional, Tuple, Union import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union, cast
import dns.asyncbackend import dns.asyncbackend
import dns.exception import dns.exception
@ -37,9 +39,11 @@ import dns.transaction
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.query import ( from dns.query import (
BadResponse, BadResponse,
HTTPVersion,
NoDOH, NoDOH,
NoDOQ, NoDOQ,
UDPMode, UDPMode,
_check_status,
_compute_times, _compute_times,
_make_dot_ssl_context, _make_dot_ssl_context,
_matches_destination, _matches_destination,
@ -338,7 +342,7 @@ async def _read_exactly(sock, count, expiration):
while count > 0: while count > 0:
n = await sock.recv(count, _timeout(expiration)) n = await sock.recv(count, _timeout(expiration))
if n == b"": if n == b"":
raise EOFError raise EOFError("EOF")
count = count - len(n) count = count - len(n)
s = s + n s = s + n
return s return s
@ -500,6 +504,20 @@ async def tls(
return response return response
def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"],
) -> "dns.asyncresolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
return resolver
async def https( async def https(
q: dns.message.Message, q: dns.message.Message,
where: str, where: str,
@ -515,7 +533,8 @@ async def https(
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None, bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None, resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC, family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
@ -529,26 +548,65 @@ async def https(
parameters, exceptions, and return type of this method. parameters, exceptions, and return type of this method.
""" """
if not have_doh:
raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
wire = q.to_wire()
try: try:
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
except ValueError: except ValueError:
af = None af = None
transport = None
headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where): if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET: if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path) url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path) url = f"https://[{where}]:{port}{path}"
else: else:
url = where url = where
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = await resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
return await _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment
wire = q.to_wire()
headers = {"accept": "application/dns-message"}
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
if source is None: if source is None:
@ -557,24 +615,23 @@ async def https(
else: else:
local_address = source local_address = source
local_port = source_port local_port = source_port
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if client: if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client) cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else: else:
cm = httpx.AsyncClient( transport = backend.get_transport_class()(
http1=True, http2=True, verify=verify, transport=transport local_address=local_address,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
) )
cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
async with cm as the_client: async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
@ -586,23 +643,33 @@ async def https(
} }
) )
response = await backend.wait_for( response = await backend.wait_for(
the_client.post(url, headers=headers, content=wire), timeout the_client.post(
url,
headers=headers,
content=wire,
extensions=extensions,
),
timeout,
) )
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for( response = await backend.wait_for(
the_client.get(url, headers=headers, params={"dns": twire}), timeout the_client.get(
url,
headers=headers,
params={"dns": twire},
extensions=extensions,
),
timeout,
) )
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
if response.status_code < 200 or response.status_code > 299: if response.status_code < 200 or response.status_code > 299:
raise ValueError( raise ValueError(
"{} responded with status code {}" f"{where} responded with status code {response.status_code}"
"\nResponse body: {!r}".format( f"\nResponse body: {response.content!r}"
where, response.status_code, response.content
)
) )
r = dns.message.from_wire( r = dns.message.from_wire(
response.content, response.content,
@ -617,6 +684,181 @@ async def https(
return r return r
async def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=hostname, h3=True
) as the_manager:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
# note that send_h3() does not need await
stream.send_h3(url, wire, post)
wire = await stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context,
verify_mode=verify,
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: dns.asyncbackend.Socket,
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
await udp_sock.sendto(wire, None, _timeout(expiration))
else:
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
tcpmsg = struct.pack("!H", len(wire)) + wire
await tcp_sock.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
(rwire, _) = await udp_sock.recvfrom(65535, timeout)
else:
ldata = await _read_exactly(tcp_sock, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(tcp_sock, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def inbound_xfr( async def inbound_xfr(
where: str, where: str,
txn_manager: dns.transaction.TransactionManager, txn_manager: dns.transaction.TransactionManager,
@ -642,139 +884,30 @@ async def inbound_xfr(
(query, serial) = dns.xfr.make_query(txn_manager) (query, serial) = dns.xfr.make_query(txn_manager)
else: else:
serial = dns.xfr.extract_serial_from_query(query) serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port) stuple = _source_tuple(af, source, source_port)
dtuple = (where, port) dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
(_, expiration) = _compute_times(lifetime) (_, expiration) = _compute_times(lifetime)
retry = True if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket( s = await backend.make_socket(
af, sock_type, 0, stuple, dtuple, _timeout(expiration) af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
) )
async with s: async with s:
if is_udp: try:
await s.sendto(wire, dtuple, _timeout(expiration)) async for _ in _inbound_xfr(
else: txn_manager, s, query, serial, timeout, expiration
tcpmsg = struct.pack("!H", len(wire)) + wire ):
await s.sendall(tcpmsg, expiration) pass
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: return
done = False except dns.xfr.UseTCP:
tsig_ctx = None if udp_mode == UDPMode.ONLY:
while not done: raise
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
destination = _lltuple((where, port), af)
while True:
timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535, timeout)
if _matches_destination(
af, from_address, destination, True
):
break
else:
ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
s = await backend.make_socket(
async def quic( af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
q: dns.message.Message, )
where: str, async with s:
timeout: Optional[float] = None, async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
port: int = 853, pass
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=server_hostname
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r

View file

@ -118,6 +118,7 @@ def key_id(key: Union[DNSKEY, CDNSKEY]) -> int:
""" """
rdata = key.to_wire() rdata = key.to_wire()
assert rdata is not None # for mypy
if key.algorithm == Algorithm.RSAMD5: if key.algorithm == Algorithm.RSAMD5:
return (rdata[-3] << 8) + rdata[-2] return (rdata[-3] << 8) + rdata[-2]
else: else:
@ -224,7 +225,7 @@ def make_ds(
if isinstance(algorithm, str): if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()] algorithm = DSDigest[algorithm.upper()]
except Exception: except Exception:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
if validating: if validating:
check = policy.ok_to_validate_ds check = policy.ok_to_validate_ds
else: else:
@ -240,14 +241,15 @@ def make_ds(
elif algorithm == DSDigest.SHA384: elif algorithm == DSDigest.SHA384:
dshash = hashlib.sha384() dshash = hashlib.sha384()
else: else:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
if isinstance(name, str): if isinstance(name, str):
name = dns.name.from_text(name, origin) name = dns.name.from_text(name, origin)
wire = name.canonicalize().to_wire() wire = name.canonicalize().to_wire()
assert wire is not None kwire = key.to_wire(origin=origin)
assert wire is not None and kwire is not None # for mypy
dshash.update(wire) dshash.update(wire)
dshash.update(key.to_wire(origin=origin)) dshash.update(kwire)
digest = dshash.digest() digest = dshash.digest()
dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest
@ -323,6 +325,7 @@ def _get_rrname_rdataset(
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None: def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
# pylint: disable=possibly-used-before-assignment
public_cls = get_algorithm_cls_from_dnskey(key).public_cls public_cls = get_algorithm_cls_from_dnskey(key).public_cls
try: try:
public_key = public_cls.from_dnskey(key) public_key = public_cls.from_dnskey(key)
@ -387,6 +390,7 @@ def _validate_rrsig(
data = _make_rrsig_signature_data(rrset, rrsig, origin) data = _make_rrsig_signature_data(rrset, rrsig, origin)
# pylint: disable=possibly-used-before-assignment
for candidate_key in candidate_keys: for candidate_key in candidate_keys:
if not policy.ok_to_validate(candidate_key): if not policy.ok_to_validate(candidate_key):
continue continue
@ -484,6 +488,7 @@ def _sign(
verify: bool = False, verify: bool = False,
policy: Optional[Policy] = None, policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
deterministic: bool = True,
) -> RRSIG: ) -> RRSIG:
"""Sign RRset using private key. """Sign RRset using private key.
@ -523,6 +528,10 @@ def _sign(
names in the rrset (including its owner name) must be absolute; otherwise the names in the rrset (including its owner name) must be absolute; otherwise the
specified origin will be used to make names absolute when signing. specified origin will be used to make names absolute when signing.
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
(reproducible) signatures when supported by the algorithm used for signing.
Currently, this only affects ECDSA.
Raises ``DeniedByPolicy`` if the signature is denied by policy. Raises ``DeniedByPolicy`` if the signature is denied by policy.
""" """
@ -580,6 +589,7 @@ def _sign(
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
# pylint: disable=possibly-used-before-assignment
if isinstance(private_key, GenericPrivateKey): if isinstance(private_key, GenericPrivateKey):
signing_key = private_key signing_key = private_key
else: else:
@ -589,7 +599,7 @@ def _sign(
except UnsupportedAlgorithm: except UnsupportedAlgorithm:
raise TypeError("Unsupported key algorithm") raise TypeError("Unsupported key algorithm")
signature = signing_key.sign(data, verify) signature = signing_key.sign(data, verify, deterministic)
return cast(RRSIG, rrsig_template.replace(signature=signature)) return cast(RRSIG, rrsig_template.replace(signature=signature))
@ -629,7 +639,9 @@ def _make_rrsig_signature_data(
rrname, rdataset = _get_rrname_rdataset(rrset) rrname, rdataset = _get_rrname_rdataset(rrset)
data = b"" data = b""
data += rrsig.to_wire(origin=signer)[:18] wire = rrsig.to_wire(origin=signer)
assert wire is not None # for mypy
data += wire[:18]
data += rrsig.signer.to_digestable(signer) data += rrsig.signer.to_digestable(signer)
# Derelativize the name before considering labels. # Derelativize the name before considering labels.
@ -686,6 +698,7 @@ def _make_dnskey(
algorithm = Algorithm.make(algorithm) algorithm = Algorithm.make(algorithm)
# pylint: disable=possibly-used-before-assignment
if isinstance(public_key, GenericPublicKey): if isinstance(public_key, GenericPublicKey):
return public_key.to_dnskey(flags=flags, protocol=protocol) return public_key.to_dnskey(flags=flags, protocol=protocol)
else: else:
@ -832,7 +845,7 @@ def make_ds_rdataset(
if isinstance(algorithm, str): if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()] algorithm = DSDigest[algorithm.upper()]
except Exception: except Exception:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
_algorithms.add(algorithm) _algorithms.add(algorithm)
if rdataset.rdtype == dns.rdatatype.CDS: if rdataset.rdtype == dns.rdatatype.CDS:
@ -950,6 +963,7 @@ def default_rrset_signer(
lifetime: Optional[int] = None, lifetime: Optional[int] = None,
policy: Optional[Policy] = None, policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
deterministic: bool = True,
) -> None: ) -> None:
"""Default RRset signer""" """Default RRset signer"""
@ -975,6 +989,7 @@ def default_rrset_signer(
signer=signer, signer=signer,
policy=policy, policy=policy,
origin=origin, origin=origin,
deterministic=deterministic,
) )
txn.add(rrset.name, rrset.ttl, rrsig) txn.add(rrset.name, rrset.ttl, rrsig)
@ -991,6 +1006,7 @@ def sign_zone(
nsec3: Optional[NSEC3PARAM] = None, nsec3: Optional[NSEC3PARAM] = None,
rrset_signer: Optional[RRsetSigner] = None, rrset_signer: Optional[RRsetSigner] = None,
policy: Optional[Policy] = None, policy: Optional[Policy] = None,
deterministic: bool = True,
) -> None: ) -> None:
"""Sign zone. """Sign zone.
@ -1030,6 +1046,10 @@ def sign_zone(
function requires two arguments: transaction and RRset. If the not specified, function requires two arguments: transaction and RRset. If the not specified,
``dns.dnssec.default_rrset_signer`` will be used. ``dns.dnssec.default_rrset_signer`` will be used.
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
(reproducible) signatures when supported by the algorithm used for signing.
Currently, this only affects ECDSA.
Returns ``None``. Returns ``None``.
""" """
@ -1056,6 +1076,9 @@ def sign_zone(
else: else:
cm = zone.writer() cm = zone.writer()
if zone.origin is None:
raise ValueError("no zone origin")
with cm as _txn: with cm as _txn:
if add_dnskey: if add_dnskey:
if dnskey_ttl is None: if dnskey_ttl is None:
@ -1081,6 +1104,7 @@ def sign_zone(
lifetime=lifetime, lifetime=lifetime,
policy=policy, policy=policy,
origin=zone.origin, origin=zone.origin,
deterministic=deterministic,
) )
return _sign_zone_nsec(zone, _txn, _rrset_signer) return _sign_zone_nsec(zone, _txn, _rrset_signer)

View file

@ -26,6 +26,7 @@ AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography: if _have_cryptography:
# pylint: disable=possibly-used-before-assignment
algorithms.update( algorithms.update(
{ {
(Algorithm.RSAMD5, None): PrivateRSAMD5, (Algorithm.RSAMD5, None): PrivateRSAMD5,
@ -59,7 +60,7 @@ def get_algorithm_cls(
if cls: if cls:
return cls return cls
raise UnsupportedAlgorithm( raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm) f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
) )

View file

@ -65,7 +65,12 @@ class GenericPrivateKey(ABC):
pass pass
@abstractmethod @abstractmethod
def sign(self, data: bytes, verify: bool = False) -> bytes: def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign DNSSEC data""" """Sign DNSSEC data"""
@abstractmethod @abstractmethod

View file

@ -68,7 +68,12 @@ class PrivateDSA(CryptographyPrivateKey):
key_cls = dsa.DSAPrivateKey key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA public_cls = PublicDSA
def sign(self, data: bytes, verify: bool = False) -> bytes: def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 2536, section 3.""" """Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key() public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024: if public_dsa_key.key_size > 1024:

View file

@ -47,9 +47,17 @@ class PrivateECDSA(CryptographyPrivateKey):
key_cls = ec.EllipticCurvePrivateKey key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA public_cls = PublicECDSA
def sign(self, data: bytes, verify: bool = False) -> bytes: def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 6605, section 4.""" """Sign using a private key per RFC 6605, section 4."""
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash)) algorithm = ec.ECDSA(
self.public_cls.chosen_hash, deterministic_signing=deterministic
)
der_signature = self.key.sign(data, algorithm)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature) dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes( signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big" dsa_r, length=self.public_cls.octets, byteorder="big"

View file

@ -29,7 +29,12 @@ class PublicEDDSA(CryptographyPublicKey):
class PrivateEDDSA(CryptographyPrivateKey): class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA] public_cls: Type[PublicEDDSA]
def sign(self, data: bytes, verify: bool = False) -> bytes: def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 8080, section 4.""" """Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data) signature = self.key.sign(data)
if verify: if verify:

View file

@ -56,7 +56,12 @@ class PrivateRSA(CryptographyPrivateKey):
public_cls = PublicRSA public_cls = PublicRSA
default_public_exponent = 65537 default_public_exponent = 65537
def sign(self, data: bytes, verify: bool = False) -> bytes: def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 3110, section 3.""" """Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash) signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
if verify: if verify:

View file

@ -52,6 +52,8 @@ class OptionType(dns.enum.IntEnum):
CHAIN = 13 CHAIN = 13
#: EDE (extended-dns-error) #: EDE (extended-dns-error)
EDE = 15 EDE = 15
#: REPORTCHANNEL
REPORTCHANNEL = 18
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):
@ -222,7 +224,7 @@ class ECSOption(Option): # lgtm[py/missing-equals]
self.addrdata = self.addrdata[:-1] + last self.addrdata = self.addrdata[:-1] + last
def to_text(self) -> str: def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen) return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
@staticmethod @staticmethod
def from_text(text: str) -> Option: def from_text(text: str) -> Option:
@ -255,10 +257,10 @@ class ECSOption(Option): # lgtm[py/missing-equals]
ecs_text = tokens[0] ecs_text = tokens[0]
elif len(tokens) == 2: elif len(tokens) == 2:
if tokens[0] != optional_prefix: if tokens[0] != optional_prefix:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError(f'could not parse ECS from "{text}"')
ecs_text = tokens[1] ecs_text = tokens[1]
else: else:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError(f'could not parse ECS from "{text}"')
n_slashes = ecs_text.count("/") n_slashes = ecs_text.count("/")
if n_slashes == 1: if n_slashes == 1:
address, tsrclen = ecs_text.split("/") address, tsrclen = ecs_text.split("/")
@ -266,18 +268,16 @@ class ECSOption(Option): # lgtm[py/missing-equals]
elif n_slashes == 2: elif n_slashes == 2:
address, tsrclen, tscope = ecs_text.split("/") address, tsrclen, tscope = ecs_text.split("/")
else: else:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError(f'could not parse ECS from "{text}"')
try: try:
scope = int(tscope) scope = int(tscope)
except ValueError: except ValueError:
raise ValueError( raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
try: try:
srclen = int(tsrclen) srclen = int(tsrclen)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen) "invalid srclen " + f'"{tsrclen}": srclen must be an integer'
) )
return ECSOption(address, srclen, scope) return ECSOption(address, srclen, scope)
@ -430,10 +430,65 @@ class NSIDOption(Option):
return cls(parser.get_remaining()) return cls(parser.get_remaining())
class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
super().__init__(dns.edns.OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
raise ValueError("client cookie must be 8 bytes")
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
def to_wire(self, file: Any = None) -> Optional[bytes]:
if file:
file.write(self.client)
if len(self.server) > 0:
file.write(self.server)
return None
else:
return self.client + self.server
def to_text(self) -> str:
client = binascii.hexlify(self.client).decode()
if len(self.server) > 0:
server = binascii.hexlify(self.server).decode()
else:
server = ""
return f"COOKIE {client}{server}"
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_bytes(8), parser.get_remaining())
class ReportChannelOption(Option):
# RFC 9567
def __init__(self, agent_domain: dns.name.Name):
super().__init__(OptionType.REPORTCHANNEL)
self.agent_domain = agent_domain
def to_wire(self, file: Any = None) -> Optional[bytes]:
return self.agent_domain.to_wire(file)
def to_text(self) -> str:
return "REPORTCHANNEL " + self.agent_domain.to_text()
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_name())
_type_to_class: Dict[OptionType, Any] = { _type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption, OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption, OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption, OptionType.NSID: NSIDOption,
OptionType.COOKIE: CookieOption,
OptionType.REPORTCHANNEL: ReportChannelOption,
} }
@ -512,5 +567,6 @@ KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN CHAIN = OptionType.CHAIN
EDE = OptionType.EDE EDE = OptionType.EDE
REPORTCHANNEL = OptionType.REPORTCHANNEL
### END generated OptionType constants ### END generated OptionType constants

View file

@ -81,7 +81,7 @@ class DNSException(Exception):
if kwargs: if kwargs:
assert ( assert (
set(kwargs.keys()) == self.supp_kwargs set(kwargs.keys()) == self.supp_kwargs
), "following set of keyword args is required: %s" % (self.supp_kwargs) ), f"following set of keyword args is required: {self.supp_kwargs}"
return kwargs return kwargs
def _fmt_kwargs(self, **kwargs): def _fmt_kwargs(self, **kwargs):

View file

@ -54,7 +54,7 @@ def from_text(text: str) -> Tuple[int, int, int]:
elif c.isdigit(): elif c.isdigit():
cur += c cur += c
else: else:
raise dns.exception.SyntaxError("Could not parse %s" % (c)) raise dns.exception.SyntaxError(f"Could not parse {c}")
if state == 0: if state == 0:
raise dns.exception.SyntaxError("no stop value specified") raise dns.exception.SyntaxError("no stop value specified")

View file

@ -143,9 +143,7 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
if m is not None: if m is not None:
b = dns.ipv4.inet_aton(m.group(2)) b = dns.ipv4.inet_aton(m.group(2))
btext = ( btext = (
"{}:{:02x}{:02x}:{:02x}{:02x}".format( f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
m.group(1).decode(), b[0], b[1], b[2], b[3]
)
).encode() ).encode()
# #
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to # Try to turn '::<whatever>' into ':<whatever>'; if no match try to

View file

@ -18,9 +18,10 @@
"""DNS Messages""" """DNS Messages"""
import contextlib import contextlib
import enum
import io import io
import time import time
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union, cast
import dns.edns import dns.edns
import dns.entropy import dns.entropy
@ -161,6 +162,7 @@ class Message:
self.index: IndexType = {} self.index: IndexType = {}
self.errors: List[MessageError] = [] self.errors: List[MessageError] = []
self.time = 0.0 self.time = 0.0
self.wire: Optional[bytes] = None
@property @property
def question(self) -> List[dns.rrset.RRset]: def question(self) -> List[dns.rrset.RRset]:
@ -220,16 +222,16 @@ class Message:
s = io.StringIO() s = io.StringIO()
s.write("id %d\n" % self.id) s.write("id %d\n" % self.id)
s.write("opcode %s\n" % dns.opcode.to_text(self.opcode())) s.write(f"opcode {dns.opcode.to_text(self.opcode())}\n")
s.write("rcode %s\n" % dns.rcode.to_text(self.rcode())) s.write(f"rcode {dns.rcode.to_text(self.rcode())}\n")
s.write("flags %s\n" % dns.flags.to_text(self.flags)) s.write(f"flags {dns.flags.to_text(self.flags)}\n")
if self.edns >= 0: if self.edns >= 0:
s.write("edns %s\n" % self.edns) s.write(f"edns {self.edns}\n")
if self.ednsflags != 0: if self.ednsflags != 0:
s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags)) s.write(f"eflags {dns.flags.edns_to_text(self.ednsflags)}\n")
s.write("payload %d\n" % self.payload) s.write("payload %d\n" % self.payload)
for opt in self.options: for opt in self.options:
s.write("option %s\n" % opt.to_text()) s.write(f"option {opt.to_text()}\n")
for name, which in self._section_enum.__members__.items(): for name, which in self._section_enum.__members__.items():
s.write(f";{name}\n") s.write(f";{name}\n")
for rrset in self.section_from_number(which): for rrset in self.section_from_number(which):
@ -645,6 +647,7 @@ class Message:
if multi: if multi:
self.tsig_ctx = ctx self.tsig_ctx = ctx
wire = r.get_wire() wire = r.get_wire()
self.wire = wire
if prepend_length: if prepend_length:
wire = len(wire).to_bytes(2, "big") + wire wire = len(wire).to_bytes(2, "big") + wire
return wire return wire
@ -912,6 +915,14 @@ class Message:
self.flags &= 0x87FF self.flags &= 0x87FF
self.flags |= dns.opcode.to_flags(opcode) self.flags |= dns.opcode.to_flags(opcode)
def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]:
"""Return the list of options of the specified type."""
return [option for option in self.options if option.otype == otype]
def extended_errors(self) -> List[dns.edns.EDEOption]:
"""Return the list of Extended DNS Error (EDE) options in the message"""
return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE))
def _get_one_rr_per_rrset(self, value): def _get_one_rr_per_rrset(self, value):
# What the caller picked is fine. # What the caller picked is fine.
return value return value
@ -1192,9 +1203,9 @@ class _WireReader:
if rdtype == dns.rdatatype.OPT: if rdtype == dns.rdatatype.OPT:
self.message.opt = dns.rrset.from_rdata(name, ttl, rd) self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
elif rdtype == dns.rdatatype.TSIG: elif rdtype == dns.rdatatype.TSIG:
if self.keyring is None: if self.keyring is None or self.keyring is True:
raise UnknownTSIGKey("got signed message without keyring") raise UnknownTSIGKey("got signed message without keyring")
if isinstance(self.keyring, dict): elif isinstance(self.keyring, dict):
key = self.keyring.get(absolute_name) key = self.keyring.get(absolute_name)
if isinstance(key, bytes): if isinstance(key, bytes):
key = dns.tsig.Key(absolute_name, key, rd.algorithm) key = dns.tsig.Key(absolute_name, key, rd.algorithm)
@ -1203,19 +1214,20 @@ class _WireReader:
else: else:
key = self.keyring key = self.keyring
if key is None: if key is None:
raise UnknownTSIGKey("key '%s' unknown" % name) raise UnknownTSIGKey(f"key '{name}' unknown")
self.message.keyring = key if key:
self.message.tsig_ctx = dns.tsig.validate( self.message.keyring = key
self.parser.wire, self.message.tsig_ctx = dns.tsig.validate(
key, self.parser.wire,
absolute_name, key,
rd, absolute_name,
int(time.time()), rd,
self.message.request_mac, int(time.time()),
rr_start, self.message.request_mac,
self.message.tsig_ctx, rr_start,
self.multi, self.message.tsig_ctx,
) self.multi,
)
self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
else: else:
rrset = self.message.find_rrset( rrset = self.message.find_rrset(
@ -1251,6 +1263,7 @@ class _WireReader:
factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
self.message = factory(id=id) self.message = factory(id=id)
self.message.flags = dns.flags.Flag(flags) self.message.flags = dns.flags.Flag(flags)
self.message.wire = self.parser.wire
self.initialize_message(self.message) self.initialize_message(self.message)
self.one_rr_per_rrset = self.message._get_one_rr_per_rrset( self.one_rr_per_rrset = self.message._get_one_rr_per_rrset(
self.one_rr_per_rrset self.one_rr_per_rrset
@ -1290,8 +1303,10 @@ def from_wire(
) -> Message: ) -> Message:
"""Convert a DNS wire format message into a message object. """Convert a DNS wire format message into a message object.
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message *keyring*, a ``dns.tsig.Key``, ``dict``, ``bool``, or ``None``, the key or keyring
is signed. to use if the message is signed. If ``None`` or ``True``, then trying to decode
a message with a TSIG will fail as it cannot be validated. If ``False``, then
TSIG validation is disabled.
*request_mac*, a ``bytes`` or ``None``. If the message is a response to a *request_mac*, a ``bytes`` or ``None``. If the message is a response to a
TSIG-signed request, *request_mac* should be set to the MAC of that request. TSIG-signed request, *request_mac* should be set to the MAC of that request.
@ -1811,6 +1826,16 @@ def make_query(
return m return m
class CopyMode(enum.Enum):
"""
How should sections be copied when making an update response?
"""
NOTHING = 0
QUESTION = 1
EVERYTHING = 2
def make_response( def make_response(
query: Message, query: Message,
recursion_available: bool = False, recursion_available: bool = False,
@ -1818,13 +1843,14 @@ def make_response(
fudge: int = 300, fudge: int = 300,
tsig_error: int = 0, tsig_error: int = 0,
pad: Optional[int] = None, pad: Optional[int] = None,
copy_mode: Optional[CopyMode] = None,
) -> Message: ) -> Message:
"""Make a message which is a response for the specified query. """Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all of the infrastructure The message returned is really a response skeleton; it has all of the infrastructure
required of a response, but none of the content. required of a response, but none of the content.
The response's question section is a shallow copy of the query's question section, Response section(s) which are copied are shallow copies of the matching section(s)
so the query's question RRsets should not be changed. in the query, so the query's RRsets should not be changed.
*query*, a ``dns.message.Message``, the query to respond to. *query*, a ``dns.message.Message``, the query to respond to.
@ -1837,25 +1863,44 @@ def make_response(
*tsig_error*, an ``int``, the TSIG error. *tsig_error*, an ``int``, the TSIG error.
*pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise *pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
if not ``None`` add padding bytes to make the message size a multiple of *pad*. if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note
Note that if padding is non-zero, an EDNS PADDING option will always be added to the that if padding is non-zero, an EDNS PADDING option will always be added to the
message. If ``None``, add padding following RFC 8467, namely if the request is message. If ``None``, add padding following RFC 8467, namely if the request is
padded, pad the response to 468 otherwise do not pad. padded, pad the response to 468 otherwise do not pad.
*copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are
copied. The default, ``None`` copies sections according to the default for the
message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all
opcodes. ``dns.message.CopyMode.QUESTION`` copies only the question section.
``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG
records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING``
copies no sections; note that this mode is for server testing purposes and is
otherwise not recommended for use. In particular, ``dns.message.is_response()``
will be ``False`` if you create a response this way and the rcode is not
``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``.
Returns a ``dns.message.Message`` object whose specific class is appropriate for the Returns a ``dns.message.Message`` object whose specific class is appropriate for the
query. For example, if query is a ``dns.update.UpdateMessage``, response will be query. For example, if query is a ``dns.update.UpdateMessage``, the response will
too. be one too.
""" """
if query.flags & dns.flags.QR: if query.flags & dns.flags.QR:
raise dns.exception.FormError("specified query message is not a query") raise dns.exception.FormError("specified query message is not a query")
factory = _message_factory_from_opcode(query.opcode()) opcode = query.opcode()
factory = _message_factory_from_opcode(opcode)
response = factory(id=query.id) response = factory(id=query.id)
response.flags = dns.flags.QR | (query.flags & dns.flags.RD) response.flags = dns.flags.QR | (query.flags & dns.flags.RD)
if recursion_available: if recursion_available:
response.flags |= dns.flags.RA response.flags |= dns.flags.RA
response.set_opcode(query.opcode()) response.set_opcode(opcode)
response.question = list(query.question) if copy_mode is None:
copy_mode = CopyMode.QUESTION
if copy_mode != CopyMode.NOTHING:
response.question = list(query.question)
if copy_mode == CopyMode.EVERYTHING:
response.answer = list(query.answer)
response.authority = list(query.authority)
response.additional = list(query.additional)
if query.edns >= 0: if query.edns >= 0:
if pad is None: if pad is None:
# Set response padding per RFC 8467 # Set response padding per RFC 8467

View file

@ -59,11 +59,11 @@ class NameRelation(dns.enum.IntEnum):
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):
return cls.COMMONANCESTOR return cls.COMMONANCESTOR # pragma: no cover
@classmethod @classmethod
def _short_name(cls): def _short_name(cls):
return cls.__name__ return cls.__name__ # pragma: no cover
# Backwards compatibility # Backwards compatibility
@ -277,6 +277,7 @@ class IDNA2008Codec(IDNACodec):
raise NoIDNA2008 raise NoIDNA2008
try: try:
if self.uts_46: if self.uts_46:
# pylint: disable=possibly-used-before-assignment
label = idna.uts46_remap(label, False, self.transitional) label = idna.uts46_remap(label, False, self.transitional)
return idna.alabel(label) return idna.alabel(label)
except idna.IDNAError as e: except idna.IDNAError as e:

View file

@ -168,12 +168,14 @@ class DoHNameserver(Nameserver):
bootstrap_address: Optional[str] = None, bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
want_get: bool = False, want_get: bool = False,
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
): ):
super().__init__() super().__init__()
self.url = url self.url = url
self.bootstrap_address = bootstrap_address self.bootstrap_address = bootstrap_address
self.verify = verify self.verify = verify
self.want_get = want_get self.want_get = want_get
self.http_version = http_version
def kind(self): def kind(self):
return "DoH" return "DoH"
@ -214,6 +216,7 @@ class DoHNameserver(Nameserver):
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
verify=self.verify, verify=self.verify,
post=(not self.want_get), post=(not self.want_get),
http_version=self.http_version,
) )
async def async_query( async def async_query(
@ -238,6 +241,7 @@ class DoHNameserver(Nameserver):
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
verify=self.verify, verify=self.verify,
post=(not self.want_get), post=(not self.want_get),
http_version=self.http_version,
) )

View file

@ -23,11 +23,13 @@ import enum
import errno import errno
import os import os
import os.path import os.path
import random
import selectors import selectors
import socket import socket
import struct import struct
import time import time
from typing import Any, Dict, Optional, Tuple, Union import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union, cast
import dns._features import dns._features
import dns.exception import dns.exception
@ -129,7 +131,7 @@ if _have_httpx:
family=socket.AF_UNSPEC, family=socket.AF_UNSPEC,
**kwargs, **kwargs,
): ):
if resolver is None: if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name # pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver import dns.resolver
@ -217,7 +219,7 @@ def _wait_for(fd, readable, writable, _, expiration):
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
return True return True
sel = _selector_class() sel = selectors.DefaultSelector()
events = 0 events = 0
if readable: if readable:
events |= selectors.EVENT_READ events |= selectors.EVENT_READ
@ -235,26 +237,6 @@ def _wait_for(fd, readable, writable, _, expiration):
raise dns.exception.Timeout raise dns.exception.Timeout
def _set_selector_class(selector_class):
# Internal API. Do not use.
global _selector_class
_selector_class = selector_class
if hasattr(selectors, "PollSelector"):
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
#
# We ignore typing here as we can't say _selector_class is Any
# on python < 3.8 due to a bug.
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
def _wait_for_readable(s, expiration): def _wait_for_readable(s, expiration):
_wait_for(s, True, False, True, expiration) _wait_for(s, True, False, True, expiration)
@ -355,6 +337,36 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
raise raise
def _maybe_get_resolver(
resolver: Optional["dns.resolver.Resolver"],
) -> "dns.resolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
resolver = dns.resolver.Resolver()
return resolver
class HTTPVersion(enum.IntEnum):
"""Which version of HTTP should be used?
DEFAULT will select the first version from the list [2, 1.1, 3] that
is available.
"""
DEFAULT = 0
HTTP_1 = 1
H1 = 1
HTTP_2 = 2
H2 = 2
HTTP_3 = 3
H3 = 3
def https( def https(
q: dns.message.Message, q: dns.message.Message,
where: str, where: str,
@ -370,7 +382,8 @@ def https(
bootstrap_address: Optional[str] = None, bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None, resolver: Optional["dns.resolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC, family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
@ -420,27 +433,66 @@ def https(
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved. and AAAA records will be retrieved.
*http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
else:
url = where
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
return _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)
if not have_doh: if not have_doh:
raise NoDOH # pragma: no cover raise NoDOH # pragma: no cover
if session and not isinstance(session, httpx.Client): if session and not isinstance(session, httpx.Client):
raise ValueError("session parameter must be an httpx.Client") raise ValueError("session parameter must be an httpx.Client")
wire = q.to_wire() wire = q.to_wire()
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
transport = None
headers = {"accept": "application/dns-message"} headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET: h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
url = "https://{}:{}{}".format(where, port, path) h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
# set source port and source address # set source port and source address
@ -450,21 +502,22 @@ def https(
else: else:
local_address = the_source[0] local_address = the_source[0]
local_port = the_source[1] local_port = the_source[1]
transport = _HTTPTransport(
local_address=local_address,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if session: if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else: else:
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) transport = _HTTPTransport(
local_address=local_address,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
with cm as session: with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
@ -475,20 +528,30 @@ def https(
"content-length": str(len(wire)), "content-length": str(len(wire)),
} }
) )
response = session.post(url, headers=headers, content=wire, timeout=timeout) response = session.post(
url,
headers=headers,
content=wire,
timeout=timeout,
extensions=extensions,
)
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes twire = wire.decode() # httpx does a repr() if we give it bytes
response = session.get( response = session.get(
url, headers=headers, timeout=timeout, params={"dns": twire} url,
headers=headers,
timeout=timeout,
params={"dns": twire},
extensions=extensions,
) )
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
if response.status_code < 200 or response.status_code > 299: if response.status_code < 200 or response.status_code > 299:
raise ValueError( raise ValueError(
"{} responded with status code {}" f"{where} responded with status code {response.status_code}"
"\nResponse body: {}".format(where, response.status_code, response.content) f"\nResponse body: {response.content}"
) )
r = dns.message.from_wire( r = dns.message.from_wire(
response.content, response.content,
@ -503,6 +566,81 @@ def https(
return r return r
def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
if headers is None:
raise KeyError
for header, value in headers:
if header == name:
return value
raise KeyError
def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
value = _find_header(headers, b":status")
if value is None:
raise SyntaxError("no :status header in response")
status = int(value)
if status < 0:
raise SyntaxError("status is negative")
if status < 200 or status > 299:
error = ""
if len(wire) > 0:
try:
error = ": " + wire.decode()
except Exception:
pass
raise ValueError(f"{peer} responded with status code {status}{error}")
def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=hostname, h3=True
)
with manager:
connection = manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
with connection.make_stream(timeout) as stream:
stream.send_h3(url, wire, post)
wire = stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
def _udp_recv(sock, max_size, expiration): def _udp_recv(sock, max_size, expiration):
"""Reads a datagram from the socket. """Reads a datagram from the socket.
A Timeout exception will be raised if the operation is not completed A Timeout exception will be raised if the operation is not completed
@ -855,7 +993,7 @@ def _net_read(sock, count, expiration):
try: try:
n = sock.recv(count) n = sock.recv(count)
if n == b"": if n == b"":
raise EOFError raise EOFError("EOF")
count -= len(n) count -= len(n)
s += n s += n
except (BlockingIOError, ssl.SSLWantReadError): except (BlockingIOError, ssl.SSLWantReadError):
@ -1023,6 +1161,7 @@ def tcp(
cm = _make_socket(af, socket.SOCK_STREAM, source) cm = _make_socket(af, socket.SOCK_STREAM, source)
with cm as s: with cm as s:
if not sock: if not sock:
# pylint: disable=possibly-used-before-assignment
_connect(s, destination, expiration) _connect(s, destination, expiration)
send_tcp(s, wire, expiration) send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp( (r, received_time) = receive_tcp(
@ -1188,6 +1327,7 @@ def quic(
ignore_trailing: bool = False, ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None, connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True, verify: Union[bool, str] = True,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None, server_hostname: Optional[str] = None,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC. """Return the response obtained after sending a query via DNS-over-QUIC.
@ -1212,17 +1352,21 @@ def quic(
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message. received message.
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
connection to use to send the query. to send the query.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification. directory which will be used for verification.
*server_hostname*, a ``str`` containing the server's hostname. The *hostname*, a ``str`` containing the server's hostname or ``None``. The default is
default is ``None``, which means that no hostname is known, and if an ``None``, which means that no hostname is known, and if an SSL context is created,
SSL context is created, hostname checking will be disabled. hostname checking will be disabled. This value is ignored if *url* is not
``None``.
*server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
only, and has the same meaning as *hostname*.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
@ -1230,6 +1374,9 @@ def quic(
if not dns.quic.have_quic: if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0 q.id = 0
wire = q.to_wire() wire = q.to_wire()
the_connection: dns.quic.SyncQuicConnection the_connection: dns.quic.SyncQuicConnection
@ -1238,9 +1385,7 @@ def quic(
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection the_connection = connection
else: else:
manager = dns.quic.SyncQuicManager( manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
verify_mode=verify, server_name=server_hostname
)
the_manager = manager # for type checking happiness the_manager = manager # for type checking happiness
with manager: with manager:
@ -1264,6 +1409,70 @@ def quic(
return r return r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: socket.socket,
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
def xfr( def xfr(
where: str, where: str,
zone: Union[dns.name.Name, str], zone: Union[dns.name.Name, str],
@ -1333,134 +1542,52 @@ def xfr(
Returns a generator of ``dns.message.Message`` objects. Returns a generator of ``dns.message.Message`` objects.
""" """
class DummyTransactionManager(dns.transaction.TransactionManager):
def __init__(self, origin, relativize):
self.info = (origin, relativize, dns.name.empty if relativize else origin)
def origin_information(self):
return self.info
def get_class(self) -> dns.rdataclass.RdataClass:
raise NotImplementedError # pragma: no cover
def reader(self):
raise NotImplementedError # pragma: no cover
def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
class DummyTransaction:
def nop(self, *args, **kw):
pass
def __getattr__(self, _):
return self.nop
return cast(dns.transaction.Transaction, DummyTransaction())
if isinstance(zone, str): if isinstance(zone, str):
zone = dns.name.from_text(zone) zone = dns.name.from_text(zone)
rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
q = dns.message.make_query(zone, rdtype, rdclass) q = dns.message.make_query(zone, rdtype, rdclass)
if rdtype == dns.rdatatype.IXFR: if rdtype == dns.rdatatype.IXFR:
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) rrset = q.find_rrset(
q.authority.append(rrset) q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
)
soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
rrset.add(soa, 0)
if keyring is not None: if keyring is not None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm) q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
wire = q.to_wire()
(af, destination, source) = _destination_and_source( (af, destination, source) = _destination_and_source(
where, port, source, source_port where, port, source, source_port
) )
(_, expiration) = _compute_times(lifetime)
tm = DummyTransactionManager(zone, relativize)
if use_udp and rdtype != dns.rdatatype.IXFR: if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError("cannot do a UDP AXFR") raise ValueError("cannot do a UDP AXFR")
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
with _make_socket(af, sock_type, source) as s: with _make_socket(af, sock_type, source) as s:
(_, expiration) = _compute_times(lifetime)
_connect(s, destination, expiration) _connect(s, destination, expiration)
l = len(wire) yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
if use_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
done = False
delete_mode = True
expecting_SOA = False
soa_rrset = None
if relativize:
origin = zone
oname = dns.name.empty
else:
origin = None
oname = zone
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if use_udp:
(wire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
wire = _net_read(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=True,
one_rr_per_rrset=is_ixfr,
)
rcode = r.rcode()
if rcode != dns.rcode.NOERROR:
raise TransferError(rcode)
tsig_ctx = r.tsig_ctx
answer_index = 0
if soa_rrset is None:
if not r.answer or r.answer[0].name != oname:
raise dns.exception.FormError("No answer or RRset not for qname")
rrset = r.answer[0]
if rrset.rdtype != dns.rdatatype.SOA:
raise dns.exception.FormError("first RRset is not an SOA")
answer_index = 1
soa_rrset = rrset.copy()
if rdtype == dns.rdatatype.IXFR:
if dns.serial.Serial(soa_rrset[0].serial) <= serial:
#
# We're already up-to-date.
#
done = True
else:
expecting_SOA = True
#
# Process SOAs in the answer section (other than the initial
# SOA in the first message).
#
for rrset in r.answer[answer_index:]:
if done:
raise dns.exception.FormError("answers after final SOA")
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
if expecting_SOA:
if rrset[0].serial != serial:
raise dns.exception.FormError("IXFR base serial mismatch")
expecting_SOA = False
elif rdtype == dns.rdatatype.IXFR:
delete_mode = not delete_mode
#
# If this SOA RRset is equal to the first we saw then we're
# finished. If this is an IXFR we also check that we're
# seeing the record in the expected part of the response.
#
if rrset == soa_rrset and (
rdtype == dns.rdatatype.AXFR
or (rdtype == dns.rdatatype.IXFR and delete_mode)
):
done = True
elif expecting_SOA:
#
# We made an IXFR request and are expecting another
# SOA RR, but saw something else, so this must be an
# AXFR response.
#
rdtype = dns.rdatatype.AXFR
expecting_SOA = False
if done and q.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
yield r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def inbound_xfr( def inbound_xfr(
@ -1514,65 +1641,25 @@ def inbound_xfr(
(query, serial) = dns.xfr.make_query(txn_manager) (query, serial) = dns.xfr.make_query(txn_manager)
else: else:
serial = dns.xfr.extract_serial_from_query(query) serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
(af, destination, source) = _destination_and_source( (af, destination, source) = _destination_and_source(
where, port, source, source_port where, port, source, source_port
) )
(_, expiration) = _compute_times(lifetime) (_, expiration) = _compute_times(lifetime)
retry = True if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
while retry: with _make_socket(af, socket.SOCK_DGRAM, source) as s:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
with _make_socket(af, sock_type, source) as s:
_connect(s, destination, expiration) _connect(s, destination, expiration)
if is_udp: try:
_udp_send(s, wire, None, expiration) for _ in _inbound_xfr(
else: txn_manager, s, query, serial, timeout, expiration
tcpmsg = struct.pack("!H", len(wire)) + wire ):
_net_write(s, tcpmsg, expiration) pass
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: return
done = False except dns.xfr.UseTCP:
tsig_ctx = None if udp_mode == UDPMode.ONLY:
while not done: raise
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or ( with _make_socket(af, socket.SOCK_STREAM, source) as s:
expiration is not None and mexpiration > expiration _connect(s, destination, expiration)
): for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
mexpiration = expiration pass
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")

View file

@ -1,5 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import List, Tuple
import dns._features import dns._features
import dns.asyncbackend import dns.asyncbackend
@ -73,3 +75,6 @@ else: # pragma: no cover
class SyncQuicConnection: # type: ignore class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any: def make_stream(self) -> Any:
raise NotImplementedError raise NotImplementedError
Headers = List[Tuple[bytes, bytes]]

View file

@ -43,12 +43,26 @@ class AsyncioQuicStream(BaseQuicStream):
raise dns.exception.Timeout raise dns.exception.Timeout
self._expecting = 0 self._expecting = 0
async def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.seen_end():
return
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
async def receive(self, timeout=None): async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout) expiration = self._expiration_from_timeout(timeout)
await self.wait_for(2, expiration) if self._connection.is_h3():
(size,) = struct.unpack("!H", self._buffer.get(2)) await self.wait_for_end(expiration)
await self.wait_for(size, expiration) return self._buffer.get_all()
return self._buffer.get(size) else:
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False): async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram) data = self._encapsulate(datagram)
@ -83,6 +97,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer = asyncio.Condition() self._wake_timer = asyncio.Condition()
self._receiver_task = None self._receiver_task = None
self._sender_task = None self._sender_task = None
self._wake_pending = False
async def _receiver(self): async def _receiver(self):
try: try:
@ -104,19 +119,24 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.receive_datagram(datagram, address, time.time()) self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be # Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now. # stuff to send now.
async with self._wake_timer: await self._wakeup()
self._wake_timer.notify_all()
except Exception: except Exception:
pass pass
finally: finally:
self._done = True self._done = True
async with self._wake_timer: await self._wakeup()
self._wake_timer.notify_all()
self._handshake_complete.set() self._handshake_complete.set()
async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
async def _wait_for_wake_timer(self): async def _wait_for_wake_timer(self):
async with self._wake_timer: async with self._wake_timer:
await self._wake_timer.wait() if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False
async def _sender(self): async def _sender(self):
await self._socket_created.wait() await self._socket_created.wait()
@ -140,9 +160,28 @@ class AsyncioQuicConnection(AsyncQuicConnection):
if event is None: if event is None:
return return
if isinstance(event, aioquic.quic.events.StreamDataReceived): if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id) if self.is_h3():
if stream: h3_events = self._h3_conn.handle_event(event)
await stream._add_input(event.data, event.end_stream) for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted): elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set() self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated): elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@ -161,8 +200,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
async def write(self, stream, data, is_end=False): async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end) self._connection.send_stream_data(stream, data, is_end)
async with self._wake_timer: await self._wakeup()
self._wake_timer.notify_all()
def run(self): def run(self):
if self._closed: if self._closed:
@ -189,8 +227,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.close() self._connection.close()
# sender might be blocked on this, so set it # sender might be blocked on this, so set it
self._socket_created.set() self._socket_created.set()
async with self._wake_timer: await self._wakeup()
self._wake_timer.notify_all()
try: try:
await self._receiver_task await self._receiver_task
except asyncio.CancelledError: except asyncio.CancelledError:
@ -203,8 +240,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
class AsyncioQuicManager(AsyncQuicManager): class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): def __init__(
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def connect( def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True self, address, port=853, source=None, source_port=0, want_session_ticket=True

View file

@ -1,12 +1,16 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import base64
import copy import copy
import functools import functools
import socket import socket
import struct import struct
import time import time
import urllib
from typing import Any, Optional from typing import Any, Optional
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore import aioquic.quic.connection # type: ignore
@ -51,6 +55,12 @@ class Buffer:
self._buffer = self._buffer[amount:] self._buffer = self._buffer[amount:]
return data return data
def get_all(self):
assert self.seen_end()
data = self._buffer
self._buffer = b""
return data
class BaseQuicStream: class BaseQuicStream:
def __init__(self, connection, stream_id): def __init__(self, connection, stream_id):
@ -58,10 +68,18 @@ class BaseQuicStream:
self._stream_id = stream_id self._stream_id = stream_id
self._buffer = Buffer() self._buffer = Buffer()
self._expecting = 0 self._expecting = 0
self._headers = None
self._trailers = None
def id(self): def id(self):
return self._stream_id return self._stream_id
def headers(self):
return self._headers
def trailers(self):
return self._trailers
def _expiration_from_timeout(self, timeout): def _expiration_from_timeout(self, timeout):
if timeout is not None: if timeout is not None:
expiration = time.time() + timeout expiration = time.time() + timeout
@ -77,16 +95,51 @@ class BaseQuicStream:
return timeout return timeout
# Subclass must implement receive() as sync / async and which returns a message # Subclass must implement receive() as sync / async and which returns a message
# or raises UnexpectedEOF. # or raises.
# Subclass must implement send() as sync / async and which takes a message and
# an EOF indicator.
def send_h3(self, url, datagram, post=True):
if not self._connection.is_h3():
raise SyntaxError("cannot send H3 to a non-H3 connection")
url_parts = urllib.parse.urlparse(url)
path = url_parts.path.encode()
if post:
method = b"POST"
else:
method = b"GET"
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
headers = [
(b":method", method),
(b":scheme", url_parts.scheme.encode()),
(b":authority", url_parts.netloc.encode()),
(b":path", path),
(b"accept", b"application/dns-message"),
]
if post:
headers.extend(
[
(b"content-type", b"application/dns-message"),
(b"content-length", str(len(datagram)).encode()),
]
)
self._connection.send_headers(self._stream_id, headers, not post)
if post:
self._connection.send_data(self._stream_id, datagram, True)
def _encapsulate(self, datagram): def _encapsulate(self, datagram):
if self._connection.is_h3():
return datagram
l = len(datagram) l = len(datagram)
return struct.pack("!H", l) + datagram return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end): def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end) self._buffer.put(data, is_end)
try: try:
return self._expecting > 0 and self._buffer.have(self._expecting) return (
self._expecting > 0 and self._buffer.have(self._expecting)
) or self._buffer.seen_end
except UnexpectedEOF: except UnexpectedEOF:
return True return True
@ -97,7 +150,13 @@ class BaseQuicStream:
class BaseQuicConnection: class BaseQuicConnection:
def __init__( def __init__(
self, connection, address, port, source=None, source_port=0, manager=None self,
connection,
address,
port,
source=None,
source_port=0,
manager=None,
): ):
self._done = False self._done = False
self._connection = connection self._connection = connection
@ -106,6 +165,10 @@ class BaseQuicConnection:
self._closed = False self._closed = False
self._manager = manager self._manager = manager
self._streams = {} self._streams = {}
if manager.is_h3():
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
else:
self._h3_conn = None
self._af = dns.inet.af_for_address(address) self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port)) self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0: if source is None and source_port != 0:
@ -120,9 +183,18 @@ class BaseQuicConnection:
else: else:
self._source = None self._source = None
def is_h3(self):
return self._h3_conn is not None
def close_stream(self, stream_id): def close_stream(self, stream_id):
del self._streams[stream_id] del self._streams[stream_id]
def send_headers(self, stream_id, headers, is_end=False):
self._h3_conn.send_headers(stream_id, headers, is_end)
def send_data(self, stream_id, data, is_end=False):
self._h3_conn.send_data(stream_id, data, is_end)
def _get_timer_values(self, closed_is_special=True): def _get_timer_values(self, closed_is_special=True):
now = time.time() now = time.time()
expiration = self._connection.get_timer() expiration = self._connection.get_timer()
@ -148,17 +220,25 @@ class AsyncQuicConnection(BaseQuicConnection):
class BaseQuicManager: class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory, server_name=None): def __init__(
self, conf, verify_mode, connection_factory, server_name=None, h3=False
):
self._connections = {} self._connections = {}
self._connection_factory = connection_factory self._connection_factory = connection_factory
self._session_tickets = {} self._session_tickets = {}
self._tokens = {}
self._h3 = h3
if conf is None: if conf is None:
verify_path = None verify_path = None
if isinstance(verify_mode, str): if isinstance(verify_mode, str):
verify_path = verify_mode verify_path = verify_mode
verify_mode = True verify_mode = True
if h3:
alpn_protocols = ["h3"]
else:
alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration( conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=["doq", "doq-i03"], alpn_protocols=alpn_protocols,
verify_mode=verify_mode, verify_mode=verify_mode,
server_name=server_name, server_name=server_name,
) )
@ -167,7 +247,13 @@ class BaseQuicManager:
self._conf = conf self._conf = conf
def _connect( def _connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
): ):
connection = self._connections.get((address, port)) connection = self._connections.get((address, port))
if connection is not None: if connection is not None:
@ -189,9 +275,24 @@ class BaseQuicManager:
) )
else: else:
session_ticket_handler = None session_ticket_handler = None
if want_token:
try:
token = self._tokens.pop((address, port))
# We found a token, so make a configuration that uses it.
conf = copy.copy(conf)
conf.token = token
except KeyError:
# No token
pass
# Whether or not we found a token, we want a handler to save # one.
token_handler = functools.partial(self.save_token, address, port)
else:
token_handler = None
qconn = aioquic.quic.connection.QuicConnection( qconn = aioquic.quic.connection.QuicConnection(
configuration=conf, configuration=conf,
session_ticket_handler=session_ticket_handler, session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
) )
lladdress = dns.inet.low_level_address_tuple((address, port)) lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time()) qconn.connect(lladdress, time.time())
@ -207,6 +308,9 @@ class BaseQuicManager:
except KeyError: except KeyError:
pass pass
def is_h3(self):
return self._h3
def save_session_ticket(self, address, port, ticket): def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We # We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of # can't just popitem() as that would be LIFO which is the opposite of
@ -218,6 +322,17 @@ class BaseQuicManager:
del self._session_tickets[key] del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket self._session_tickets[(address, port)] = ticket
def save_token(self, address, port, token):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._tokens)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._tokens[key]
self._tokens[(address, port)] = token
class AsyncQuicManager(BaseQuicManager): class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):

View file

@ -21,11 +21,9 @@ from dns.quic._common import (
UnexpectedEOF, UnexpectedEOF,
) )
# Avoid circularity with dns.query # Function used to create a socket. Can be overridden if needed in special
if hasattr(selectors, "PollSelector"): # situations.
_selector_class = selectors.PollSelector # type: ignore socket_factory = socket.socket
else:
_selector_class = selectors.SelectSelector # type: ignore
class SyncQuicStream(BaseQuicStream): class SyncQuicStream(BaseQuicStream):
@ -46,14 +44,29 @@ class SyncQuicStream(BaseQuicStream):
raise dns.exception.Timeout raise dns.exception.Timeout
self._expecting = 0 self._expecting = 0
def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.seen_end():
return
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
def receive(self, timeout=None): def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout) expiration = self._expiration_from_timeout(timeout)
self.wait_for(2, expiration) if self._connection.is_h3():
with self._lock: self.wait_for_end(expiration)
(size,) = struct.unpack("!H", self._buffer.get(2)) with self._lock:
self.wait_for(size, expiration) return self._buffer.get_all()
with self._lock: else:
return self._buffer.get(size) self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False): def send(self, datagram, is_end=False):
data = self._encapsulate(datagram) data = self._encapsulate(datagram)
@ -81,7 +94,7 @@ class SyncQuicStream(BaseQuicStream):
class SyncQuicConnection(BaseQuicConnection): class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager): def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager) super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None: if self._source is not None:
try: try:
self._socket.bind( self._socket.bind(
@ -118,7 +131,7 @@ class SyncQuicConnection(BaseQuicConnection):
def _worker(self): def _worker(self):
try: try:
sel = _selector_class() sel = selectors.DefaultSelector()
sel.register(self._socket, selectors.EVENT_READ, self._read) sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
while not self._done: while not self._done:
@ -140,6 +153,7 @@ class SyncQuicConnection(BaseQuicConnection):
finally: finally:
with self._lock: with self._lock:
self._done = True self._done = True
self._socket.close()
# Ensure anyone waiting for this gets woken up. # Ensure anyone waiting for this gets woken up.
self._handshake_complete.set() self._handshake_complete.set()
@ -150,10 +164,29 @@ class SyncQuicConnection(BaseQuicConnection):
if event is None: if event is None:
return return
if isinstance(event, aioquic.quic.events.StreamDataReceived): if isinstance(event, aioquic.quic.events.StreamDataReceived):
with self._lock: if self.is_h3():
stream = self._streams.get(event.stream_id) h3_events = self._h3_conn.handle_event(event)
if stream: for h3_event in h3_events:
stream._add_input(event.data, event.end_stream) if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(h3_event.data, h3_event.stream_ended)
else:
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted): elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set() self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated): elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@ -170,6 +203,18 @@ class SyncQuicConnection(BaseQuicConnection):
self._connection.send_stream_data(stream, data, is_end) self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01") self._send_wakeup.send(b"\x01")
def send_headers(self, stream_id, headers, is_end=False):
with self._lock:
super().send_headers(stream_id, headers, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def send_data(self, stream_id, data, is_end=False):
with self._lock:
super().send_data(stream_id, data, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def run(self): def run(self):
if self._closed: if self._closed:
return return
@ -203,16 +248,24 @@ class SyncQuicConnection(BaseQuicConnection):
class SyncQuicManager(BaseQuicManager): class SyncQuicManager(BaseQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): def __init__(
super().__init__(conf, verify_mode, SyncQuicConnection, server_name) self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
self._lock = threading.Lock() self._lock = threading.Lock()
def connect( def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
): ):
with self._lock: with self._lock:
(connection, start) = self._connect( (connection, start) = self._connect(
address, port, source, source_port, want_session_ticket address, port, source, source_port, want_session_ticket, want_token
) )
if start: if start:
connection.run() connection.run()
@ -226,6 +279,10 @@ class SyncQuicManager(BaseQuicManager):
with self._lock: with self._lock:
super().save_session_ticket(address, port, ticket) super().save_session_ticket(address, port, ticket)
def save_token(self, address, port, token):
with self._lock:
super().save_token(address, port, token)
def __enter__(self): def __enter__(self):
return self return self

View file

@ -36,16 +36,27 @@ class TrioQuicStream(BaseQuicStream):
await self._wake_up.wait() await self._wake_up.wait()
self._expecting = 0 self._expecting = 0
async def wait_for_end(self):
while True:
if self._buffer.seen_end():
return
async with self._wake_up:
await self._wake_up.wait()
async def receive(self, timeout=None): async def receive(self, timeout=None):
if timeout is None: if timeout is None:
context = NullContext(None) context = NullContext(None)
else: else:
context = trio.move_on_after(timeout) context = trio.move_on_after(timeout)
with context: with context:
await self.wait_for(2) if self._connection.is_h3():
(size,) = struct.unpack("!H", self._buffer.get(2)) await self.wait_for_end()
await self.wait_for(size) return self._buffer.get_all()
return self._buffer.get(size) else:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
raise dns.exception.Timeout raise dns.exception.Timeout
async def send(self, datagram, is_end=False): async def send(self, datagram, is_end=False):
@ -115,6 +126,7 @@ class TrioQuicConnection(AsyncQuicConnection):
await self._socket.send(datagram) await self._socket.send(datagram)
finally: finally:
self._done = True self._done = True
self._socket.close()
self._handshake_complete.set() self._handshake_complete.set()
async def _handle_events(self): async def _handle_events(self):
@ -124,9 +136,28 @@ class TrioQuicConnection(AsyncQuicConnection):
if event is None: if event is None:
return return
if isinstance(event, aioquic.quic.events.StreamDataReceived): if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id) if self.is_h3():
if stream: h3_events = self._h3_conn.handle_event(event)
await stream._add_input(event.data, event.end_stream) for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted): elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set() self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated): elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@ -183,9 +214,14 @@ class TrioQuicConnection(AsyncQuicConnection):
class TrioQuicManager(AsyncQuicManager): class TrioQuicManager(AsyncQuicManager):
def __init__( def __init__(
self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None self,
nursery,
conf=None,
verify_mode=ssl.CERT_REQUIRED,
server_name=None,
h3=False,
): ):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name) super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
self._nursery = nursery self._nursery = nursery
def connect( def connect(

View file

@ -214,7 +214,7 @@ class Rdata:
compress: Optional[dns.name.CompressType] = None, compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
canonicalize: bool = False, canonicalize: bool = False,
) -> bytes: ) -> None:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def to_wire( def to_wire(
@ -223,14 +223,19 @@ class Rdata:
compress: Optional[dns.name.CompressType] = None, compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
canonicalize: bool = False, canonicalize: bool = False,
) -> bytes: ) -> Optional[bytes]:
"""Convert an rdata to wire format. """Convert an rdata to wire format.
Returns a ``bytes`` or ``None``. Returns a ``bytes`` if no output file was specified, or ``None`` otherwise.
""" """
if file: if file:
return self._to_wire(file, compress, origin, canonicalize) # We call _to_wire() and then return None explicitly instead of
# of just returning the None from _to_wire() as mypy's func-returns-value
# unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return
# a value (it only ever returns None)"
self._to_wire(file, compress, origin, canonicalize)
return None
else: else:
f = io.BytesIO() f = io.BytesIO()
self._to_wire(f, compress, origin, canonicalize) self._to_wire(f, compress, origin, canonicalize)
@ -253,8 +258,9 @@ class Rdata:
Returns a ``bytes``. Returns a ``bytes``.
""" """
wire = self.to_wire(origin=origin, canonicalize=True)
return self.to_wire(origin=origin, canonicalize=True) assert wire is not None # for mypy
return wire
def __repr__(self): def __repr__(self):
covers = self.covers() covers = self.covers()
@ -434,15 +440,11 @@ class Rdata:
continue continue
if key not in parameters: if key not in parameters:
raise AttributeError( raise AttributeError(
"'{}' object has no attribute '{}'".format( f"'{self.__class__.__name__}' object has no attribute '{key}'"
self.__class__.__name__, key
)
) )
if key in ("rdclass", "rdtype"): if key in ("rdclass", "rdtype"):
raise AttributeError( raise AttributeError(
"Cannot overwrite '{}' attribute '{}'".format( f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'"
self.__class__.__name__, key
)
) )
# Construct the parameter list. For each field, use the value in # Construct the parameter list. For each field, use the value in
@ -646,13 +648,14 @@ _rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType],
{} {}
) )
_module_prefix = "dns.rdtypes" _module_prefix = "dns.rdtypes"
_dynamic_load_allowed = True
def get_rdata_class(rdclass, rdtype): def get_rdata_class(rdclass, rdtype, use_generic=True):
cls = _rdata_classes.get((rdclass, rdtype)) cls = _rdata_classes.get((rdclass, rdtype))
if not cls: if not cls:
cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype)) cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
if not cls: if not cls and _dynamic_load_allowed:
rdclass_text = dns.rdataclass.to_text(rdclass) rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype) rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace("-", "_") rdtype_text = rdtype_text.replace("-", "_")
@ -670,12 +673,36 @@ def get_rdata_class(rdclass, rdtype):
_rdata_classes[(rdclass, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls
except ImportError: except ImportError:
pass pass
if not cls: if not cls and use_generic:
cls = GenericRdata cls = GenericRdata
_rdata_classes[(rdclass, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls
return cls return cls
def load_all_types(disable_dynamic_load=True):
"""Load all rdata types for which dnspython has a non-generic implementation.
Normally dnspython loads DNS rdatatype implementations on demand, but in some
specialized cases loading all types at an application-controlled time is preferred.
If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt
to use its dynamic loading mechanism if an unknown type is subsequently encountered,
and will simply use the ``GenericRdata`` class.
"""
# Load class IN and ANY types.
for rdtype in dns.rdatatype.RdataType:
get_rdata_class(dns.rdataclass.IN, rdtype, False)
# Load the one non-ANY implementation we have in CH. Everything
# else in CH is an ANY type, and we'll discover those on demand but won't
# have to import anything.
get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False)
if disable_dynamic_load:
# Now disable dynamic loading so any subsequent unknown type immediately becomes
# GenericRdata without a load attempt.
global _dynamic_load_allowed
_dynamic_load_allowed = False
def from_text( def from_text(
rdclass: Union[dns.rdataclass.RdataClass, str], rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str], rdtype: Union[dns.rdatatype.RdataType, str],

View file

@ -160,7 +160,7 @@ class Rdataset(dns.set.Set):
return s[:100] + "..." return s[:100] + "..."
return s return s
return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self) return "[" + ", ".join(f"<{maybe_truncate(str(rr))}>" for rr in self) + "]"
def __repr__(self): def __repr__(self):
if self.covers == 0: if self.covers == 0:
@ -248,12 +248,8 @@ class Rdataset(dns.set.Set):
# (which is meaningless anyway). # (which is meaningless anyway).
# #
s.write( s.write(
"{}{}{} {}\n".format( f"{ntext}{pad}{dns.rdataclass.to_text(rdclass)} "
ntext, f"{dns.rdatatype.to_text(self.rdtype)}\n"
pad,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype),
)
) )
else: else:
for rd in self: for rd in self:

View file

@ -105,6 +105,8 @@ class RdataType(dns.enum.IntEnum):
CAA = 257 CAA = 257
AVC = 258 AVC = 258
AMTRELAY = 260 AMTRELAY = 260
RESINFO = 261
WALLET = 262
TA = 32768 TA = 32768
DLV = 32769 DLV = 32769
@ -125,7 +127,7 @@ class RdataType(dns.enum.IntEnum):
if text.find("-") >= 0: if text.find("-") >= 0:
try: try:
return cls[text.replace("-", "_")] return cls[text.replace("-", "_")]
except KeyError: except KeyError: # pragma: no cover
pass pass
return _registered_by_text.get(text) return _registered_by_text.get(text)
@ -326,6 +328,8 @@ URI = RdataType.URI
CAA = RdataType.CAA CAA = RdataType.CAA
AVC = RdataType.AVC AVC = RdataType.AVC
AMTRELAY = RdataType.AMTRELAY AMTRELAY = RdataType.AMTRELAY
RESINFO = RdataType.RESINFO
WALLET = RdataType.WALLET
TA = RdataType.TA TA = RdataType.TA
DLV = RdataType.DLV DLV = RdataType.DLV

View file

@ -75,8 +75,9 @@ class GPOS(dns.rdata.Rdata):
raise dns.exception.FormError("bad longitude") raise dns.exception.FormError("bad longitude")
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return "{} {} {}".format( return (
self.latitude.decode(), self.longitude.decode(), self.altitude.decode() f"{self.latitude.decode()} {self.longitude.decode()} "
f"{self.altitude.decode()}"
) )
@classmethod @classmethod

View file

@ -37,9 +37,7 @@ class HINFO(dns.rdata.Rdata):
self.os = self._as_bytes(os, True, 255) self.os = self._as_bytes(os, True, 255)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '"{}" "{}"'.format( return f'"{dns.rdata._escapify(self.cpu)}" "{dns.rdata._escapify(self.os)}"'
dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
)
@classmethod @classmethod
def from_text( def from_text(

View file

@ -48,7 +48,7 @@ class HIP(dns.rdata.Rdata):
for server in self.servers: for server in self.servers:
servers.append(server.choose_relativity(origin, relativize)) servers.append(server.choose_relativity(origin, relativize))
if len(servers) > 0: if len(servers) > 0:
text += " " + " ".join((x.to_unicode() for x in servers)) text += " " + " ".join(x.to_unicode() for x in servers)
return "%u %s %s%s" % (self.algorithm, hit, key, text) return "%u %s %s%s" % (self.algorithm, hit, key, text)
@classmethod @classmethod

View file

@ -38,11 +38,12 @@ class ISDN(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress: if self.subaddress:
return '"{}" "{}"'.format( return (
dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress) f'"{dns.rdata._escapify(self.address)}" '
f'"{dns.rdata._escapify(self.subaddress)}"'
) )
else: else:
return '"%s"' % dns.rdata._escapify(self.address) return f'"{dns.rdata._escapify(self.address)}"'
@classmethod @classmethod
def from_text( def from_text(

View file

@ -44,7 +44,7 @@ def _exponent_of(what, desc):
exp = i - 1 exp = i - 1
break break
if exp is None or exp < 0: if exp is None or exp < 0:
raise dns.exception.SyntaxError("%s value out of bounds" % desc) raise dns.exception.SyntaxError(f"{desc} value out of bounds")
return exp return exp
@ -83,10 +83,10 @@ def _encode_size(what, desc):
def _decode_size(what, desc): def _decode_size(what, desc):
exponent = what & 0x0F exponent = what & 0x0F
if exponent > 9: if exponent > 9:
raise dns.exception.FormError("bad %s exponent" % desc) raise dns.exception.FormError(f"bad {desc} exponent")
base = (what & 0xF0) >> 4 base = (what & 0xF0) >> 4
if base > 9: if base > 9:
raise dns.exception.FormError("bad %s base" % desc) raise dns.exception.FormError(f"bad {desc} base")
return base * pow(10, exponent) return base * pow(10, exponent)
@ -184,10 +184,9 @@ class LOC(dns.rdata.Rdata):
or self.horizontal_precision != _default_hprec or self.horizontal_precision != _default_hprec
or self.vertical_precision != _default_vprec or self.vertical_precision != _default_vprec
): ):
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( text += (
self.size / 100.0, f" {self.size / 100.0:0.2f}m {self.horizontal_precision / 100.0:0.2f}m"
self.horizontal_precision / 100.0, f" {self.vertical_precision / 100.0:0.2f}m"
self.vertical_precision / 100.0,
) )
return text return text

View file

@ -44,7 +44,7 @@ class NSEC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize) next = self.next.choose_relativity(origin, relativize)
text = Bitmap(self.windows).to_text() text = Bitmap(self.windows).to_text()
return "{}{}".format(next, text) return f"{next}{text}"
@classmethod @classmethod
def from_text( def from_text(

View file

@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class RESINFO(dns.rdtypes.txtbase.TXTBase):
"""RESINFO record"""

View file

@ -37,7 +37,7 @@ class RP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
mbox = self.mbox.choose_relativity(origin, relativize) mbox = self.mbox.choose_relativity(origin, relativize)
txt = self.txt.choose_relativity(origin, relativize) txt = self.txt.choose_relativity(origin, relativize)
return "{} {}".format(str(mbox), str(txt)) return f"{str(mbox)} {str(txt)}"
@classmethod @classmethod
def from_text( def from_text(

View file

@ -69,7 +69,7 @@ class TKEY(dns.rdata.Rdata):
dns.rdata._base64ify(self.key, 0), dns.rdata._base64ify(self.key, 0),
) )
if len(self.other) > 0: if len(self.other) > 0:
text += " %s" % (dns.rdata._base64ify(self.other, 0)) text += f" {dns.rdata._base64ify(self.other, 0)}"
return text return text

View file

@ -0,0 +1,9 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class WALLET(dns.rdtypes.txtbase.TXTBase):
"""WALLET record"""

View file

@ -36,7 +36,7 @@ class X25(dns.rdata.Rdata):
self.address = self._as_bytes(address, True, 255) self.address = self._as_bytes(address, True, 255)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '"%s"' % dns.rdata._escapify(self.address) return f'"{dns.rdata._escapify(self.address)}"'
@classmethod @classmethod
def from_text( def from_text(

View file

@ -51,6 +51,7 @@ __all__ = [
"OPENPGPKEY", "OPENPGPKEY",
"OPT", "OPT",
"PTR", "PTR",
"RESINFO",
"RP", "RP",
"RRSIG", "RRSIG",
"RT", "RT",
@ -63,6 +64,7 @@ __all__ = [
"TSIG", "TSIG",
"TXT", "TXT",
"URI", "URI",
"WALLET",
"X25", "X25",
"ZONEMD", "ZONEMD",
] ]

View file

@ -37,7 +37,7 @@ class A(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
domain = self.domain.choose_relativity(origin, relativize) domain = self.domain.choose_relativity(origin, relativize)
return "%s %o" % (domain, self.address) return f"{domain} {self.address:o}"
@classmethod @classmethod
def from_text( def from_text(

View file

@ -36,7 +36,7 @@ class NSAP(dns.rdata.Rdata):
self.address = self._as_bytes(address) self.address = self._as_bytes(address)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return "0x%s" % binascii.hexlify(self.address).decode() return f"0x{binascii.hexlify(self.address).decode()}"
@classmethod @classmethod
def from_text( def from_text(

View file

@ -36,7 +36,7 @@ class EUIBase(dns.rdata.Rdata):
self.eui = self._as_bytes(eui) self.eui = self._as_bytes(eui)
if len(self.eui) != self.byte_len: if len(self.eui) != self.byte_len:
raise dns.exception.FormError( raise dns.exception.FormError(
"EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len) f"EUI{self.byte_len * 8} rdata has to have {self.byte_len} bytes"
) )
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
@ -49,16 +49,16 @@ class EUIBase(dns.rdata.Rdata):
text = tok.get_string() text = tok.get_string()
if len(text) != cls.text_len: if len(text) != cls.text_len:
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
"Input text must have %s characters" % cls.text_len f"Input text must have {cls.text_len} characters"
) )
for i in range(2, cls.byte_len * 3 - 1, 3): for i in range(2, cls.byte_len * 3 - 1, 3):
if text[i] != "-": if text[i] != "-":
raise dns.exception.SyntaxError("Dash expected at position %s" % i) raise dns.exception.SyntaxError(f"Dash expected at position {i}")
text = text.replace("-", "") text = text.replace("-", "")
try: try:
data = binascii.unhexlify(text.encode()) data = binascii.unhexlify(text.encode())
except (ValueError, TypeError) as ex: except (ValueError, TypeError) as ex:
raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex)) raise dns.exception.SyntaxError(f"Hex decoding error: {str(ex)}")
return cls(rdclass, rdtype, data) return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -35,6 +35,7 @@ class ParamKey(dns.enum.IntEnum):
ECH = 5 ECH = 5
IPV6HINT = 6 IPV6HINT = 6
DOHPATH = 7 DOHPATH = 7
OHTTP = 8
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):
@ -396,6 +397,36 @@ class ECHParam(Param):
file.write(self.ech) file.write(self.ech)
@dns.immutable.immutable
class OHTTPParam(Param):
# We don't ever expect to instantiate this class, but we need
# a from_value() and a from_wire_parser(), so we just return None
# from the class methods when things are OK.
@classmethod
def emptiness(cls):
return Emptiness.ALWAYS
@classmethod
def from_value(cls, value):
if value is None or value == "":
return None
else:
raise ValueError("ohttp with non-empty value")
def to_text(self):
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
if parser.remaining() != 0:
raise dns.exception.FormError
return None
def to_wire(self, file, origin=None): # pylint: disable=W0613
raise NotImplementedError # pragma: no cover
_class_for_key = { _class_for_key = {
ParamKey.MANDATORY: MandatoryParam, ParamKey.MANDATORY: MandatoryParam,
ParamKey.ALPN: ALPNParam, ParamKey.ALPN: ALPNParam,
@ -404,6 +435,7 @@ _class_for_key = {
ParamKey.IPV4HINT: IPv4HintParam, ParamKey.IPV4HINT: IPv4HintParam,
ParamKey.ECH: ECHParam, ParamKey.ECH: ECHParam,
ParamKey.IPV6HINT: IPv6HintParam, ParamKey.IPV6HINT: IPv6HintParam,
ParamKey.OHTTP: OHTTPParam,
} }

View file

@ -50,6 +50,8 @@ class TXTBase(dns.rdata.Rdata):
self.strings: Tuple[bytes] = self._as_tuple( self.strings: Tuple[bytes] = self._as_tuple(
strings, lambda x: self._as_bytes(x, True, 255) strings, lambda x: self._as_bytes(x, True, 255)
) )
if len(self.strings) == 0:
raise ValueError("the list of strings must not be empty")
def to_text( def to_text(
self, self,
@ -60,7 +62,7 @@ class TXTBase(dns.rdata.Rdata):
txt = "" txt = ""
prefix = "" prefix = ""
for s in self.strings: for s in self.strings:
txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) txt += f'{prefix}"{dns.rdata._escapify(s)}"'
prefix = " " prefix = " "
return txt return txt

View file

@ -231,7 +231,7 @@ def weighted_processing_order(iterable):
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
while len(rdatas) > 1: while len(rdatas) > 1:
r = random.uniform(0, total) r = random.uniform(0, total)
for n, rdata in enumerate(rdatas): for n, rdata in enumerate(rdatas): # noqa: B007
weight = rdata._processing_weight() or _no_weight weight = rdata._processing_weight() or _no_weight
if weight > r: if weight > r:
break break

View file

@ -36,6 +36,7 @@ import dns.ipv4
import dns.ipv6 import dns.ipv6
import dns.message import dns.message
import dns.name import dns.name
import dns.rdata
import dns.nameserver import dns.nameserver
import dns.query import dns.query
import dns.rcode import dns.rcode
@ -45,7 +46,7 @@ import dns.rdtypes.svcbbase
import dns.reversename import dns.reversename
import dns.tsig import dns.tsig
if sys.platform == "win32": if sys.platform == "win32": # pragma: no cover
import dns.win32util import dns.win32util
@ -83,7 +84,7 @@ class NXDOMAIN(dns.exception.DNSException):
else: else:
msg = "The DNS query name does not exist" msg = "The DNS query name does not exist"
qnames = ", ".join(map(str, qnames)) qnames = ", ".join(map(str, qnames))
return "{}: {}".format(msg, qnames) return f"{msg}: {qnames}"
@property @property
def canonical_name(self): def canonical_name(self):
@ -96,7 +97,7 @@ class NXDOMAIN(dns.exception.DNSException):
cname = response.canonical_name() cname = response.canonical_name()
if cname != qname: if cname != qname:
return cname return cname
except Exception: except Exception: # pragma: no cover
# We can just eat this exception as it means there was # We can just eat this exception as it means there was
# something wrong with the response. # something wrong with the response.
pass pass
@ -154,7 +155,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
"""Turn a resolution errors trace into a list of text.""" """Turn a resolution errors trace into a list of text."""
texts = [] texts = []
for err in errors: for err in errors:
texts.append("Server {} answered {}".format(err[0], err[3])) texts.append(f"Server {err[0]} answered {err[3]}")
return texts return texts
@ -162,7 +163,7 @@ class LifetimeTimeout(dns.exception.Timeout):
"""The resolution lifetime expired.""" """The resolution lifetime expired."""
msg = "The resolution lifetime expired." msg = "The resolution lifetime expired."
fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1] fmt = f"{msg[:-1]} after {{timeout:.3f}} seconds: {{errors}}"
supp_kwargs = {"timeout", "errors"} supp_kwargs = {"timeout", "errors"}
# We do this as otherwise mypy complains about unexpected keyword argument # We do this as otherwise mypy complains about unexpected keyword argument
@ -211,7 +212,7 @@ class NoNameservers(dns.exception.DNSException):
""" """
msg = "All nameservers failed to answer the query." msg = "All nameservers failed to answer the query."
fmt = "%s {query}: {errors}" % msg[:-1] fmt = f"{msg[:-1]} {{query}}: {{errors}}"
supp_kwargs = {"request", "errors"} supp_kwargs = {"request", "errors"}
# We do this as otherwise mypy complains about unexpected keyword argument # We do this as otherwise mypy complains about unexpected keyword argument
@ -297,7 +298,7 @@ class Answer:
def __len__(self) -> int: def __len__(self) -> int:
return self.rrset and len(self.rrset) or 0 return self.rrset and len(self.rrset) or 0
def __iter__(self): def __iter__(self) -> Iterator[dns.rdata.Rdata]:
return self.rrset and iter(self.rrset) or iter(tuple()) return self.rrset and iter(self.rrset) or iter(tuple())
def __getitem__(self, i): def __getitem__(self, i):
@ -334,7 +335,7 @@ class HostAnswers(Answers):
answers[dns.rdatatype.A] = v4 answers[dns.rdatatype.A] = v4
return answers return answers
# Returns pairs of (address, family) from this result, potentiallys # Returns pairs of (address, family) from this result, potentially
# filtering by address family. # filtering by address family.
def addresses_and_families( def addresses_and_families(
self, family: int = socket.AF_UNSPEC self, family: int = socket.AF_UNSPEC
@ -347,7 +348,7 @@ class HostAnswers(Answers):
answer = self.get(dns.rdatatype.AAAA) answer = self.get(dns.rdatatype.AAAA)
elif family == socket.AF_INET: elif family == socket.AF_INET:
answer = self.get(dns.rdatatype.A) answer = self.get(dns.rdatatype.A)
else: else: # pragma: no cover
raise NotImplementedError(f"unknown address family {family}") raise NotImplementedError(f"unknown address family {family}")
if answer: if answer:
for rdata in answer: for rdata in answer:
@ -938,7 +939,7 @@ class BaseResolver:
self.reset() self.reset()
if configure: if configure:
if sys.platform == "win32": if sys.platform == "win32": # pragma: no cover
self.read_registry() self.read_registry()
elif filename: elif filename:
self.read_resolv_conf(filename) self.read_resolv_conf(filename)
@ -947,7 +948,7 @@ class BaseResolver:
"""Reset all resolver configuration to the defaults.""" """Reset all resolver configuration to the defaults."""
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
if len(self.domain) == 0: if len(self.domain) == 0: # pragma: no cover
self.domain = dns.name.root self.domain = dns.name.root
self._nameservers = [] self._nameservers = []
self.nameserver_ports = {} self.nameserver_ports = {}
@ -1040,7 +1041,7 @@ class BaseResolver:
# setter logic, with additonal checking and enrichment. # setter logic, with additonal checking and enrichment.
self.nameservers = nameservers self.nameservers = nameservers
def read_registry(self) -> None: def read_registry(self) -> None: # pragma: no cover
"""Extract resolver configuration from the Windows registry.""" """Extract resolver configuration from the Windows registry."""
try: try:
info = dns.win32util.get_dns_info() # type: ignore info = dns.win32util.get_dns_info() # type: ignore
@ -1205,9 +1206,7 @@ class BaseResolver:
enriched_nameservers.append(enriched_nameserver) enriched_nameservers.append(enriched_nameserver)
else: else:
raise ValueError( raise ValueError(
"nameservers must be a list or tuple (not a {})".format( f"nameservers must be a list or tuple (not a {type(nameservers)})"
type(nameservers)
)
) )
return enriched_nameservers return enriched_nameservers
@ -1431,7 +1430,7 @@ class Resolver(BaseResolver):
elif family == socket.AF_INET6: elif family == socket.AF_INET6:
v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return HostAnswers.make(v6=v6) return HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC: elif family != socket.AF_UNSPEC: # pragma: no cover
raise NotImplementedError(f"unknown address family {family}") raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
@ -1515,7 +1514,7 @@ class Resolver(BaseResolver):
nameservers = dns._ddr._get_nameservers_sync(answer, timeout) nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
if len(nameservers) > 0: if len(nameservers) > 0:
self.nameservers = nameservers self.nameservers = nameservers
except Exception: except Exception: # pragma: no cover
pass pass
@ -1640,7 +1639,7 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
return get_default_resolver().canonical_name(name) return get_default_resolver().canonical_name(name)
def try_ddr(lifetime: float = 5.0) -> None: def try_ddr(lifetime: float = 5.0) -> None: # pragma: no cover
"""Try to update the default resolver's nameservers using Discovery of Designated """Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries. DNS-over-HTTPS or DNS-over-TLS for future queries.
@ -1926,7 +1925,7 @@ def _getnameinfo(sockaddr, flags=0):
family = socket.AF_INET family = socket.AF_INET
tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0) tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0)
if len(tuples) > 1: if len(tuples) > 1:
raise socket.error("sockaddr resolved to multiple addresses") raise OSError("sockaddr resolved to multiple addresses")
addr = tuples[0][4][0] addr = tuples[0][4][0]
if flags & socket.NI_DGRAM: if flags & socket.NI_DGRAM:
pname = "udp" pname = "udp"
@ -1961,7 +1960,7 @@ def _getfqdn(name=None):
(name, _, _) = _gethostbyaddr(name) (name, _, _) = _gethostbyaddr(name)
# Python's version checks aliases too, but our gethostbyname # Python's version checks aliases too, but our gethostbyname
# ignores them, so we do so here as well. # ignores them, so we do so here as well.
except Exception: except Exception: # pragma: no cover
pass pass
return name return name

View file

@ -21,10 +21,11 @@ import itertools
class Set: class Set:
"""A simple set class. """A simple set class.
This class was originally used to deal with sets being missing in This class was originally used to deal with python not having a set class, and
ancient versions of python, but dnspython will continue to use it originally the class used lists in its implementation. The ordered and indexable
as these sets are based on lists and are thus indexable, and this nature of RRsets and Rdatasets is unfortunately widely used in dnspython
ability is widely used in dnspython applications. applications, so for backwards compatibility sets continue to be a custom class, now
based on an ordered dictionary.
""" """
__slots__ = ["items"] __slots__ = ["items"]
@ -43,7 +44,7 @@ class Set:
self.add(item) # lgtm[py/init-calls-subclass] self.add(item) # lgtm[py/init-calls-subclass]
def __repr__(self): def __repr__(self):
return "dns.set.Set(%s)" % repr(list(self.items.keys())) return f"dns.set.Set({repr(list(self.items.keys()))})" # pragma: no cover
def add(self, item): def add(self, item):
"""Add an item to the set.""" """Add an item to the set."""

View file

@ -528,7 +528,7 @@ class Tokenizer:
if value < 0 or value > 65535: if value < 0 or value > 65535:
if base == 8: if base == 8:
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
"%o is not an octal unsigned 16-bit integer" % value f"{value:o} is not an octal unsigned 16-bit integer"
) )
else: else:
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(

View file

@ -486,7 +486,7 @@ class Transaction:
if exact: if exact:
raise DeleteNotExact(f"{method}: missing rdataset") raise DeleteNotExact(f"{method}: missing rdataset")
else: else:
self._delete_rdataset(name, rdtype, covers) self._checked_delete_rdataset(name, rdtype, covers)
return return
else: else:
rdataset = self._rdataset_from_args(method, True, args) rdataset = self._rdataset_from_args(method, True, args)
@ -529,8 +529,6 @@ class Transaction:
def _end(self, commit): def _end(self, commit):
self._check_ended() self._check_ended()
if self._ended:
raise AlreadyEnded
try: try:
self._end_transaction(commit) self._end_transaction(commit)
finally: finally:

View file

@ -73,7 +73,7 @@ def from_text(text: str) -> int:
elif c == "s": elif c == "s":
total += current total += current
else: else:
raise BadTTL("unknown unit '%s'" % c) raise BadTTL(f"unknown unit '{c}'")
current = 0 current = 0
need_digit = True need_digit = True
if not current == 0: if not current == 0:

View file

@ -20,9 +20,9 @@
#: MAJOR #: MAJOR
MAJOR = 2 MAJOR = 2
#: MINOR #: MINOR
MINOR = 6 MINOR = 7
#: MICRO #: MICRO
MICRO = 1 MICRO = 0
#: RELEASELEVEL #: RELEASELEVEL
RELEASELEVEL = 0x0F RELEASELEVEL = 0x0F
#: SERIAL #: SERIAL

View file

@ -13,8 +13,8 @@ if sys.platform == "win32":
# Keep pylint quiet on non-windows. # Keep pylint quiet on non-windows.
try: try:
WindowsError is None # pylint: disable=used-before-assignment _ = WindowsError # pylint: disable=used-before-assignment
except KeyError: except NameError:
WindowsError = Exception WindowsError = Exception
if dns._features.have("wmi"): if dns._features.have("wmi"):
@ -44,6 +44,7 @@ if sys.platform == "win32":
if _have_wmi: if _have_wmi:
class _WMIGetter(threading.Thread): class _WMIGetter(threading.Thread):
# pylint: disable=possibly-used-before-assignment
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.info = DnsInfo() self.info = DnsInfo()
@ -82,32 +83,21 @@ if sys.platform == "win32":
def __init__(self): def __init__(self):
self.info = DnsInfo() self.info = DnsInfo()
def _determine_split_char(self, entry): def _split(self, text):
# # The windows registry has used both " " and "," as a delimiter, and while
# The windows registry irritatingly changes the list element # it is currently using "," in Windows 10 and later, updates can seemingly
# delimiter in between ' ' and ',' (and vice-versa) in various # leave a space in too, e.g. "a, b". So we just convert all commas to
# versions of windows. # spaces, and use split() in its default configuration, which splits on
# # all whitespace and ignores empty strings.
if entry.find(" ") >= 0: return text.replace(",", " ").split()
split_char = " "
elif entry.find(",") >= 0:
split_char = ","
else:
# probably a singleton; treat as a space-separated list.
split_char = " "
return split_char
def _config_nameservers(self, nameservers): def _config_nameservers(self, nameservers):
split_char = self._determine_split_char(nameservers) for ns in self._split(nameservers):
ns_list = nameservers.split(split_char)
for ns in ns_list:
if ns not in self.info.nameservers: if ns not in self.info.nameservers:
self.info.nameservers.append(ns) self.info.nameservers.append(ns)
def _config_search(self, search): def _config_search(self, search):
split_char = self._determine_split_char(search) for s in self._split(search):
search_list = search.split(split_char)
for s in search_list:
s = _config_domain(s) s = _config_domain(s)
if s not in self.info.search: if s not in self.info.search:
self.info.search.append(s) self.info.search.append(s)
@ -164,7 +154,7 @@ if sys.platform == "win32":
lm, lm,
r"SYSTEM\CurrentControlSet\Control\Network" r"SYSTEM\CurrentControlSet\Control\Network"
r"\{4D36E972-E325-11CE-BFC1-08002BE10318}" r"\{4D36E972-E325-11CE-BFC1-08002BE10318}"
r"\%s\Connection" % guid, rf"\{guid}\Connection",
) )
try: try:
@ -177,7 +167,7 @@ if sys.platform == "win32":
raise ValueError # pragma: no cover raise ValueError # pragma: no cover
device_key = winreg.OpenKey( device_key = winreg.OpenKey(
lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id lm, rf"SYSTEM\CurrentControlSet\Enum\{pnp_id}"
) )
try: try:
@ -232,7 +222,7 @@ if sys.platform == "win32":
self._config_fromkey(key, False) self._config_fromkey(key, False)
finally: finally:
key.Close() key.Close()
except EnvironmentError: except OSError:
break break
finally: finally:
interfaces.Close() interfaces.Close()

View file

@ -33,7 +33,7 @@ class TransferError(dns.exception.DNSException):
"""A zone transfer response got a non-zero rcode.""" """A zone transfer response got a non-zero rcode."""
def __init__(self, rcode): def __init__(self, rcode):
message = "Zone transfer error: %s" % dns.rcode.to_text(rcode) message = f"Zone transfer error: {dns.rcode.to_text(rcode)}"
super().__init__(message) super().__init__(message)
self.rcode = rcode self.rcode = rcode

View file

@ -230,7 +230,7 @@ class Reader:
try: try:
rdtype = dns.rdatatype.from_text(token.value) rdtype = dns.rdatatype.from_text(token.value)
except Exception: except Exception:
raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'")
try: try:
rd = dns.rdata.from_text( rd = dns.rdata.from_text(
@ -251,9 +251,7 @@ class Reader:
# We convert them to syntax errors so that we can emit # We convert them to syntax errors so that we can emit
# helpful filename:line info. # helpful filename:line info.
(ty, va) = sys.exc_info()[:2] (ty, va) = sys.exc_info()[:2]
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(f"caught exception {str(ty)}: {str(va)}")
"caught exception {}: {}".format(str(ty), str(va))
)
if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
# The pre-RFC2308 and pre-BIND9 behavior inherits the zone default # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
@ -281,41 +279,41 @@ class Reader:
# Sometimes there are modifiers in the hostname. These come after # Sometimes there are modifiers in the hostname. These come after
# the dollar sign. They are in the form: ${offset[,width[,base]]}. # the dollar sign. They are in the form: ${offset[,width[,base]]}.
# Make names # Make names
mod = ""
sign = "+"
offset = "0"
width = "0"
base = "d"
g1 = is_generate1.match(side) g1 = is_generate1.match(side)
if g1: if g1:
mod, sign, offset, width, base = g1.groups() mod, sign, offset, width, base = g1.groups()
if sign == "": if sign == "":
sign = "+" sign = "+"
g2 = is_generate2.match(side) else:
if g2: g2 = is_generate2.match(side)
mod, sign, offset = g2.groups() if g2:
if sign == "": mod, sign, offset = g2.groups()
sign = "+" if sign == "":
width = 0 sign = "+"
base = "d" width = "0"
g3 = is_generate3.match(side) base = "d"
if g3: else:
mod, sign, offset, width = g3.groups() g3 = is_generate3.match(side)
if sign == "": if g3:
sign = "+" mod, sign, offset, width = g3.groups()
base = "d" if sign == "":
sign = "+"
base = "d"
if not (g1 or g2 or g3): ioffset = int(offset)
mod = "" iwidth = int(width)
sign = "+"
offset = 0
width = 0
base = "d"
offset = int(offset)
width = int(width)
if sign not in ["+", "-"]: if sign not in ["+", "-"]:
raise dns.exception.SyntaxError("invalid offset sign %s" % sign) raise dns.exception.SyntaxError(f"invalid offset sign {sign}")
if base not in ["d", "o", "x", "X", "n", "N"]: if base not in ["d", "o", "x", "X", "n", "N"]:
raise dns.exception.SyntaxError("invalid type %s" % base) raise dns.exception.SyntaxError(f"invalid type {base}")
return mod, sign, offset, width, base return mod, sign, ioffset, iwidth, base
def _generate_line(self): def _generate_line(self):
# range lhs [ttl] [class] type rhs [ comment ] # range lhs [ttl] [class] type rhs [ comment ]
@ -377,7 +375,7 @@ class Reader:
if not token.is_identifier(): if not token.is_identifier():
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
except Exception: except Exception:
raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'")
# rhs (required) # rhs (required)
rhs = token.value rhs = token.value
@ -412,8 +410,8 @@ class Reader:
lzfindex = _format_index(lindex, lbase, lwidth) lzfindex = _format_index(lindex, lbase, lwidth)
rzfindex = _format_index(rindex, rbase, rwidth) rzfindex = _format_index(rindex, rbase, rwidth)
name = lhs.replace("$%s" % (lmod), lzfindex) name = lhs.replace(f"${lmod}", lzfindex)
rdata = rhs.replace("$%s" % (rmod), rzfindex) rdata = rhs.replace(f"${rmod}", rzfindex)
self.last_name = dns.name.from_text( self.last_name = dns.name.from_text(
name, self.current_origin, self.tok.idna_codec name, self.current_origin, self.tok.idna_codec
@ -445,7 +443,7 @@ class Reader:
# helpful filename:line info. # helpful filename:line info.
(ty, va) = sys.exc_info()[:2] (ty, va) = sys.exc_info()[:2]
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
"caught exception %s: %s" % (str(ty), str(va)) f"caught exception {str(ty)}: {str(va)}"
) )
self.txn.add(name, ttl, rd) self.txn.add(name, ttl, rd)
@ -528,7 +526,7 @@ class Reader:
self.default_ttl_known, self.default_ttl_known,
) )
) )
self.current_file = open(filename, "r") self.current_file = open(filename)
self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) self.tok = dns.tokenizer.Tokenizer(self.current_file, filename)
self.current_origin = new_origin self.current_origin = new_origin
elif c == "$GENERATE": elif c == "$GENERATE":

View file

@ -7,7 +7,7 @@ cheroot==10.0.1
cherrypy==18.10.0 cherrypy==18.10.0
cloudinary==1.41.0 cloudinary==1.41.0
distro==1.9.0 distro==1.9.0
dnspython==2.6.1 dnspython==2.7.0
facebook-sdk==3.1.0 facebook-sdk==3.1.0
future==1.0.0 future==1.0.0
ga4mp==2.0.4 ga4mp==2.0.4