diff --git a/lib/dns/_asyncbackend.py b/lib/dns/_asyncbackend.py index 49f14fed..f6760fd0 100644 --- a/lib/dns/_asyncbackend.py +++ b/lib/dns/_asyncbackend.py @@ -26,6 +26,10 @@ class NullContext: class Socket: # pragma: no cover + def __init__(self, family: int, type: int): + self.family = family + self.type = type + async def close(self): pass @@ -46,9 +50,6 @@ class Socket: # pragma: no cover class DatagramSocket(Socket): # pragma: no cover - def __init__(self, family: int): - self.family = family - async def sendto(self, what, destination, timeout): raise NotImplementedError diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py index 9d9ed369..6ab168de 100644 --- a/lib/dns/_asyncio_backend.py +++ b/lib/dns/_asyncio_backend.py @@ -42,7 +42,7 @@ class _DatagramProtocol: if exc is None: # EOF we triggered. Is there a better way to do this? try: - raise EOFError + raise EOFError("EOF") except EOFError as e: self.recvfrom.set_exception(e) else: @@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout): class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, transport, protocol): - super().__init__(family) + super().__init__(family, socket.SOCK_DGRAM) self.transport = transport self.protocol = protocol @@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, af, reader, writer): - self.family = af + super().__init__(af, socket.SOCK_STREAM) self.reader = reader self.writer = writer @@ -197,7 +197,7 @@ if dns._features.have("doh"): family=socket.AF_UNSPEC, **kwargs, ): - if resolver is None: + if resolver is None and bootstrap_address is None: # pylint: disable=import-outside-toplevel,redefined-outer-name import dns.asyncresolver diff --git a/lib/dns/_features.py b/lib/dns/_features.py index 03ccaa77..fa6d4955 100644 --- a/lib/dns/_features.py +++ b/lib/dns/_features.py @@ -32,6 +32,9 @@ def _version_check( package, minimum = requirement.split(">=") try: version = importlib.metadata.version(package) + # This shouldn't happen, but it apparently can. + if version is None: + return False except Exception: return False t_version = _tuple_from_text(version) @@ -82,10 +85,10 @@ def force(feature: str, enabled: bool) -> None: _requirements: Dict[str, List[str]] = { ### BEGIN generated requirements - "dnssec": ["cryptography>=41"], + "dnssec": ["cryptography>=43"], "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"], - "doq": ["aioquic>=0.9.25"], - "idna": ["idna>=3.6"], + "doq": ["aioquic>=1.0.0"], + "idna": ["idna>=3.7"], "trio": ["trio>=0.23"], "wmi": ["wmi>=1.5.1"], ### END generated requirements diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py index 398e3276..0ed904dd 100644 --- a/lib/dns/_trio_backend.py +++ b/lib/dns/_trio_backend.py @@ -30,13 +30,16 @@ _lltuple = dns.inet.low_level_address_tuple class DatagramSocket(dns._asyncbackend.DatagramSocket): - def __init__(self, socket): - super().__init__(socket.family) - self.socket = socket + def __init__(self, sock): + super().__init__(sock.family, socket.SOCK_DGRAM) + self.socket = sock async def sendto(self, what, destination, timeout): with _maybe_timeout(timeout): - return await self.socket.sendto(what, destination) + if destination is None: + return await self.socket.send(what) + else: + return await self.socket.sendto(what, destination) raise dns.exception.Timeout( timeout=timeout ) # pragma: no cover lgtm[py/unreachable-statement] @@ -61,7 +64,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, family, stream, tls=False): - self.family = family + super().__init__(family, socket.SOCK_STREAM) self.stream = stream self.tls = tls @@ -171,7 +174,7 @@ if dns._features.have("doh"): family=socket.AF_UNSPEC, **kwargs, ): - if resolver is None: + if resolver is None and bootstrap_address is None: # pylint: disable=import-outside-toplevel,redefined-outer-name import dns.asyncresolver @@ -205,7 +208,7 @@ class Backend(dns._asyncbackend.Backend): try: if source: await s.bind(_lltuple(source, af)) - if socktype == socket.SOCK_STREAM: + if socktype == socket.SOCK_STREAM or destination is not None: connected = False with _maybe_timeout(timeout): await s.connect(_lltuple(destination, af)) diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py index 4d9ab9ae..efad0fd7 100644 --- a/lib/dns/asyncquery.py +++ b/lib/dns/asyncquery.py @@ -19,10 +19,12 @@ import base64 import contextlib +import random import socket import struct import time -from typing import Any, Dict, Optional, Tuple, Union +import urllib.parse +from typing import Any, Dict, Optional, Tuple, Union, cast import dns.asyncbackend import dns.exception @@ -37,9 +39,11 @@ import dns.transaction from dns._asyncbackend import NullContext from dns.query import ( BadResponse, + HTTPVersion, NoDOH, NoDOQ, UDPMode, + _check_status, _compute_times, _make_dot_ssl_context, _matches_destination, @@ -338,7 +342,7 @@ async def _read_exactly(sock, count, expiration): while count > 0: n = await sock.recv(count, _timeout(expiration)) if n == b"": - raise EOFError + raise EOFError("EOF") count = count - len(n) s = s + n return s @@ -500,6 +504,20 @@ async def tls( return response +def _maybe_get_resolver( + resolver: Optional["dns.asyncresolver.Resolver"], +) -> "dns.asyncresolver.Resolver": + # We need a separate method for this to avoid overriding the global + # variable "dns" with the as-yet undefined local variable "dns" + # in https(). + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + return resolver + + async def https( q: dns.message.Message, where: str, @@ -515,7 +533,8 @@ async def https( verify: Union[bool, str] = True, bootstrap_address: Optional[str] = None, resolver: Optional["dns.asyncresolver.Resolver"] = None, - family: Optional[int] = socket.AF_UNSPEC, + family: int = socket.AF_UNSPEC, + http_version: HTTPVersion = HTTPVersion.DEFAULT, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -529,26 +548,65 @@ async def https( parameters, exceptions, and return type of this method. """ - if not have_doh: - raise NoDOH # pragma: no cover - if client and not isinstance(client, httpx.AsyncClient): - raise ValueError("session parameter must be an httpx.AsyncClient") - - wire = q.to_wire() try: af = dns.inet.af_for_address(where) except ValueError: af = None - transport = None - headers = {"accept": "application/dns-message"} if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: - url = "https://{}:{}{}".format(where, port, path) + url = f"https://{where}:{port}{path}" elif af == socket.AF_INET6: - url = "https://[{}]:{}{}".format(where, port, path) + url = f"https://[{where}]:{port}{path}" else: url = where + extensions = {} + if bootstrap_address is None: + # pylint: disable=possibly-used-before-assignment + parsed = urllib.parse.urlparse(url) + if parsed.hostname is None: + raise ValueError("no hostname in URL") + if dns.inet.is_address(parsed.hostname): + bootstrap_address = parsed.hostname + extensions["sni_hostname"] = parsed.hostname + if parsed.port is not None: + port = parsed.port + + if http_version == HTTPVersion.H3 or ( + http_version == HTTPVersion.DEFAULT and not have_doh + ): + if bootstrap_address is None: + resolver = _maybe_get_resolver(resolver) + assert parsed.hostname is not None # for mypy + answers = await resolver.resolve_name(parsed.hostname, family) + bootstrap_address = random.choice(list(answers.addresses())) + return await _http3( + q, + bootstrap_address, + url, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + verify=verify, + post=post, + ) + + if not have_doh: + raise NoDOH # pragma: no cover + # pylint: disable=possibly-used-before-assignment + if client and not isinstance(client, httpx.AsyncClient): + raise ValueError("session parameter must be an httpx.AsyncClient") + # pylint: enable=possibly-used-before-assignment + + wire = q.to_wire() + headers = {"accept": "application/dns-message"} + + h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT) + h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT) + backend = dns.asyncbackend.get_default_backend() if source is None: @@ -557,24 +615,23 @@ async def https( else: local_address = source local_port = source_port - transport = backend.get_transport_class()( - local_address=local_address, - http1=True, - http2=True, - verify=verify, - local_port=local_port, - bootstrap_address=bootstrap_address, - resolver=resolver, - family=family, - ) if client: cm: contextlib.AbstractAsyncContextManager = NullContext(client) else: - cm = httpx.AsyncClient( - http1=True, http2=True, verify=verify, transport=transport + transport = backend.get_transport_class()( + local_address=local_address, + http1=h1, + http2=h2, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, ) + cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport) + async with cm as the_client: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples @@ -586,23 +643,33 @@ async def https( } ) response = await backend.wait_for( - the_client.post(url, headers=headers, content=wire), timeout + the_client.post( + url, + headers=headers, + content=wire, + extensions=extensions, + ), + timeout, ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes response = await backend.wait_for( - the_client.get(url, headers=headers, params={"dns": twire}), timeout + the_client.get( + url, + headers=headers, + params={"dns": twire}, + extensions=extensions, + ), + timeout, ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: raise ValueError( - "{} responded with status code {}" - "\nResponse body: {!r}".format( - where, response.status_code, response.content - ) + f"{where} responded with status code {response.status_code}" + f"\nResponse body: {response.content!r}" ) r = dns.message.from_wire( response.content, @@ -617,6 +684,181 @@ async def https( return r +async def _http3( + q: dns.message.Message, + where: str, + url: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + verify: Union[bool, str] = True, + backend: Optional[dns.asyncbackend.Backend] = None, + hostname: Optional[str] = None, + post: bool = True, +) -> dns.message.Message: + if not dns.quic.have_quic: + raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover + + url_parts = urllib.parse.urlparse(url) + hostname = url_parts.hostname + if url_parts.port is not None: + port = url_parts.port + + q.id = 0 + wire = q.to_wire() + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + + async with cfactory() as context: + async with mfactory( + context, verify_mode=verify, server_name=hostname, h3=True + ) as the_manager: + the_connection = the_manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) + async with stream: + # note that send_h3() does not need await + stream.send_h3(url, wire, post) + wire = await stream.receive(_remaining(expiration)) + _check_status(stream.headers(), where, wire) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + +async def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.AsyncQuicConnection] = None, + verify: Union[bool, str] = True, + backend: Optional[dns.asyncbackend.Backend] = None, + hostname: Optional[str] = None, + server_hostname: Optional[str] = None, +) -> dns.message.Message: + """Return the response obtained after sending an asynchronous query via + DNS-over-QUIC. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.quic()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + if server_hostname is not None and hostname is None: + hostname = server_hostname + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.AsyncQuicConnection + if connection: + cfactory = dns.quic.null_factory + mfactory = dns.quic.null_factory + the_connection = connection + else: + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + + async with cfactory() as context: + async with mfactory( + context, + verify_mode=verify, + server_name=server_hostname, + ) as the_manager: + if not connection: + the_connection = the_manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) + async with stream: + await stream.send(wire, True) + wire = await stream.receive(_remaining(expiration)) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + +async def _inbound_xfr( + txn_manager: dns.transaction.TransactionManager, + s: dns.asyncbackend.Socket, + query: dns.message.Message, + serial: Optional[int], + timeout: Optional[float], + expiration: float, +) -> Any: + """Given a socket, does the zone transfer.""" + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + is_udp = s.type == socket.SOCK_DGRAM + if is_udp: + udp_sock = cast(dns.asyncbackend.DatagramSocket, s) + await udp_sock.sendto(wire, None, _timeout(expiration)) + else: + tcp_sock = cast(dns.asyncbackend.StreamSocket, s) + tcpmsg = struct.pack("!H", len(wire)) + wire + await tcp_sock.sendall(tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if is_udp: + timeout = _timeout(mexpiration) + (rwire, _) = await udp_sock.recvfrom(65535, timeout) + else: + ldata = await _read_exactly(tcp_sock, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = await _read_exactly(tcp_sock, l, mexpiration) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + done = inbound.process_message(r) + yield r + tsig_ctx = r.tsig_ctx + if query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + + async def inbound_xfr( where: str, txn_manager: dns.transaction.TransactionManager, @@ -642,139 +884,30 @@ async def inbound_xfr( (query, serial) = dns.xfr.make_query(txn_manager) else: serial = dns.xfr.extract_serial_from_query(query) - rdtype = query.question[0].rdtype - is_ixfr = rdtype == dns.rdatatype.IXFR - origin = txn_manager.from_wire_origin() - wire = query.to_wire() af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() (_, expiration) = _compute_times(lifetime) - retry = True - while retry: - retry = False - if is_ixfr and udp_mode != UDPMode.NEVER: - sock_type = socket.SOCK_DGRAM - is_udp = True - else: - sock_type = socket.SOCK_STREAM - is_udp = False - if not backend: - backend = dns.asyncbackend.get_default_backend() + if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER: s = await backend.make_socket( - af, sock_type, 0, stuple, dtuple, _timeout(expiration) + af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration) ) async with s: - if is_udp: - await s.sendto(wire, dtuple, _timeout(expiration)) - else: - tcpmsg = struct.pack("!H", len(wire)) + wire - await s.sendall(tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: - done = False - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if is_udp: - destination = _lltuple((where, port), af) - while True: - timeout = _timeout(mexpiration) - (rwire, from_address) = await s.recvfrom(65535, timeout) - if _matches_destination( - af, from_address, destination, True - ): - break - else: - ldata = await _read_exactly(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - rwire = await _read_exactly(s, l, mexpiration) - is_ixfr = rdtype == dns.rdatatype.IXFR - r = dns.message.from_wire( - rwire, - keyring=query.keyring, - request_mac=query.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr, - ) - try: - done = inbound.process_message(r) - except dns.xfr.UseTCP: - assert is_udp # should not happen if we used TCP! - if udp_mode == UDPMode.ONLY: - raise - done = True - retry = True - udp_mode = UDPMode.NEVER - continue - tsig_ctx = r.tsig_ctx - if not retry and query.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") + try: + async for _ in _inbound_xfr( + txn_manager, s, query, serial, timeout, expiration + ): + pass + return + except dns.xfr.UseTCP: + if udp_mode == UDPMode.ONLY: + raise - -async def quic( - q: dns.message.Message, - where: str, - timeout: Optional[float] = None, - port: int = 853, - source: Optional[str] = None, - source_port: int = 0, - one_rr_per_rrset: bool = False, - ignore_trailing: bool = False, - connection: Optional[dns.quic.AsyncQuicConnection] = None, - verify: Union[bool, str] = True, - backend: Optional[dns.asyncbackend.Backend] = None, - server_hostname: Optional[str] = None, -) -> dns.message.Message: - """Return the response obtained after sending an asynchronous query via - DNS-over-QUIC. - - *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, - the default, then dnspython will use the default backend. - - See :py:func:`dns.query.quic()` for the documentation of the other - parameters, exceptions, and return type of this method. - """ - - if not dns.quic.have_quic: - raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover - - q.id = 0 - wire = q.to_wire() - the_connection: dns.quic.AsyncQuicConnection - if connection: - cfactory = dns.quic.null_factory - mfactory = dns.quic.null_factory - the_connection = connection - else: - (cfactory, mfactory) = dns.quic.factories_for_backend(backend) - - async with cfactory() as context: - async with mfactory( - context, verify_mode=verify, server_name=server_hostname - ) as the_manager: - if not connection: - the_connection = the_manager.connect(where, port, source, source_port) - (start, expiration) = _compute_times(timeout) - stream = await the_connection.make_stream(timeout) - async with stream: - await stream.send(wire, True) - wire = await stream.receive(_remaining(expiration)) - finish = time.time() - r = dns.message.from_wire( - wire, - keyring=q.keyring, - request_mac=q.request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - ) - r.time = max(finish - start, 0.0) - if not q.is_response(r): - raise BadResponse - return r + s = await backend.make_socket( + af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration) + ) + async with s: + async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration): + pass diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index e49c3b79..b69d0a12 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -118,6 +118,7 @@ def key_id(key: Union[DNSKEY, CDNSKEY]) -> int: """ rdata = key.to_wire() + assert rdata is not None # for mypy if key.algorithm == Algorithm.RSAMD5: return (rdata[-3] << 8) + rdata[-2] else: @@ -224,7 +225,7 @@ def make_ds( if isinstance(algorithm, str): algorithm = DSDigest[algorithm.upper()] except Exception: - raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"') if validating: check = policy.ok_to_validate_ds else: @@ -240,14 +241,15 @@ def make_ds( elif algorithm == DSDigest.SHA384: dshash = hashlib.sha384() else: - raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"') if isinstance(name, str): name = dns.name.from_text(name, origin) wire = name.canonicalize().to_wire() - assert wire is not None + kwire = key.to_wire(origin=origin) + assert wire is not None and kwire is not None # for mypy dshash.update(wire) - dshash.update(key.to_wire(origin=origin)) + dshash.update(kwire) digest = dshash.digest() dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest @@ -323,6 +325,7 @@ def _get_rrname_rdataset( def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None: + # pylint: disable=possibly-used-before-assignment public_cls = get_algorithm_cls_from_dnskey(key).public_cls try: public_key = public_cls.from_dnskey(key) @@ -387,6 +390,7 @@ def _validate_rrsig( data = _make_rrsig_signature_data(rrset, rrsig, origin) + # pylint: disable=possibly-used-before-assignment for candidate_key in candidate_keys: if not policy.ok_to_validate(candidate_key): continue @@ -484,6 +488,7 @@ def _sign( verify: bool = False, policy: Optional[Policy] = None, origin: Optional[dns.name.Name] = None, + deterministic: bool = True, ) -> RRSIG: """Sign RRset using private key. @@ -523,6 +528,10 @@ def _sign( names in the rrset (including its owner name) must be absolute; otherwise the specified origin will be used to make names absolute when signing. + *deterministic*, a ``bool``. If ``True``, the default, use deterministic + (reproducible) signatures when supported by the algorithm used for signing. + Currently, this only affects ECDSA. + Raises ``DeniedByPolicy`` if the signature is denied by policy. """ @@ -580,6 +589,7 @@ def _sign( data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) + # pylint: disable=possibly-used-before-assignment if isinstance(private_key, GenericPrivateKey): signing_key = private_key else: @@ -589,7 +599,7 @@ def _sign( except UnsupportedAlgorithm: raise TypeError("Unsupported key algorithm") - signature = signing_key.sign(data, verify) + signature = signing_key.sign(data, verify, deterministic) return cast(RRSIG, rrsig_template.replace(signature=signature)) @@ -629,7 +639,9 @@ def _make_rrsig_signature_data( rrname, rdataset = _get_rrname_rdataset(rrset) data = b"" - data += rrsig.to_wire(origin=signer)[:18] + wire = rrsig.to_wire(origin=signer) + assert wire is not None # for mypy + data += wire[:18] data += rrsig.signer.to_digestable(signer) # Derelativize the name before considering labels. @@ -686,6 +698,7 @@ def _make_dnskey( algorithm = Algorithm.make(algorithm) + # pylint: disable=possibly-used-before-assignment if isinstance(public_key, GenericPublicKey): return public_key.to_dnskey(flags=flags, protocol=protocol) else: @@ -832,7 +845,7 @@ def make_ds_rdataset( if isinstance(algorithm, str): algorithm = DSDigest[algorithm.upper()] except Exception: - raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"') _algorithms.add(algorithm) if rdataset.rdtype == dns.rdatatype.CDS: @@ -950,6 +963,7 @@ def default_rrset_signer( lifetime: Optional[int] = None, policy: Optional[Policy] = None, origin: Optional[dns.name.Name] = None, + deterministic: bool = True, ) -> None: """Default RRset signer""" @@ -975,6 +989,7 @@ def default_rrset_signer( signer=signer, policy=policy, origin=origin, + deterministic=deterministic, ) txn.add(rrset.name, rrset.ttl, rrsig) @@ -991,6 +1006,7 @@ def sign_zone( nsec3: Optional[NSEC3PARAM] = None, rrset_signer: Optional[RRsetSigner] = None, policy: Optional[Policy] = None, + deterministic: bool = True, ) -> None: """Sign zone. @@ -1030,6 +1046,10 @@ def sign_zone( function requires two arguments: transaction and RRset. If the not specified, ``dns.dnssec.default_rrset_signer`` will be used. + *deterministic*, a ``bool``. If ``True``, the default, use deterministic + (reproducible) signatures when supported by the algorithm used for signing. + Currently, this only affects ECDSA. + Returns ``None``. """ @@ -1056,6 +1076,9 @@ def sign_zone( else: cm = zone.writer() + if zone.origin is None: + raise ValueError("no zone origin") + with cm as _txn: if add_dnskey: if dnskey_ttl is None: @@ -1081,6 +1104,7 @@ def sign_zone( lifetime=lifetime, policy=policy, origin=zone.origin, + deterministic=deterministic, ) return _sign_zone_nsec(zone, _txn, _rrset_signer) diff --git a/lib/dns/dnssecalgs/__init__.py b/lib/dns/dnssecalgs/__init__.py index 3d9181a7..602367e3 100644 --- a/lib/dns/dnssecalgs/__init__.py +++ b/lib/dns/dnssecalgs/__init__.py @@ -26,6 +26,7 @@ AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} if _have_cryptography: + # pylint: disable=possibly-used-before-assignment algorithms.update( { (Algorithm.RSAMD5, None): PrivateRSAMD5, @@ -59,7 +60,7 @@ def get_algorithm_cls( if cls: return cls raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm) + f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython' ) diff --git a/lib/dns/dnssecalgs/base.py b/lib/dns/dnssecalgs/base.py index e990575a..752ee480 100644 --- a/lib/dns/dnssecalgs/base.py +++ b/lib/dns/dnssecalgs/base.py @@ -65,7 +65,12 @@ class GenericPrivateKey(ABC): pass @abstractmethod - def sign(self, data: bytes, verify: bool = False) -> bytes: + def sign( + self, + data: bytes, + verify: bool = False, + deterministic: bool = True, + ) -> bytes: """Sign DNSSEC data""" @abstractmethod diff --git a/lib/dns/dnssecalgs/dsa.py b/lib/dns/dnssecalgs/dsa.py index 0fe4690d..adca3def 100644 --- a/lib/dns/dnssecalgs/dsa.py +++ b/lib/dns/dnssecalgs/dsa.py @@ -68,7 +68,12 @@ class PrivateDSA(CryptographyPrivateKey): key_cls = dsa.DSAPrivateKey public_cls = PublicDSA - def sign(self, data: bytes, verify: bool = False) -> bytes: + def sign( + self, + data: bytes, + verify: bool = False, + deterministic: bool = True, + ) -> bytes: """Sign using a private key per RFC 2536, section 3.""" public_dsa_key = self.key.public_key() if public_dsa_key.key_size > 1024: diff --git a/lib/dns/dnssecalgs/ecdsa.py b/lib/dns/dnssecalgs/ecdsa.py index a31d79f2..86d5764c 100644 --- a/lib/dns/dnssecalgs/ecdsa.py +++ b/lib/dns/dnssecalgs/ecdsa.py @@ -47,9 +47,17 @@ class PrivateECDSA(CryptographyPrivateKey): key_cls = ec.EllipticCurvePrivateKey public_cls = PublicECDSA - def sign(self, data: bytes, verify: bool = False) -> bytes: + def sign( + self, + data: bytes, + verify: bool = False, + deterministic: bool = True, + ) -> bytes: """Sign using a private key per RFC 6605, section 4.""" - der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash)) + algorithm = ec.ECDSA( + self.public_cls.chosen_hash, deterministic_signing=deterministic + ) + der_signature = self.key.sign(data, algorithm) dsa_r, dsa_s = utils.decode_dss_signature(der_signature) signature = int.to_bytes( dsa_r, length=self.public_cls.octets, byteorder="big" diff --git a/lib/dns/dnssecalgs/eddsa.py b/lib/dns/dnssecalgs/eddsa.py index 70505342..604bcbfe 100644 --- a/lib/dns/dnssecalgs/eddsa.py +++ b/lib/dns/dnssecalgs/eddsa.py @@ -29,7 +29,12 @@ class PublicEDDSA(CryptographyPublicKey): class PrivateEDDSA(CryptographyPrivateKey): public_cls: Type[PublicEDDSA] - def sign(self, data: bytes, verify: bool = False) -> bytes: + def sign( + self, + data: bytes, + verify: bool = False, + deterministic: bool = True, + ) -> bytes: """Sign using a private key per RFC 8080, section 4.""" signature = self.key.sign(data) if verify: diff --git a/lib/dns/dnssecalgs/rsa.py b/lib/dns/dnssecalgs/rsa.py index e95dcf1d..27537aad 100644 --- a/lib/dns/dnssecalgs/rsa.py +++ b/lib/dns/dnssecalgs/rsa.py @@ -56,7 +56,12 @@ class PrivateRSA(CryptographyPrivateKey): public_cls = PublicRSA default_public_exponent = 65537 - def sign(self, data: bytes, verify: bool = False) -> bytes: + def sign( + self, + data: bytes, + verify: bool = False, + deterministic: bool = True, + ) -> bytes: """Sign using a private key per RFC 3110, section 3.""" signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash) if verify: diff --git a/lib/dns/edns.py b/lib/dns/edns.py index 776e5eeb..f7d9ff99 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -52,6 +52,8 @@ class OptionType(dns.enum.IntEnum): CHAIN = 13 #: EDE (extended-dns-error) EDE = 15 + #: REPORTCHANNEL + REPORTCHANNEL = 18 @classmethod def _maximum(cls): @@ -222,7 +224,7 @@ class ECSOption(Option): # lgtm[py/missing-equals] self.addrdata = self.addrdata[:-1] + last def to_text(self) -> str: - return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen) + return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}" @staticmethod def from_text(text: str) -> Option: @@ -255,10 +257,10 @@ class ECSOption(Option): # lgtm[py/missing-equals] ecs_text = tokens[0] elif len(tokens) == 2: if tokens[0] != optional_prefix: - raise ValueError('could not parse ECS from "{}"'.format(text)) + raise ValueError(f'could not parse ECS from "{text}"') ecs_text = tokens[1] else: - raise ValueError('could not parse ECS from "{}"'.format(text)) + raise ValueError(f'could not parse ECS from "{text}"') n_slashes = ecs_text.count("/") if n_slashes == 1: address, tsrclen = ecs_text.split("/") @@ -266,18 +268,16 @@ class ECSOption(Option): # lgtm[py/missing-equals] elif n_slashes == 2: address, tsrclen, tscope = ecs_text.split("/") else: - raise ValueError('could not parse ECS from "{}"'.format(text)) + raise ValueError(f'could not parse ECS from "{text}"') try: scope = int(tscope) except ValueError: - raise ValueError( - "invalid scope " + '"{}": scope must be an integer'.format(tscope) - ) + raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer') try: srclen = int(tsrclen) except ValueError: raise ValueError( - "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen) + "invalid srclen " + f'"{tsrclen}": srclen must be an integer' ) return ECSOption(address, srclen, scope) @@ -430,10 +430,65 @@ class NSIDOption(Option): return cls(parser.get_remaining()) +class CookieOption(Option): + def __init__(self, client: bytes, server: bytes): + super().__init__(dns.edns.OptionType.COOKIE) + self.client = client + self.server = server + if len(client) != 8: + raise ValueError("client cookie must be 8 bytes") + if len(server) != 0 and (len(server) < 8 or len(server) > 32): + raise ValueError("server cookie must be empty or between 8 and 32 bytes") + + def to_wire(self, file: Any = None) -> Optional[bytes]: + if file: + file.write(self.client) + if len(self.server) > 0: + file.write(self.server) + return None + else: + return self.client + self.server + + def to_text(self) -> str: + client = binascii.hexlify(self.client).decode() + if len(self.server) > 0: + server = binascii.hexlify(self.server).decode() + else: + server = "" + return f"COOKIE {client}{server}" + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: dns.wire.Parser + ) -> Option: + return cls(parser.get_bytes(8), parser.get_remaining()) + + +class ReportChannelOption(Option): + # RFC 9567 + def __init__(self, agent_domain: dns.name.Name): + super().__init__(OptionType.REPORTCHANNEL) + self.agent_domain = agent_domain + + def to_wire(self, file: Any = None) -> Optional[bytes]: + return self.agent_domain.to_wire(file) + + def to_text(self) -> str: + return "REPORTCHANNEL " + self.agent_domain.to_text() + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: dns.wire.Parser + ) -> Option: + return cls(parser.get_name()) + + _type_to_class: Dict[OptionType, Any] = { OptionType.ECS: ECSOption, OptionType.EDE: EDEOption, OptionType.NSID: NSIDOption, + OptionType.COOKIE: CookieOption, + OptionType.REPORTCHANNEL: ReportChannelOption, } @@ -512,5 +567,6 @@ KEEPALIVE = OptionType.KEEPALIVE PADDING = OptionType.PADDING CHAIN = OptionType.CHAIN EDE = OptionType.EDE +REPORTCHANNEL = OptionType.REPORTCHANNEL ### END generated OptionType constants diff --git a/lib/dns/exception.py b/lib/dns/exception.py index 6982373d..223f2d68 100644 --- a/lib/dns/exception.py +++ b/lib/dns/exception.py @@ -81,7 +81,7 @@ class DNSException(Exception): if kwargs: assert ( set(kwargs.keys()) == self.supp_kwargs - ), "following set of keyword args is required: %s" % (self.supp_kwargs) + ), f"following set of keyword args is required: {self.supp_kwargs}" return kwargs def _fmt_kwargs(self, **kwargs): diff --git a/lib/dns/grange.py b/lib/dns/grange.py index 3a52278f..a967ca41 100644 --- a/lib/dns/grange.py +++ b/lib/dns/grange.py @@ -54,7 +54,7 @@ def from_text(text: str) -> Tuple[int, int, int]: elif c.isdigit(): cur += c else: - raise dns.exception.SyntaxError("Could not parse %s" % (c)) + raise dns.exception.SyntaxError(f"Could not parse {c}") if state == 0: raise dns.exception.SyntaxError("no stop value specified") diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index 44a10639..4dd1d1ca 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -143,9 +143,7 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes: if m is not None: b = dns.ipv4.inet_aton(m.group(2)) btext = ( - "{}:{:02x}{:02x}:{:02x}{:02x}".format( - m.group(1).decode(), b[0], b[1], b[2], b[3] - ) + f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}" ).encode() # # Try to turn '::' into ':'; if no match try to diff --git a/lib/dns/message.py b/lib/dns/message.py index 44cacbd9..e978a0a2 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -18,9 +18,10 @@ """DNS Messages""" import contextlib +import enum import io import time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast import dns.edns import dns.entropy @@ -161,6 +162,7 @@ class Message: self.index: IndexType = {} self.errors: List[MessageError] = [] self.time = 0.0 + self.wire: Optional[bytes] = None @property def question(self) -> List[dns.rrset.RRset]: @@ -220,16 +222,16 @@ class Message: s = io.StringIO() s.write("id %d\n" % self.id) - s.write("opcode %s\n" % dns.opcode.to_text(self.opcode())) - s.write("rcode %s\n" % dns.rcode.to_text(self.rcode())) - s.write("flags %s\n" % dns.flags.to_text(self.flags)) + s.write(f"opcode {dns.opcode.to_text(self.opcode())}\n") + s.write(f"rcode {dns.rcode.to_text(self.rcode())}\n") + s.write(f"flags {dns.flags.to_text(self.flags)}\n") if self.edns >= 0: - s.write("edns %s\n" % self.edns) + s.write(f"edns {self.edns}\n") if self.ednsflags != 0: - s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags)) + s.write(f"eflags {dns.flags.edns_to_text(self.ednsflags)}\n") s.write("payload %d\n" % self.payload) for opt in self.options: - s.write("option %s\n" % opt.to_text()) + s.write(f"option {opt.to_text()}\n") for name, which in self._section_enum.__members__.items(): s.write(f";{name}\n") for rrset in self.section_from_number(which): @@ -645,6 +647,7 @@ class Message: if multi: self.tsig_ctx = ctx wire = r.get_wire() + self.wire = wire if prepend_length: wire = len(wire).to_bytes(2, "big") + wire return wire @@ -912,6 +915,14 @@ class Message: self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) + def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]: + """Return the list of options of the specified type.""" + return [option for option in self.options if option.otype == otype] + + def extended_errors(self) -> List[dns.edns.EDEOption]: + """Return the list of Extended DNS Error (EDE) options in the message""" + return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE)) + def _get_one_rr_per_rrset(self, value): # What the caller picked is fine. return value @@ -1192,9 +1203,9 @@ class _WireReader: if rdtype == dns.rdatatype.OPT: self.message.opt = dns.rrset.from_rdata(name, ttl, rd) elif rdtype == dns.rdatatype.TSIG: - if self.keyring is None: + if self.keyring is None or self.keyring is True: raise UnknownTSIGKey("got signed message without keyring") - if isinstance(self.keyring, dict): + elif isinstance(self.keyring, dict): key = self.keyring.get(absolute_name) if isinstance(key, bytes): key = dns.tsig.Key(absolute_name, key, rd.algorithm) @@ -1203,19 +1214,20 @@ class _WireReader: else: key = self.keyring if key is None: - raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyring = key - self.message.tsig_ctx = dns.tsig.validate( - self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi, - ) + raise UnknownTSIGKey(f"key '{name}' unknown") + if key: + self.message.keyring = key + self.message.tsig_ctx = dns.tsig.validate( + self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi, + ) self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) else: rrset = self.message.find_rrset( @@ -1251,6 +1263,7 @@ class _WireReader: factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) self.message = factory(id=id) self.message.flags = dns.flags.Flag(flags) + self.message.wire = self.parser.wire self.initialize_message(self.message) self.one_rr_per_rrset = self.message._get_one_rr_per_rrset( self.one_rr_per_rrset @@ -1290,8 +1303,10 @@ def from_wire( ) -> Message: """Convert a DNS wire format message into a message object. - *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message - is signed. + *keyring*, a ``dns.tsig.Key``, ``dict``, ``bool``, or ``None``, the key or keyring + to use if the message is signed. If ``None`` or ``True``, then trying to decode + a message with a TSIG will fail as it cannot be validated. If ``False``, then + TSIG validation is disabled. *request_mac*, a ``bytes`` or ``None``. If the message is a response to a TSIG-signed request, *request_mac* should be set to the MAC of that request. @@ -1811,6 +1826,16 @@ def make_query( return m +class CopyMode(enum.Enum): + """ + How should sections be copied when making an update response? + """ + + NOTHING = 0 + QUESTION = 1 + EVERYTHING = 2 + + def make_response( query: Message, recursion_available: bool = False, @@ -1818,13 +1843,14 @@ def make_response( fudge: int = 300, tsig_error: int = 0, pad: Optional[int] = None, + copy_mode: Optional[CopyMode] = None, ) -> Message: """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the content. - The response's question section is a shallow copy of the query's question section, - so the query's question RRsets should not be changed. + Response section(s) which are copied are shallow copies of the matching section(s) + in the query, so the query's RRsets should not be changed. *query*, a ``dns.message.Message``, the query to respond to. @@ -1837,25 +1863,44 @@ def make_response( *tsig_error*, an ``int``, the TSIG error. *pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise - if not ``None`` add padding bytes to make the message size a multiple of *pad*. - Note that if padding is non-zero, an EDNS PADDING option will always be added to the + if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note + that if padding is non-zero, an EDNS PADDING option will always be added to the message. If ``None``, add padding following RFC 8467, namely if the request is padded, pad the response to 468 otherwise do not pad. + *copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are + copied. The default, ``None`` copies sections according to the default for the + message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all + opcodes. ``dns.message.CopyMode.QUESTION`` copies only the question section. + ``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG + records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING`` + copies no sections; note that this mode is for server testing purposes and is + otherwise not recommended for use. In particular, ``dns.message.is_response()`` + will be ``False`` if you create a response this way and the rcode is not + ``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``. + Returns a ``dns.message.Message`` object whose specific class is appropriate for the - query. For example, if query is a ``dns.update.UpdateMessage``, response will be - too. + query. For example, if query is a ``dns.update.UpdateMessage``, the response will + be one too. """ if query.flags & dns.flags.QR: raise dns.exception.FormError("specified query message is not a query") - factory = _message_factory_from_opcode(query.opcode()) + opcode = query.opcode() + factory = _message_factory_from_opcode(opcode) response = factory(id=query.id) response.flags = dns.flags.QR | (query.flags & dns.flags.RD) if recursion_available: response.flags |= dns.flags.RA - response.set_opcode(query.opcode()) - response.question = list(query.question) + response.set_opcode(opcode) + if copy_mode is None: + copy_mode = CopyMode.QUESTION + if copy_mode != CopyMode.NOTHING: + response.question = list(query.question) + if copy_mode == CopyMode.EVERYTHING: + response.answer = list(query.answer) + response.authority = list(query.authority) + response.additional = list(query.additional) if query.edns >= 0: if pad is None: # Set response padding per RFC 8467 diff --git a/lib/dns/name.py b/lib/dns/name.py index 22ccb392..f79f0d0f 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -59,11 +59,11 @@ class NameRelation(dns.enum.IntEnum): @classmethod def _maximum(cls): - return cls.COMMONANCESTOR + return cls.COMMONANCESTOR # pragma: no cover @classmethod def _short_name(cls): - return cls.__name__ + return cls.__name__ # pragma: no cover # Backwards compatibility @@ -277,6 +277,7 @@ class IDNA2008Codec(IDNACodec): raise NoIDNA2008 try: if self.uts_46: + # pylint: disable=possibly-used-before-assignment label = idna.uts46_remap(label, False, self.transitional) return idna.alabel(label) except idna.IDNAError as e: diff --git a/lib/dns/nameserver.py b/lib/dns/nameserver.py index 5dbb4e8b..b02a239b 100644 --- a/lib/dns/nameserver.py +++ b/lib/dns/nameserver.py @@ -168,12 +168,14 @@ class DoHNameserver(Nameserver): bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, want_get: bool = False, + http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT, ): super().__init__() self.url = url self.bootstrap_address = bootstrap_address self.verify = verify self.want_get = want_get + self.http_version = http_version def kind(self): return "DoH" @@ -214,6 +216,7 @@ class DoHNameserver(Nameserver): ignore_trailing=ignore_trailing, verify=self.verify, post=(not self.want_get), + http_version=self.http_version, ) async def async_query( @@ -238,6 +241,7 @@ class DoHNameserver(Nameserver): ignore_trailing=ignore_trailing, verify=self.verify, post=(not self.want_get), + http_version=self.http_version, ) diff --git a/lib/dns/query.py b/lib/dns/query.py index f0ee9161..0d8a977a 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -23,11 +23,13 @@ import enum import errno import os import os.path +import random import selectors import socket import struct import time -from typing import Any, Dict, Optional, Tuple, Union +import urllib.parse +from typing import Any, Dict, Optional, Tuple, Union, cast import dns._features import dns.exception @@ -129,7 +131,7 @@ if _have_httpx: family=socket.AF_UNSPEC, **kwargs, ): - if resolver is None: + if resolver is None and bootstrap_address is None: # pylint: disable=import-outside-toplevel,redefined-outer-name import dns.resolver @@ -217,7 +219,7 @@ def _wait_for(fd, readable, writable, _, expiration): if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: return True - sel = _selector_class() + sel = selectors.DefaultSelector() events = 0 if readable: events |= selectors.EVENT_READ @@ -235,26 +237,6 @@ def _wait_for(fd, readable, writable, _, expiration): raise dns.exception.Timeout -def _set_selector_class(selector_class): - # Internal API. Do not use. - - global _selector_class - - _selector_class = selector_class - - -if hasattr(selectors, "PollSelector"): - # Prefer poll() on platforms that support it because it has no - # limits on the maximum value of a file descriptor (plus it will - # be more efficient for high values). - # - # We ignore typing here as we can't say _selector_class is Any - # on python < 3.8 due to a bug. - _selector_class = selectors.PollSelector # type: ignore -else: - _selector_class = selectors.SelectSelector # type: ignore - - def _wait_for_readable(s, expiration): _wait_for(s, True, False, True, expiration) @@ -355,6 +337,36 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): raise +def _maybe_get_resolver( + resolver: Optional["dns.resolver.Resolver"], +) -> "dns.resolver.Resolver": + # We need a separate method for this to avoid overriding the global + # variable "dns" with the as-yet undefined local variable "dns" + # in https(). + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.resolver + + resolver = dns.resolver.Resolver() + return resolver + + +class HTTPVersion(enum.IntEnum): + """Which version of HTTP should be used? + + DEFAULT will select the first version from the list [2, 1.1, 3] that + is available. + """ + + DEFAULT = 0 + HTTP_1 = 1 + H1 = 1 + HTTP_2 = 2 + H2 = 2 + HTTP_3 = 3 + H3 = 3 + + def https( q: dns.message.Message, where: str, @@ -370,7 +382,8 @@ def https( bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, resolver: Optional["dns.resolver.Resolver"] = None, - family: Optional[int] = socket.AF_UNSPEC, + family: int = socket.AF_UNSPEC, + http_version: HTTPVersion = HTTPVersion.DEFAULT, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -420,27 +433,66 @@ def https( *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A and AAAA records will be retrieved. + *http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use. + Returns a ``dns.message.Message``. """ + (af, _, the_source) = _destination_and_source( + where, port, source, source_port, False + ) + if af is not None and dns.inet.is_address(where): + if af == socket.AF_INET: + url = f"https://{where}:{port}{path}" + elif af == socket.AF_INET6: + url = f"https://[{where}]:{port}{path}" + else: + url = where + + extensions = {} + if bootstrap_address is None: + # pylint: disable=possibly-used-before-assignment + parsed = urllib.parse.urlparse(url) + if parsed.hostname is None: + raise ValueError("no hostname in URL") + if dns.inet.is_address(parsed.hostname): + bootstrap_address = parsed.hostname + extensions["sni_hostname"] = parsed.hostname + if parsed.port is not None: + port = parsed.port + + if http_version == HTTPVersion.H3 or ( + http_version == HTTPVersion.DEFAULT and not have_doh + ): + if bootstrap_address is None: + resolver = _maybe_get_resolver(resolver) + assert parsed.hostname is not None # for mypy + answers = resolver.resolve_name(parsed.hostname, family) + bootstrap_address = random.choice(list(answers.addresses())) + return _http3( + q, + bootstrap_address, + url, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + verify=verify, + post=post, + ) + if not have_doh: raise NoDOH # pragma: no cover if session and not isinstance(session, httpx.Client): raise ValueError("session parameter must be an httpx.Client") wire = q.to_wire() - (af, _, the_source) = _destination_and_source( - where, port, source, source_port, False - ) - transport = None headers = {"accept": "application/dns-message"} - if af is not None and dns.inet.is_address(where): - if af == socket.AF_INET: - url = "https://{}:{}{}".format(where, port, path) - elif af == socket.AF_INET6: - url = "https://[{}]:{}{}".format(where, port, path) - else: - url = where + + h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT) + h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT) # set source port and source address @@ -450,21 +502,22 @@ def https( else: local_address = the_source[0] local_port = the_source[1] - transport = _HTTPTransport( - local_address=local_address, - http1=True, - http2=True, - verify=verify, - local_port=local_port, - bootstrap_address=bootstrap_address, - resolver=resolver, - family=family, - ) if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) else: - cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) + transport = _HTTPTransport( + local_address=local_address, + http1=h1, + http2=h2, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) + + cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport) with cm as session: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples @@ -475,20 +528,30 @@ def https( "content-length": str(len(wire)), } ) - response = session.post(url, headers=headers, content=wire, timeout=timeout) + response = session.post( + url, + headers=headers, + content=wire, + timeout=timeout, + extensions=extensions, + ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes response = session.get( - url, headers=headers, timeout=timeout, params={"dns": twire} + url, + headers=headers, + timeout=timeout, + params={"dns": twire}, + extensions=extensions, ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: raise ValueError( - "{} responded with status code {}" - "\nResponse body: {}".format(where, response.status_code, response.content) + f"{where} responded with status code {response.status_code}" + f"\nResponse body: {response.content}" ) r = dns.message.from_wire( response.content, @@ -503,6 +566,81 @@ def https( return r +def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes: + if headers is None: + raise KeyError + for header, value in headers: + if header == name: + return value + raise KeyError + + +def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None: + value = _find_header(headers, b":status") + if value is None: + raise SyntaxError("no :status header in response") + status = int(value) + if status < 0: + raise SyntaxError("status is negative") + if status < 200 or status > 299: + error = "" + if len(wire) > 0: + try: + error = ": " + wire.decode() + except Exception: + pass + raise ValueError(f"{peer} responded with status code {status}{error}") + + +def _http3( + q: dns.message.Message, + where: str, + url: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + verify: Union[bool, str] = True, + hostname: Optional[str] = None, + post: bool = True, +) -> dns.message.Message: + if not dns.quic.have_quic: + raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover + + url_parts = urllib.parse.urlparse(url) + hostname = url_parts.hostname + if url_parts.port is not None: + port = url_parts.port + + q.id = 0 + wire = q.to_wire() + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=hostname, h3=True + ) + + with manager: + connection = manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + with connection.make_stream(timeout) as stream: + stream.send_h3(url, wire, post) + wire = stream.receive(_remaining(expiration)) + _check_status(stream.headers(), where, wire) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + def _udp_recv(sock, max_size, expiration): """Reads a datagram from the socket. A Timeout exception will be raised if the operation is not completed @@ -855,7 +993,7 @@ def _net_read(sock, count, expiration): try: n = sock.recv(count) if n == b"": - raise EOFError + raise EOFError("EOF") count -= len(n) s += n except (BlockingIOError, ssl.SSLWantReadError): @@ -1023,6 +1161,7 @@ def tcp( cm = _make_socket(af, socket.SOCK_STREAM, source) with cm as s: if not sock: + # pylint: disable=possibly-used-before-assignment _connect(s, destination, expiration) send_tcp(s, wire, expiration) (r, received_time) = receive_tcp( @@ -1188,6 +1327,7 @@ def quic( ignore_trailing: bool = False, connection: Optional[dns.quic.SyncQuicConnection] = None, verify: Union[bool, str] = True, + hostname: Optional[str] = None, server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-QUIC. @@ -1212,17 +1352,21 @@ def quic( *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the - connection to use to send the query. + *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use + to send the query. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification of the server is done using the default CA bundle; if ``False``, then no verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. - *server_hostname*, a ``str`` containing the server's hostname. The - default is ``None``, which means that no hostname is known, and if an - SSL context is created, hostname checking will be disabled. + *hostname*, a ``str`` containing the server's hostname or ``None``. The default is + ``None``, which means that no hostname is known, and if an SSL context is created, + hostname checking will be disabled. This value is ignored if *url* is not + ``None``. + + *server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility + only, and has the same meaning as *hostname*. Returns a ``dns.message.Message``. """ @@ -1230,6 +1374,9 @@ def quic( if not dns.quic.have_quic: raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + if server_hostname is not None and hostname is None: + hostname = server_hostname + q.id = 0 wire = q.to_wire() the_connection: dns.quic.SyncQuicConnection @@ -1238,9 +1385,7 @@ def quic( manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) the_connection = connection else: - manager = dns.quic.SyncQuicManager( - verify_mode=verify, server_name=server_hostname - ) + manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname) the_manager = manager # for type checking happiness with manager: @@ -1264,6 +1409,70 @@ def quic( return r +class UDPMode(enum.IntEnum): + """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? + + NEVER means "never use UDP; always use TCP" + TRY_FIRST means "try to use UDP but fall back to TCP if needed" + ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" + """ + + NEVER = 0 + TRY_FIRST = 1 + ONLY = 2 + + +def _inbound_xfr( + txn_manager: dns.transaction.TransactionManager, + s: socket.socket, + query: dns.message.Message, + serial: Optional[int], + timeout: Optional[float], + expiration: float, +) -> Any: + """Given a socket, does the zone transfer.""" + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + is_udp = s.type == socket.SOCK_DGRAM + if is_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + _net_write(s, tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if is_udp: + (rwire, _) = _udp_recv(s, 65535, mexpiration) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = _net_read(s, l, mexpiration) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + done = inbound.process_message(r) + yield r + tsig_ctx = r.tsig_ctx + if query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + + def xfr( where: str, zone: Union[dns.name.Name, str], @@ -1333,134 +1542,52 @@ def xfr( Returns a generator of ``dns.message.Message`` objects. """ + class DummyTransactionManager(dns.transaction.TransactionManager): + def __init__(self, origin, relativize): + self.info = (origin, relativize, dns.name.empty if relativize else origin) + + def origin_information(self): + return self.info + + def get_class(self) -> dns.rdataclass.RdataClass: + raise NotImplementedError # pragma: no cover + + def reader(self): + raise NotImplementedError # pragma: no cover + + def writer(self, replacement: bool = False) -> dns.transaction.Transaction: + class DummyTransaction: + def nop(self, *args, **kw): + pass + + def __getattr__(self, _): + return self.nop + + return cast(dns.transaction.Transaction, DummyTransaction()) + if isinstance(zone, str): zone = dns.name.from_text(zone) rdtype = dns.rdatatype.RdataType.make(rdtype) q = dns.message.make_query(zone, rdtype, rdclass) if rdtype == dns.rdatatype.IXFR: - rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) - q.authority.append(rrset) + rrset = q.find_rrset( + q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True + ) + soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial) + rrset.add(soa, 0) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) - wire = q.to_wire() (af, destination, source) = _destination_and_source( where, port, source, source_port ) + (_, expiration) = _compute_times(lifetime) + tm = DummyTransactionManager(zone, relativize) if use_udp and rdtype != dns.rdatatype.IXFR: raise ValueError("cannot do a UDP AXFR") sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM with _make_socket(af, sock_type, source) as s: - (_, expiration) = _compute_times(lifetime) _connect(s, destination, expiration) - l = len(wire) - if use_udp: - _udp_send(s, wire, None, expiration) - else: - tcpmsg = struct.pack("!H", l) + wire - _net_write(s, tcpmsg, expiration) - done = False - delete_mode = True - expecting_SOA = False - soa_rrset = None - if relativize: - origin = zone - oname = dns.name.empty - else: - origin = None - oname = zone - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if use_udp: - (wire, _) = _udp_recv(s, 65535, mexpiration) - else: - ldata = _net_read(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - wire = _net_read(s, l, mexpiration) - is_ixfr = rdtype == dns.rdatatype.IXFR - r = dns.message.from_wire( - wire, - keyring=q.keyring, - request_mac=q.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=True, - one_rr_per_rrset=is_ixfr, - ) - rcode = r.rcode() - if rcode != dns.rcode.NOERROR: - raise TransferError(rcode) - tsig_ctx = r.tsig_ctx - answer_index = 0 - if soa_rrset is None: - if not r.answer or r.answer[0].name != oname: - raise dns.exception.FormError("No answer or RRset not for qname") - rrset = r.answer[0] - if rrset.rdtype != dns.rdatatype.SOA: - raise dns.exception.FormError("first RRset is not an SOA") - answer_index = 1 - soa_rrset = rrset.copy() - if rdtype == dns.rdatatype.IXFR: - if dns.serial.Serial(soa_rrset[0].serial) <= serial: - # - # We're already up-to-date. - # - done = True - else: - expecting_SOA = True - # - # Process SOAs in the answer section (other than the initial - # SOA in the first message). - # - for rrset in r.answer[answer_index:]: - if done: - raise dns.exception.FormError("answers after final SOA") - if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: - if expecting_SOA: - if rrset[0].serial != serial: - raise dns.exception.FormError("IXFR base serial mismatch") - expecting_SOA = False - elif rdtype == dns.rdatatype.IXFR: - delete_mode = not delete_mode - # - # If this SOA RRset is equal to the first we saw then we're - # finished. If this is an IXFR we also check that we're - # seeing the record in the expected part of the response. - # - if rrset == soa_rrset and ( - rdtype == dns.rdatatype.AXFR - or (rdtype == dns.rdatatype.IXFR and delete_mode) - ): - done = True - elif expecting_SOA: - # - # We made an IXFR request and are expecting another - # SOA RR, but saw something else, so this must be an - # AXFR response. - # - rdtype = dns.rdatatype.AXFR - expecting_SOA = False - if done and q.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") - yield r - - -class UDPMode(enum.IntEnum): - """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? - - NEVER means "never use UDP; always use TCP" - TRY_FIRST means "try to use UDP but fall back to TCP if needed" - ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" - """ - - NEVER = 0 - TRY_FIRST = 1 - ONLY = 2 + yield from _inbound_xfr(tm, s, q, serial, timeout, expiration) def inbound_xfr( @@ -1514,65 +1641,25 @@ def inbound_xfr( (query, serial) = dns.xfr.make_query(txn_manager) else: serial = dns.xfr.extract_serial_from_query(query) - rdtype = query.question[0].rdtype - is_ixfr = rdtype == dns.rdatatype.IXFR - origin = txn_manager.from_wire_origin() - wire = query.to_wire() + (af, destination, source) = _destination_and_source( where, port, source, source_port ) (_, expiration) = _compute_times(lifetime) - retry = True - while retry: - retry = False - if is_ixfr and udp_mode != UDPMode.NEVER: - sock_type = socket.SOCK_DGRAM - is_udp = True - else: - sock_type = socket.SOCK_STREAM - is_udp = False - with _make_socket(af, sock_type, source) as s: + if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER: + with _make_socket(af, socket.SOCK_DGRAM, source) as s: _connect(s, destination, expiration) - if is_udp: - _udp_send(s, wire, None, expiration) - else: - tcpmsg = struct.pack("!H", len(wire)) + wire - _net_write(s, tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: - done = False - tsig_ctx = None - while not done: - (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or ( - expiration is not None and mexpiration > expiration - ): - mexpiration = expiration - if is_udp: - (rwire, _) = _udp_recv(s, 65535, mexpiration) - else: - ldata = _net_read(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - rwire = _net_read(s, l, mexpiration) - r = dns.message.from_wire( - rwire, - keyring=query.keyring, - request_mac=query.mac, - xfr=True, - origin=origin, - tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr, - ) - try: - done = inbound.process_message(r) - except dns.xfr.UseTCP: - assert is_udp # should not happen if we used TCP! - if udp_mode == UDPMode.ONLY: - raise - done = True - retry = True - udp_mode = UDPMode.NEVER - continue - tsig_ctx = r.tsig_ctx - if not retry and query.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") + try: + for _ in _inbound_xfr( + txn_manager, s, query, serial, timeout, expiration + ): + pass + return + except dns.xfr.UseTCP: + if udp_mode == UDPMode.ONLY: + raise + + with _make_socket(af, socket.SOCK_STREAM, source) as s: + _connect(s, destination, expiration) + for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration): + pass diff --git a/lib/dns/quic/__init__.py b/lib/dns/quic/__init__.py index 20aff345..0750e729 100644 --- a/lib/dns/quic/__init__.py +++ b/lib/dns/quic/__init__.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import List, Tuple + import dns._features import dns.asyncbackend @@ -73,3 +75,6 @@ else: # pragma: no cover class SyncQuicConnection: # type: ignore def make_stream(self) -> Any: raise NotImplementedError + + +Headers = List[Tuple[bytes, bytes]] diff --git a/lib/dns/quic/_asyncio.py b/lib/dns/quic/_asyncio.py index 0f44331f..f87515da 100644 --- a/lib/dns/quic/_asyncio.py +++ b/lib/dns/quic/_asyncio.py @@ -43,12 +43,26 @@ class AsyncioQuicStream(BaseQuicStream): raise dns.exception.Timeout self._expecting = 0 + async def wait_for_end(self, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + if self._buffer.seen_end(): + return + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except TimeoutError: + raise dns.exception.Timeout + async def receive(self, timeout=None): expiration = self._expiration_from_timeout(timeout) - await self.wait_for(2, expiration) - (size,) = struct.unpack("!H", self._buffer.get(2)) - await self.wait_for(size, expiration) - return self._buffer.get(size) + if self._connection.is_h3(): + await self.wait_for_end(expiration) + return self._buffer.get_all() + else: + await self.wait_for(2, expiration) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size, expiration) + return self._buffer.get(size) async def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -83,6 +97,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._wake_timer = asyncio.Condition() self._receiver_task = None self._sender_task = None + self._wake_pending = False async def _receiver(self): try: @@ -104,19 +119,24 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._connection.receive_datagram(datagram, address, time.time()) # Wake up the timer in case the sender is sleeping, as there may be # stuff to send now. - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() except Exception: pass finally: self._done = True - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() self._handshake_complete.set() + async def _wakeup(self): + self._wake_pending = True + async with self._wake_timer: + self._wake_timer.notify_all() + async def _wait_for_wake_timer(self): async with self._wake_timer: - await self._wake_timer.wait() + if not self._wake_pending: + await self._wake_timer.wait() + self._wake_pending = False async def _sender(self): await self._socket_created.wait() @@ -140,9 +160,28 @@ class AsyncioQuicConnection(AsyncQuicConnection): if event is None: return if isinstance(event, aioquic.quic.events.StreamDataReceived): - stream = self._streams.get(event.stream_id) - if stream: - await stream._add_input(event.data, event.end_stream) + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + await stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input( + h3_event.data, h3_event.stream_ended + ) + else: + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): @@ -161,8 +200,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): async def write(self, stream, data, is_end=False): self._connection.send_stream_data(stream, data, is_end) - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() def run(self): if self._closed: @@ -189,8 +227,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._connection.close() # sender might be blocked on this, so set it self._socket_created.set() - async with self._wake_timer: - self._wake_timer.notify_all() + await self._wakeup() try: await self._receiver_task except asyncio.CancelledError: @@ -203,8 +240,10 @@ class AsyncioQuicConnection(AsyncQuicConnection): class AsyncioQuicManager(AsyncQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): - super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) + def __init__( + self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False + ): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3) def connect( self, address, port=853, source=None, source_port=0, want_session_ticket=True diff --git a/lib/dns/quic/_common.py b/lib/dns/quic/_common.py index 0eacc691..ce575b03 100644 --- a/lib/dns/quic/_common.py +++ b/lib/dns/quic/_common.py @@ -1,12 +1,16 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import base64 import copy import functools import socket import struct import time +import urllib from typing import Any, Optional +import aioquic.h3.connection # type: ignore +import aioquic.h3.events # type: ignore import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore @@ -51,6 +55,12 @@ class Buffer: self._buffer = self._buffer[amount:] return data + def get_all(self): + assert self.seen_end() + data = self._buffer + self._buffer = b"" + return data + class BaseQuicStream: def __init__(self, connection, stream_id): @@ -58,10 +68,18 @@ class BaseQuicStream: self._stream_id = stream_id self._buffer = Buffer() self._expecting = 0 + self._headers = None + self._trailers = None def id(self): return self._stream_id + def headers(self): + return self._headers + + def trailers(self): + return self._trailers + def _expiration_from_timeout(self, timeout): if timeout is not None: expiration = time.time() + timeout @@ -77,16 +95,51 @@ class BaseQuicStream: return timeout # Subclass must implement receive() as sync / async and which returns a message - # or raises UnexpectedEOF. + # or raises. + + # Subclass must implement send() as sync / async and which takes a message and + # an EOF indicator. + + def send_h3(self, url, datagram, post=True): + if not self._connection.is_h3(): + raise SyntaxError("cannot send H3 to a non-H3 connection") + url_parts = urllib.parse.urlparse(url) + path = url_parts.path.encode() + if post: + method = b"POST" + else: + method = b"GET" + path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=") + headers = [ + (b":method", method), + (b":scheme", url_parts.scheme.encode()), + (b":authority", url_parts.netloc.encode()), + (b":path", path), + (b"accept", b"application/dns-message"), + ] + if post: + headers.extend( + [ + (b"content-type", b"application/dns-message"), + (b"content-length", str(len(datagram)).encode()), + ] + ) + self._connection.send_headers(self._stream_id, headers, not post) + if post: + self._connection.send_data(self._stream_id, datagram, True) def _encapsulate(self, datagram): + if self._connection.is_h3(): + return datagram l = len(datagram) return struct.pack("!H", l) + datagram def _common_add_input(self, data, is_end): self._buffer.put(data, is_end) try: - return self._expecting > 0 and self._buffer.have(self._expecting) + return ( + self._expecting > 0 and self._buffer.have(self._expecting) + ) or self._buffer.seen_end except UnexpectedEOF: return True @@ -97,7 +150,13 @@ class BaseQuicStream: class BaseQuicConnection: def __init__( - self, connection, address, port, source=None, source_port=0, manager=None + self, + connection, + address, + port, + source=None, + source_port=0, + manager=None, ): self._done = False self._connection = connection @@ -106,6 +165,10 @@ class BaseQuicConnection: self._closed = False self._manager = manager self._streams = {} + if manager.is_h3(): + self._h3_conn = aioquic.h3.connection.H3Connection(connection, False) + else: + self._h3_conn = None self._af = dns.inet.af_for_address(address) self._peer = dns.inet.low_level_address_tuple((address, port)) if source is None and source_port != 0: @@ -120,9 +183,18 @@ class BaseQuicConnection: else: self._source = None + def is_h3(self): + return self._h3_conn is not None + def close_stream(self, stream_id): del self._streams[stream_id] + def send_headers(self, stream_id, headers, is_end=False): + self._h3_conn.send_headers(stream_id, headers, is_end) + + def send_data(self, stream_id, data, is_end=False): + self._h3_conn.send_data(stream_id, data, is_end) + def _get_timer_values(self, closed_is_special=True): now = time.time() expiration = self._connection.get_timer() @@ -148,17 +220,25 @@ class AsyncQuicConnection(BaseQuicConnection): class BaseQuicManager: - def __init__(self, conf, verify_mode, connection_factory, server_name=None): + def __init__( + self, conf, verify_mode, connection_factory, server_name=None, h3=False + ): self._connections = {} self._connection_factory = connection_factory self._session_tickets = {} + self._tokens = {} + self._h3 = h3 if conf is None: verify_path = None if isinstance(verify_mode, str): verify_path = verify_mode verify_mode = True + if h3: + alpn_protocols = ["h3"] + else: + alpn_protocols = ["doq", "doq-i03"] conf = aioquic.quic.configuration.QuicConfiguration( - alpn_protocols=["doq", "doq-i03"], + alpn_protocols=alpn_protocols, verify_mode=verify_mode, server_name=server_name, ) @@ -167,7 +247,13 @@ class BaseQuicManager: self._conf = conf def _connect( - self, address, port=853, source=None, source_port=0, want_session_ticket=True + self, + address, + port=853, + source=None, + source_port=0, + want_session_ticket=True, + want_token=True, ): connection = self._connections.get((address, port)) if connection is not None: @@ -189,9 +275,24 @@ class BaseQuicManager: ) else: session_ticket_handler = None + if want_token: + try: + token = self._tokens.pop((address, port)) + # We found a token, so make a configuration that uses it. + conf = copy.copy(conf) + conf.token = token + except KeyError: + # No token + pass + # Whether or not we found a token, we want a handler to save # one. + token_handler = functools.partial(self.save_token, address, port) + else: + token_handler = None + qconn = aioquic.quic.connection.QuicConnection( configuration=conf, session_ticket_handler=session_ticket_handler, + token_handler=token_handler, ) lladdress = dns.inet.low_level_address_tuple((address, port)) qconn.connect(lladdress, time.time()) @@ -207,6 +308,9 @@ class BaseQuicManager: except KeyError: pass + def is_h3(self): + return self._h3 + def save_session_ticket(self, address, port, ticket): # We rely on dictionaries keys() being in insertion order here. We # can't just popitem() as that would be LIFO which is the opposite of @@ -218,6 +322,17 @@ class BaseQuicManager: del self._session_tickets[key] self._session_tickets[(address, port)] = ticket + def save_token(self, address, port, token): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._tokens) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._tokens[key] + self._tokens[(address, port)] = token + class AsyncQuicManager(BaseQuicManager): def connect(self, address, port=853, source=None, source_port=0): diff --git a/lib/dns/quic/_sync.py b/lib/dns/quic/_sync.py index 120cb5f3..473d1f48 100644 --- a/lib/dns/quic/_sync.py +++ b/lib/dns/quic/_sync.py @@ -21,11 +21,9 @@ from dns.quic._common import ( UnexpectedEOF, ) -# Avoid circularity with dns.query -if hasattr(selectors, "PollSelector"): - _selector_class = selectors.PollSelector # type: ignore -else: - _selector_class = selectors.SelectSelector # type: ignore +# Function used to create a socket. Can be overridden if needed in special +# situations. +socket_factory = socket.socket class SyncQuicStream(BaseQuicStream): @@ -46,14 +44,29 @@ class SyncQuicStream(BaseQuicStream): raise dns.exception.Timeout self._expecting = 0 + def wait_for_end(self, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + with self._lock: + if self._buffer.seen_end(): + return + with self._wake_up: + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout + def receive(self, timeout=None): expiration = self._expiration_from_timeout(timeout) - self.wait_for(2, expiration) - with self._lock: - (size,) = struct.unpack("!H", self._buffer.get(2)) - self.wait_for(size, expiration) - with self._lock: - return self._buffer.get(size) + if self._connection.is_h3(): + self.wait_for_end(expiration) + with self._lock: + return self._buffer.get_all() + else: + self.wait_for(2, expiration) + with self._lock: + (size,) = struct.unpack("!H", self._buffer.get(2)) + self.wait_for(size, expiration) + with self._lock: + return self._buffer.get(size) def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -81,7 +94,7 @@ class SyncQuicStream(BaseQuicStream): class SyncQuicConnection(BaseQuicConnection): def __init__(self, connection, address, port, source, source_port, manager): super().__init__(connection, address, port, source, source_port, manager) - self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) + self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0) if self._source is not None: try: self._socket.bind( @@ -118,7 +131,7 @@ class SyncQuicConnection(BaseQuicConnection): def _worker(self): try: - sel = _selector_class() + sel = selectors.DefaultSelector() sel.register(self._socket, selectors.EVENT_READ, self._read) sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) while not self._done: @@ -140,6 +153,7 @@ class SyncQuicConnection(BaseQuicConnection): finally: with self._lock: self._done = True + self._socket.close() # Ensure anyone waiting for this gets woken up. self._handshake_complete.set() @@ -150,10 +164,29 @@ class SyncQuicConnection(BaseQuicConnection): if event is None: return if isinstance(event, aioquic.quic.events.StreamDataReceived): - with self._lock: - stream = self._streams.get(event.stream_id) - if stream: - stream._add_input(event.data, event.end_stream) + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(h3_event.data, h3_event.stream_ended) + else: + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): @@ -170,6 +203,18 @@ class SyncQuicConnection(BaseQuicConnection): self._connection.send_stream_data(stream, data, is_end) self._send_wakeup.send(b"\x01") + def send_headers(self, stream_id, headers, is_end=False): + with self._lock: + super().send_headers(stream_id, headers, is_end) + if is_end: + self._send_wakeup.send(b"\x01") + + def send_data(self, stream_id, data, is_end=False): + with self._lock: + super().send_data(stream_id, data, is_end) + if is_end: + self._send_wakeup.send(b"\x01") + def run(self): if self._closed: return @@ -203,16 +248,24 @@ class SyncQuicConnection(BaseQuicConnection): class SyncQuicManager(BaseQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): - super().__init__(conf, verify_mode, SyncQuicConnection, server_name) + def __init__( + self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False + ): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3) self._lock = threading.Lock() def connect( - self, address, port=853, source=None, source_port=0, want_session_ticket=True + self, + address, + port=853, + source=None, + source_port=0, + want_session_ticket=True, + want_token=True, ): with self._lock: (connection, start) = self._connect( - address, port, source, source_port, want_session_ticket + address, port, source, source_port, want_session_ticket, want_token ) if start: connection.run() @@ -226,6 +279,10 @@ class SyncQuicManager(BaseQuicManager): with self._lock: super().save_session_ticket(address, port, ticket) + def save_token(self, address, port, token): + with self._lock: + super().save_token(address, port, token) + def __enter__(self): return self diff --git a/lib/dns/quic/_trio.py b/lib/dns/quic/_trio.py index 35e36b98..ae79f369 100644 --- a/lib/dns/quic/_trio.py +++ b/lib/dns/quic/_trio.py @@ -36,16 +36,27 @@ class TrioQuicStream(BaseQuicStream): await self._wake_up.wait() self._expecting = 0 + async def wait_for_end(self): + while True: + if self._buffer.seen_end(): + return + async with self._wake_up: + await self._wake_up.wait() + async def receive(self, timeout=None): if timeout is None: context = NullContext(None) else: context = trio.move_on_after(timeout) with context: - await self.wait_for(2) - (size,) = struct.unpack("!H", self._buffer.get(2)) - await self.wait_for(size) - return self._buffer.get(size) + if self._connection.is_h3(): + await self.wait_for_end() + return self._buffer.get_all() + else: + await self.wait_for(2) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size) + return self._buffer.get(size) raise dns.exception.Timeout async def send(self, datagram, is_end=False): @@ -115,6 +126,7 @@ class TrioQuicConnection(AsyncQuicConnection): await self._socket.send(datagram) finally: self._done = True + self._socket.close() self._handshake_complete.set() async def _handle_events(self): @@ -124,9 +136,28 @@ class TrioQuicConnection(AsyncQuicConnection): if event is None: return if isinstance(event, aioquic.quic.events.StreamDataReceived): - stream = self._streams.get(event.stream_id) - if stream: - await stream._add_input(event.data, event.end_stream) + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + await stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input( + h3_event.data, h3_event.stream_ended + ) + else: + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): @@ -183,9 +214,14 @@ class TrioQuicConnection(AsyncQuicConnection): class TrioQuicManager(AsyncQuicManager): def __init__( - self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None + self, + nursery, + conf=None, + verify_mode=ssl.CERT_REQUIRED, + server_name=None, + h3=False, ): - super().__init__(conf, verify_mode, TrioQuicConnection, server_name) + super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3) self._nursery = nursery def connect( diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index 024fd8f6..8099c26a 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -214,7 +214,7 @@ class Rdata: compress: Optional[dns.name.CompressType] = None, origin: Optional[dns.name.Name] = None, canonicalize: bool = False, - ) -> bytes: + ) -> None: raise NotImplementedError # pragma: no cover def to_wire( @@ -223,14 +223,19 @@ class Rdata: compress: Optional[dns.name.CompressType] = None, origin: Optional[dns.name.Name] = None, canonicalize: bool = False, - ) -> bytes: + ) -> Optional[bytes]: """Convert an rdata to wire format. - Returns a ``bytes`` or ``None``. + Returns a ``bytes`` if no output file was specified, or ``None`` otherwise. """ if file: - return self._to_wire(file, compress, origin, canonicalize) + # We call _to_wire() and then return None explicitly instead of + # of just returning the None from _to_wire() as mypy's func-returns-value + # unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return + # a value (it only ever returns None)" + self._to_wire(file, compress, origin, canonicalize) + return None else: f = io.BytesIO() self._to_wire(f, compress, origin, canonicalize) @@ -253,8 +258,9 @@ class Rdata: Returns a ``bytes``. """ - - return self.to_wire(origin=origin, canonicalize=True) + wire = self.to_wire(origin=origin, canonicalize=True) + assert wire is not None # for mypy + return wire def __repr__(self): covers = self.covers() @@ -434,15 +440,11 @@ class Rdata: continue if key not in parameters: raise AttributeError( - "'{}' object has no attribute '{}'".format( - self.__class__.__name__, key - ) + f"'{self.__class__.__name__}' object has no attribute '{key}'" ) if key in ("rdclass", "rdtype"): raise AttributeError( - "Cannot overwrite '{}' attribute '{}'".format( - self.__class__.__name__, key - ) + f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'" ) # Construct the parameter list. For each field, use the value in @@ -646,13 +648,14 @@ _rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], {} ) _module_prefix = "dns.rdtypes" +_dynamic_load_allowed = True -def get_rdata_class(rdclass, rdtype): +def get_rdata_class(rdclass, rdtype, use_generic=True): cls = _rdata_classes.get((rdclass, rdtype)) if not cls: cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype)) - if not cls: + if not cls and _dynamic_load_allowed: rdclass_text = dns.rdataclass.to_text(rdclass) rdtype_text = dns.rdatatype.to_text(rdtype) rdtype_text = rdtype_text.replace("-", "_") @@ -670,12 +673,36 @@ def get_rdata_class(rdclass, rdtype): _rdata_classes[(rdclass, rdtype)] = cls except ImportError: pass - if not cls: + if not cls and use_generic: cls = GenericRdata _rdata_classes[(rdclass, rdtype)] = cls return cls +def load_all_types(disable_dynamic_load=True): + """Load all rdata types for which dnspython has a non-generic implementation. + + Normally dnspython loads DNS rdatatype implementations on demand, but in some + specialized cases loading all types at an application-controlled time is preferred. + + If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt + to use its dynamic loading mechanism if an unknown type is subsequently encountered, + and will simply use the ``GenericRdata`` class. + """ + # Load class IN and ANY types. + for rdtype in dns.rdatatype.RdataType: + get_rdata_class(dns.rdataclass.IN, rdtype, False) + # Load the one non-ANY implementation we have in CH. Everything + # else in CH is an ANY type, and we'll discover those on demand but won't + # have to import anything. + get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False) + if disable_dynamic_load: + # Now disable dynamic loading so any subsequent unknown type immediately becomes + # GenericRdata without a load attempt. + global _dynamic_load_allowed + _dynamic_load_allowed = False + + def from_text( rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.RdataType, str], diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index 8bff58d7..39cab236 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -160,7 +160,7 @@ class Rdataset(dns.set.Set): return s[:100] + "..." return s - return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self) + return "[" + ", ".join(f"<{maybe_truncate(str(rr))}>" for rr in self) + "]" def __repr__(self): if self.covers == 0: @@ -248,12 +248,8 @@ class Rdataset(dns.set.Set): # (which is meaningless anyway). # s.write( - "{}{}{} {}\n".format( - ntext, - pad, - dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype), - ) + f"{ntext}{pad}{dns.rdataclass.to_text(rdclass)} " + f"{dns.rdatatype.to_text(self.rdtype)}\n" ) else: for rd in self: diff --git a/lib/dns/rdatatype.py b/lib/dns/rdatatype.py index e6c58186..aa9e561c 100644 --- a/lib/dns/rdatatype.py +++ b/lib/dns/rdatatype.py @@ -105,6 +105,8 @@ class RdataType(dns.enum.IntEnum): CAA = 257 AVC = 258 AMTRELAY = 260 + RESINFO = 261 + WALLET = 262 TA = 32768 DLV = 32769 @@ -125,7 +127,7 @@ class RdataType(dns.enum.IntEnum): if text.find("-") >= 0: try: return cls[text.replace("-", "_")] - except KeyError: + except KeyError: # pragma: no cover pass return _registered_by_text.get(text) @@ -326,6 +328,8 @@ URI = RdataType.URI CAA = RdataType.CAA AVC = RdataType.AVC AMTRELAY = RdataType.AMTRELAY +RESINFO = RdataType.RESINFO +WALLET = RdataType.WALLET TA = RdataType.TA DLV = RdataType.DLV diff --git a/lib/dns/rdtypes/ANY/GPOS.py b/lib/dns/rdtypes/ANY/GPOS.py index 312338f9..d79f4a06 100644 --- a/lib/dns/rdtypes/ANY/GPOS.py +++ b/lib/dns/rdtypes/ANY/GPOS.py @@ -75,8 +75,9 @@ class GPOS(dns.rdata.Rdata): raise dns.exception.FormError("bad longitude") def to_text(self, origin=None, relativize=True, **kw): - return "{} {} {}".format( - self.latitude.decode(), self.longitude.decode(), self.altitude.decode() + return ( + f"{self.latitude.decode()} {self.longitude.decode()} " + f"{self.altitude.decode()}" ) @classmethod diff --git a/lib/dns/rdtypes/ANY/HINFO.py b/lib/dns/rdtypes/ANY/HINFO.py index c2c45de0..06ad3487 100644 --- a/lib/dns/rdtypes/ANY/HINFO.py +++ b/lib/dns/rdtypes/ANY/HINFO.py @@ -37,9 +37,7 @@ class HINFO(dns.rdata.Rdata): self.os = self._as_bytes(os, True, 255) def to_text(self, origin=None, relativize=True, **kw): - return '"{}" "{}"'.format( - dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os) - ) + return f'"{dns.rdata._escapify(self.cpu)}" "{dns.rdata._escapify(self.os)}"' @classmethod def from_text( diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index 91669139..f3157da7 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -48,7 +48,7 @@ class HIP(dns.rdata.Rdata): for server in self.servers: servers.append(server.choose_relativity(origin, relativize)) if len(servers) > 0: - text += " " + " ".join((x.to_unicode() for x in servers)) + text += " " + " ".join(x.to_unicode() for x in servers) return "%u %s %s%s" % (self.algorithm, hit, key, text) @classmethod diff --git a/lib/dns/rdtypes/ANY/ISDN.py b/lib/dns/rdtypes/ANY/ISDN.py index fb01eab3..6428a0a8 100644 --- a/lib/dns/rdtypes/ANY/ISDN.py +++ b/lib/dns/rdtypes/ANY/ISDN.py @@ -38,11 +38,12 @@ class ISDN(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): if self.subaddress: - return '"{}" "{}"'.format( - dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress) + return ( + f'"{dns.rdata._escapify(self.address)}" ' + f'"{dns.rdata._escapify(self.subaddress)}"' ) else: - return '"%s"' % dns.rdata._escapify(self.address) + return f'"{dns.rdata._escapify(self.address)}"' @classmethod def from_text( diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index a36a2c10..1153cf03 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -44,7 +44,7 @@ def _exponent_of(what, desc): exp = i - 1 break if exp is None or exp < 0: - raise dns.exception.SyntaxError("%s value out of bounds" % desc) + raise dns.exception.SyntaxError(f"{desc} value out of bounds") return exp @@ -83,10 +83,10 @@ def _encode_size(what, desc): def _decode_size(what, desc): exponent = what & 0x0F if exponent > 9: - raise dns.exception.FormError("bad %s exponent" % desc) + raise dns.exception.FormError(f"bad {desc} exponent") base = (what & 0xF0) >> 4 if base > 9: - raise dns.exception.FormError("bad %s base" % desc) + raise dns.exception.FormError(f"bad {desc} base") return base * pow(10, exponent) @@ -184,10 +184,9 @@ class LOC(dns.rdata.Rdata): or self.horizontal_precision != _default_hprec or self.vertical_precision != _default_vprec ): - text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( - self.size / 100.0, - self.horizontal_precision / 100.0, - self.vertical_precision / 100.0, + text += ( + f" {self.size / 100.0:0.2f}m {self.horizontal_precision / 100.0:0.2f}m" + f" {self.vertical_precision / 100.0:0.2f}m" ) return text diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index 340525a6..3c78b722 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -44,7 +44,7 @@ class NSEC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): next = self.next.choose_relativity(origin, relativize) text = Bitmap(self.windows).to_text() - return "{}{}".format(next, text) + return f"{next}{text}" @classmethod def from_text( diff --git a/lib/dns/rdtypes/ANY/RESINFO.py b/lib/dns/rdtypes/ANY/RESINFO.py new file mode 100644 index 00000000..76c8ea2a --- /dev/null +++ b/lib/dns/rdtypes/ANY/RESINFO.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.txtbase + + +@dns.immutable.immutable +class RESINFO(dns.rdtypes.txtbase.TXTBase): + """RESINFO record""" diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index 9b74549d..a66cfc50 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -37,7 +37,7 @@ class RP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): mbox = self.mbox.choose_relativity(origin, relativize) txt = self.txt.choose_relativity(origin, relativize) - return "{} {}".format(str(mbox), str(txt)) + return f"{str(mbox)} {str(txt)}" @classmethod def from_text( diff --git a/lib/dns/rdtypes/ANY/TKEY.py b/lib/dns/rdtypes/ANY/TKEY.py index 5b490b82..75f62249 100644 --- a/lib/dns/rdtypes/ANY/TKEY.py +++ b/lib/dns/rdtypes/ANY/TKEY.py @@ -69,7 +69,7 @@ class TKEY(dns.rdata.Rdata): dns.rdata._base64ify(self.key, 0), ) if len(self.other) > 0: - text += " %s" % (dns.rdata._base64ify(self.other, 0)) + text += f" {dns.rdata._base64ify(self.other, 0)}" return text diff --git a/lib/dns/rdtypes/ANY/WALLET.py b/lib/dns/rdtypes/ANY/WALLET.py new file mode 100644 index 00000000..ff464763 --- /dev/null +++ b/lib/dns/rdtypes/ANY/WALLET.py @@ -0,0 +1,9 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.immutable +import dns.rdtypes.txtbase + + +@dns.immutable.immutable +class WALLET(dns.rdtypes.txtbase.TXTBase): + """WALLET record""" diff --git a/lib/dns/rdtypes/ANY/X25.py b/lib/dns/rdtypes/ANY/X25.py index 8375611d..2436ddb6 100644 --- a/lib/dns/rdtypes/ANY/X25.py +++ b/lib/dns/rdtypes/ANY/X25.py @@ -36,7 +36,7 @@ class X25(dns.rdata.Rdata): self.address = self._as_bytes(address, True, 255) def to_text(self, origin=None, relativize=True, **kw): - return '"%s"' % dns.rdata._escapify(self.address) + return f'"{dns.rdata._escapify(self.address)}"' @classmethod def from_text( diff --git a/lib/dns/rdtypes/ANY/__init__.py b/lib/dns/rdtypes/ANY/__init__.py index 3824a0a0..647b215b 100644 --- a/lib/dns/rdtypes/ANY/__init__.py +++ b/lib/dns/rdtypes/ANY/__init__.py @@ -51,6 +51,7 @@ __all__ = [ "OPENPGPKEY", "OPT", "PTR", + "RESINFO", "RP", "RRSIG", "RT", @@ -63,6 +64,7 @@ __all__ = [ "TSIG", "TXT", "URI", + "WALLET", "X25", "ZONEMD", ] diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py index 583a88ac..832e8d3a 100644 --- a/lib/dns/rdtypes/CH/A.py +++ b/lib/dns/rdtypes/CH/A.py @@ -37,7 +37,7 @@ class A(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): domain = self.domain.choose_relativity(origin, relativize) - return "%s %o" % (domain, self.address) + return f"{domain} {self.address:o}" @classmethod def from_text( diff --git a/lib/dns/rdtypes/IN/NSAP.py b/lib/dns/rdtypes/IN/NSAP.py index a4854b3f..d55edb73 100644 --- a/lib/dns/rdtypes/IN/NSAP.py +++ b/lib/dns/rdtypes/IN/NSAP.py @@ -36,7 +36,7 @@ class NSAP(dns.rdata.Rdata): self.address = self._as_bytes(address) def to_text(self, origin=None, relativize=True, **kw): - return "0x%s" % binascii.hexlify(self.address).decode() + return f"0x{binascii.hexlify(self.address).decode()}" @classmethod def from_text( diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index 751087b4..a39c166b 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -36,7 +36,7 @@ class EUIBase(dns.rdata.Rdata): self.eui = self._as_bytes(eui) if len(self.eui) != self.byte_len: raise dns.exception.FormError( - "EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len) + f"EUI{self.byte_len * 8} rdata has to have {self.byte_len} bytes" ) def to_text(self, origin=None, relativize=True, **kw): @@ -49,16 +49,16 @@ class EUIBase(dns.rdata.Rdata): text = tok.get_string() if len(text) != cls.text_len: raise dns.exception.SyntaxError( - "Input text must have %s characters" % cls.text_len + f"Input text must have {cls.text_len} characters" ) for i in range(2, cls.byte_len * 3 - 1, 3): if text[i] != "-": - raise dns.exception.SyntaxError("Dash expected at position %s" % i) + raise dns.exception.SyntaxError(f"Dash expected at position {i}") text = text.replace("-", "") try: data = binascii.unhexlify(text.encode()) except (ValueError, TypeError) as ex: - raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex)) + raise dns.exception.SyntaxError(f"Hex decoding error: {str(ex)}") return cls(rdclass, rdtype, data) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/svcbbase.py b/lib/dns/rdtypes/svcbbase.py index 05652413..a2b15b92 100644 --- a/lib/dns/rdtypes/svcbbase.py +++ b/lib/dns/rdtypes/svcbbase.py @@ -35,6 +35,7 @@ class ParamKey(dns.enum.IntEnum): ECH = 5 IPV6HINT = 6 DOHPATH = 7 + OHTTP = 8 @classmethod def _maximum(cls): @@ -396,6 +397,36 @@ class ECHParam(Param): file.write(self.ech) +@dns.immutable.immutable +class OHTTPParam(Param): + # We don't ever expect to instantiate this class, but we need + # a from_value() and a from_wire_parser(), so we just return None + # from the class methods when things are OK. + + @classmethod + def emptiness(cls): + return Emptiness.ALWAYS + + @classmethod + def from_value(cls, value): + if value is None or value == "": + return None + else: + raise ValueError("ohttp with non-empty value") + + def to_text(self): + raise NotImplementedError # pragma: no cover + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + if parser.remaining() != 0: + raise dns.exception.FormError + return None + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + raise NotImplementedError # pragma: no cover + + _class_for_key = { ParamKey.MANDATORY: MandatoryParam, ParamKey.ALPN: ALPNParam, @@ -404,6 +435,7 @@ _class_for_key = { ParamKey.IPV4HINT: IPv4HintParam, ParamKey.ECH: ECHParam, ParamKey.IPV6HINT: IPv6HintParam, + ParamKey.OHTTP: OHTTPParam, } diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index 44d6df57..73db6d9e 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -50,6 +50,8 @@ class TXTBase(dns.rdata.Rdata): self.strings: Tuple[bytes] = self._as_tuple( strings, lambda x: self._as_bytes(x, True, 255) ) + if len(self.strings) == 0: + raise ValueError("the list of strings must not be empty") def to_text( self, @@ -60,7 +62,7 @@ class TXTBase(dns.rdata.Rdata): txt = "" prefix = "" for s in self.strings: - txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) + txt += f'{prefix}"{dns.rdata._escapify(s)}"' prefix = " " return txt diff --git a/lib/dns/rdtypes/util.py b/lib/dns/rdtypes/util.py index 54908fdc..653a0bf2 100644 --- a/lib/dns/rdtypes/util.py +++ b/lib/dns/rdtypes/util.py @@ -231,7 +231,7 @@ def weighted_processing_order(iterable): total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) while len(rdatas) > 1: r = random.uniform(0, total) - for n, rdata in enumerate(rdatas): + for n, rdata in enumerate(rdatas): # noqa: B007 weight = rdata._processing_weight() or _no_weight if weight > r: break diff --git a/lib/dns/resolver.py b/lib/dns/resolver.py index f08f824d..3ba76e31 100644 --- a/lib/dns/resolver.py +++ b/lib/dns/resolver.py @@ -36,6 +36,7 @@ import dns.ipv4 import dns.ipv6 import dns.message import dns.name +import dns.rdata import dns.nameserver import dns.query import dns.rcode @@ -45,7 +46,7 @@ import dns.rdtypes.svcbbase import dns.reversename import dns.tsig -if sys.platform == "win32": +if sys.platform == "win32": # pragma: no cover import dns.win32util @@ -83,7 +84,7 @@ class NXDOMAIN(dns.exception.DNSException): else: msg = "The DNS query name does not exist" qnames = ", ".join(map(str, qnames)) - return "{}: {}".format(msg, qnames) + return f"{msg}: {qnames}" @property def canonical_name(self): @@ -96,7 +97,7 @@ class NXDOMAIN(dns.exception.DNSException): cname = response.canonical_name() if cname != qname: return cname - except Exception: + except Exception: # pragma: no cover # We can just eat this exception as it means there was # something wrong with the response. pass @@ -154,7 +155,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: """Turn a resolution errors trace into a list of text.""" texts = [] for err in errors: - texts.append("Server {} answered {}".format(err[0], err[3])) + texts.append(f"Server {err[0]} answered {err[3]}") return texts @@ -162,7 +163,7 @@ class LifetimeTimeout(dns.exception.Timeout): """The resolution lifetime expired.""" msg = "The resolution lifetime expired." - fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1] + fmt = f"{msg[:-1]} after {{timeout:.3f}} seconds: {{errors}}" supp_kwargs = {"timeout", "errors"} # We do this as otherwise mypy complains about unexpected keyword argument @@ -211,7 +212,7 @@ class NoNameservers(dns.exception.DNSException): """ msg = "All nameservers failed to answer the query." - fmt = "%s {query}: {errors}" % msg[:-1] + fmt = f"{msg[:-1]} {{query}}: {{errors}}" supp_kwargs = {"request", "errors"} # We do this as otherwise mypy complains about unexpected keyword argument @@ -297,7 +298,7 @@ class Answer: def __len__(self) -> int: return self.rrset and len(self.rrset) or 0 - def __iter__(self): + def __iter__(self) -> Iterator[dns.rdata.Rdata]: return self.rrset and iter(self.rrset) or iter(tuple()) def __getitem__(self, i): @@ -334,7 +335,7 @@ class HostAnswers(Answers): answers[dns.rdatatype.A] = v4 return answers - # Returns pairs of (address, family) from this result, potentiallys + # Returns pairs of (address, family) from this result, potentially # filtering by address family. def addresses_and_families( self, family: int = socket.AF_UNSPEC @@ -347,7 +348,7 @@ class HostAnswers(Answers): answer = self.get(dns.rdatatype.AAAA) elif family == socket.AF_INET: answer = self.get(dns.rdatatype.A) - else: + else: # pragma: no cover raise NotImplementedError(f"unknown address family {family}") if answer: for rdata in answer: @@ -938,7 +939,7 @@ class BaseResolver: self.reset() if configure: - if sys.platform == "win32": + if sys.platform == "win32": # pragma: no cover self.read_registry() elif filename: self.read_resolv_conf(filename) @@ -947,7 +948,7 @@ class BaseResolver: """Reset all resolver configuration to the defaults.""" self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) - if len(self.domain) == 0: + if len(self.domain) == 0: # pragma: no cover self.domain = dns.name.root self._nameservers = [] self.nameserver_ports = {} @@ -1040,7 +1041,7 @@ class BaseResolver: # setter logic, with additonal checking and enrichment. self.nameservers = nameservers - def read_registry(self) -> None: + def read_registry(self) -> None: # pragma: no cover """Extract resolver configuration from the Windows registry.""" try: info = dns.win32util.get_dns_info() # type: ignore @@ -1205,9 +1206,7 @@ class BaseResolver: enriched_nameservers.append(enriched_nameserver) else: raise ValueError( - "nameservers must be a list or tuple (not a {})".format( - type(nameservers) - ) + f"nameservers must be a list or tuple (not a {type(nameservers)})" ) return enriched_nameservers @@ -1431,7 +1430,7 @@ class Resolver(BaseResolver): elif family == socket.AF_INET6: v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) return HostAnswers.make(v6=v6) - elif family != socket.AF_UNSPEC: + elif family != socket.AF_UNSPEC: # pragma: no cover raise NotImplementedError(f"unknown address family {family}") raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) @@ -1515,7 +1514,7 @@ class Resolver(BaseResolver): nameservers = dns._ddr._get_nameservers_sync(answer, timeout) if len(nameservers) > 0: self.nameservers = nameservers - except Exception: + except Exception: # pragma: no cover pass @@ -1640,7 +1639,7 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return get_default_resolver().canonical_name(name) -def try_ddr(lifetime: float = 5.0) -> None: +def try_ddr(lifetime: float = 5.0) -> None: # pragma: no cover """Try to update the default resolver's nameservers using Discovery of Designated Resolvers (DDR). If successful, the resolver will subsequently use DNS-over-HTTPS or DNS-over-TLS for future queries. @@ -1926,7 +1925,7 @@ def _getnameinfo(sockaddr, flags=0): family = socket.AF_INET tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0) if len(tuples) > 1: - raise socket.error("sockaddr resolved to multiple addresses") + raise OSError("sockaddr resolved to multiple addresses") addr = tuples[0][4][0] if flags & socket.NI_DGRAM: pname = "udp" @@ -1961,7 +1960,7 @@ def _getfqdn(name=None): (name, _, _) = _gethostbyaddr(name) # Python's version checks aliases too, but our gethostbyname # ignores them, so we do so here as well. - except Exception: + except Exception: # pragma: no cover pass return name diff --git a/lib/dns/set.py b/lib/dns/set.py index f0fb0d50..ae8f0dd5 100644 --- a/lib/dns/set.py +++ b/lib/dns/set.py @@ -21,10 +21,11 @@ import itertools class Set: """A simple set class. - This class was originally used to deal with sets being missing in - ancient versions of python, but dnspython will continue to use it - as these sets are based on lists and are thus indexable, and this - ability is widely used in dnspython applications. + This class was originally used to deal with python not having a set class, and + originally the class used lists in its implementation. The ordered and indexable + nature of RRsets and Rdatasets is unfortunately widely used in dnspython + applications, so for backwards compatibility sets continue to be a custom class, now + based on an ordered dictionary. """ __slots__ = ["items"] @@ -43,7 +44,7 @@ class Set: self.add(item) # lgtm[py/init-calls-subclass] def __repr__(self): - return "dns.set.Set(%s)" % repr(list(self.items.keys())) + return f"dns.set.Set({repr(list(self.items.keys()))})" # pragma: no cover def add(self, item): """Add an item to the set.""" diff --git a/lib/dns/tokenizer.py b/lib/dns/tokenizer.py index 454cac4a..ab205bc3 100644 --- a/lib/dns/tokenizer.py +++ b/lib/dns/tokenizer.py @@ -528,7 +528,7 @@ class Tokenizer: if value < 0 or value > 65535: if base == 8: raise dns.exception.SyntaxError( - "%o is not an octal unsigned 16-bit integer" % value + f"{value:o} is not an octal unsigned 16-bit integer" ) else: raise dns.exception.SyntaxError( diff --git a/lib/dns/transaction.py b/lib/dns/transaction.py index 84e54f7d..aa2e1160 100644 --- a/lib/dns/transaction.py +++ b/lib/dns/transaction.py @@ -486,7 +486,7 @@ class Transaction: if exact: raise DeleteNotExact(f"{method}: missing rdataset") else: - self._delete_rdataset(name, rdtype, covers) + self._checked_delete_rdataset(name, rdtype, covers) return else: rdataset = self._rdataset_from_args(method, True, args) @@ -529,8 +529,6 @@ class Transaction: def _end(self, commit): self._check_ended() - if self._ended: - raise AlreadyEnded try: self._end_transaction(commit) finally: diff --git a/lib/dns/ttl.py b/lib/dns/ttl.py index 264b0338..b9a99fe3 100644 --- a/lib/dns/ttl.py +++ b/lib/dns/ttl.py @@ -73,7 +73,7 @@ def from_text(text: str) -> int: elif c == "s": total += current else: - raise BadTTL("unknown unit '%s'" % c) + raise BadTTL(f"unknown unit '{c}'") current = 0 need_digit = True if not current == 0: diff --git a/lib/dns/version.py b/lib/dns/version.py index 251f2583..9ed2ce19 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -20,9 +20,9 @@ #: MAJOR MAJOR = 2 #: MINOR -MINOR = 6 +MINOR = 7 #: MICRO -MICRO = 1 +MICRO = 0 #: RELEASELEVEL RELEASELEVEL = 0x0F #: SERIAL diff --git a/lib/dns/win32util.py b/lib/dns/win32util.py index aaa7e93e..9ed3f11b 100644 --- a/lib/dns/win32util.py +++ b/lib/dns/win32util.py @@ -13,8 +13,8 @@ if sys.platform == "win32": # Keep pylint quiet on non-windows. try: - WindowsError is None # pylint: disable=used-before-assignment - except KeyError: + _ = WindowsError # pylint: disable=used-before-assignment + except NameError: WindowsError = Exception if dns._features.have("wmi"): @@ -44,6 +44,7 @@ if sys.platform == "win32": if _have_wmi: class _WMIGetter(threading.Thread): + # pylint: disable=possibly-used-before-assignment def __init__(self): super().__init__() self.info = DnsInfo() @@ -82,32 +83,21 @@ if sys.platform == "win32": def __init__(self): self.info = DnsInfo() - def _determine_split_char(self, entry): - # - # The windows registry irritatingly changes the list element - # delimiter in between ' ' and ',' (and vice-versa) in various - # versions of windows. - # - if entry.find(" ") >= 0: - split_char = " " - elif entry.find(",") >= 0: - split_char = "," - else: - # probably a singleton; treat as a space-separated list. - split_char = " " - return split_char + def _split(self, text): + # The windows registry has used both " " and "," as a delimiter, and while + # it is currently using "," in Windows 10 and later, updates can seemingly + # leave a space in too, e.g. "a, b". So we just convert all commas to + # spaces, and use split() in its default configuration, which splits on + # all whitespace and ignores empty strings. + return text.replace(",", " ").split() def _config_nameservers(self, nameservers): - split_char = self._determine_split_char(nameservers) - ns_list = nameservers.split(split_char) - for ns in ns_list: + for ns in self._split(nameservers): if ns not in self.info.nameservers: self.info.nameservers.append(ns) def _config_search(self, search): - split_char = self._determine_split_char(search) - search_list = search.split(split_char) - for s in search_list: + for s in self._split(search): s = _config_domain(s) if s not in self.info.search: self.info.search.append(s) @@ -164,7 +154,7 @@ if sys.platform == "win32": lm, r"SYSTEM\CurrentControlSet\Control\Network" r"\{4D36E972-E325-11CE-BFC1-08002BE10318}" - r"\%s\Connection" % guid, + rf"\{guid}\Connection", ) try: @@ -177,7 +167,7 @@ if sys.platform == "win32": raise ValueError # pragma: no cover device_key = winreg.OpenKey( - lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id + lm, rf"SYSTEM\CurrentControlSet\Enum\{pnp_id}" ) try: @@ -232,7 +222,7 @@ if sys.platform == "win32": self._config_fromkey(key, False) finally: key.Close() - except EnvironmentError: + except OSError: break finally: interfaces.Close() diff --git a/lib/dns/xfr.py b/lib/dns/xfr.py index dd247d33..520aa32d 100644 --- a/lib/dns/xfr.py +++ b/lib/dns/xfr.py @@ -33,7 +33,7 @@ class TransferError(dns.exception.DNSException): """A zone transfer response got a non-zero rcode.""" def __init__(self, rcode): - message = "Zone transfer error: %s" % dns.rcode.to_text(rcode) + message = f"Zone transfer error: {dns.rcode.to_text(rcode)}" super().__init__(message) self.rcode = rcode diff --git a/lib/dns/zonefile.py b/lib/dns/zonefile.py index af064e73..d74510b2 100644 --- a/lib/dns/zonefile.py +++ b/lib/dns/zonefile.py @@ -230,7 +230,7 @@ class Reader: try: rdtype = dns.rdatatype.from_text(token.value) except Exception: - raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) + raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'") try: rd = dns.rdata.from_text( @@ -251,9 +251,7 @@ class Reader: # We convert them to syntax errors so that we can emit # helpful filename:line info. (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError( - "caught exception {}: {}".format(str(ty), str(va)) - ) + raise dns.exception.SyntaxError(f"caught exception {str(ty)}: {str(va)}") if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default @@ -281,41 +279,41 @@ class Reader: # Sometimes there are modifiers in the hostname. These come after # the dollar sign. They are in the form: ${offset[,width[,base]]}. # Make names + mod = "" + sign = "+" + offset = "0" + width = "0" + base = "d" g1 = is_generate1.match(side) if g1: mod, sign, offset, width, base = g1.groups() if sign == "": sign = "+" - g2 = is_generate2.match(side) - if g2: - mod, sign, offset = g2.groups() - if sign == "": - sign = "+" - width = 0 - base = "d" - g3 = is_generate3.match(side) - if g3: - mod, sign, offset, width = g3.groups() - if sign == "": - sign = "+" - base = "d" + else: + g2 = is_generate2.match(side) + if g2: + mod, sign, offset = g2.groups() + if sign == "": + sign = "+" + width = "0" + base = "d" + else: + g3 = is_generate3.match(side) + if g3: + mod, sign, offset, width = g3.groups() + if sign == "": + sign = "+" + base = "d" - if not (g1 or g2 or g3): - mod = "" - sign = "+" - offset = 0 - width = 0 - base = "d" - - offset = int(offset) - width = int(width) + ioffset = int(offset) + iwidth = int(width) if sign not in ["+", "-"]: - raise dns.exception.SyntaxError("invalid offset sign %s" % sign) + raise dns.exception.SyntaxError(f"invalid offset sign {sign}") if base not in ["d", "o", "x", "X", "n", "N"]: - raise dns.exception.SyntaxError("invalid type %s" % base) + raise dns.exception.SyntaxError(f"invalid type {base}") - return mod, sign, offset, width, base + return mod, sign, ioffset, iwidth, base def _generate_line(self): # range lhs [ttl] [class] type rhs [ comment ] @@ -377,7 +375,7 @@ class Reader: if not token.is_identifier(): raise dns.exception.SyntaxError except Exception: - raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) + raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'") # rhs (required) rhs = token.value @@ -412,8 +410,8 @@ class Reader: lzfindex = _format_index(lindex, lbase, lwidth) rzfindex = _format_index(rindex, rbase, rwidth) - name = lhs.replace("$%s" % (lmod), lzfindex) - rdata = rhs.replace("$%s" % (rmod), rzfindex) + name = lhs.replace(f"${lmod}", lzfindex) + rdata = rhs.replace(f"${rmod}", rzfindex) self.last_name = dns.name.from_text( name, self.current_origin, self.tok.idna_codec @@ -445,7 +443,7 @@ class Reader: # helpful filename:line info. (ty, va) = sys.exc_info()[:2] raise dns.exception.SyntaxError( - "caught exception %s: %s" % (str(ty), str(va)) + f"caught exception {str(ty)}: {str(va)}" ) self.txn.add(name, ttl, rd) @@ -528,7 +526,7 @@ class Reader: self.default_ttl_known, ) ) - self.current_file = open(filename, "r") + self.current_file = open(filename) self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) self.current_origin = new_origin elif c == "$GENERATE": diff --git a/requirements.txt b/requirements.txt index cfd8b7e5..9413a3f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ cheroot==10.0.1 cherrypy==18.10.0 cloudinary==1.41.0 distro==1.9.0 -dnspython==2.6.1 +dnspython==2.7.0 facebook-sdk==3.1.0 future==1.0.0 ga4mp==2.0.4