diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py index 2631228e..9d9ed369 100644 --- a/lib/dns/_asyncio_backend.py +++ b/lib/dns/_asyncio_backend.py @@ -7,7 +7,9 @@ import socket import sys import dns._asyncbackend +import dns._features import dns.exception +import dns.inet _is_win32 = sys.platform == "win32" @@ -121,7 +123,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket): return self.writer.get_extra_info("peercert") -try: +if dns._features.have("doh"): import anyio import httpcore import httpcore._backends.anyio @@ -205,7 +207,7 @@ try: resolver, local_port, bootstrap_address, family ) -except ImportError: +else: _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore @@ -224,14 +226,12 @@ class Backend(dns._asyncbackend.Backend): ssl_context=None, server_hostname=None, ): - if destination is None and socktype == socket.SOCK_DGRAM and _is_win32: - raise NotImplementedError( - "destinationless datagram sockets " - "are not supported by asyncio " - "on Windows" - ) loop = _get_running_loop() if socktype == socket.SOCK_DGRAM: + if _is_win32 and source is None: + # Win32 wants explicit binding before recvfrom(). This is the + # proper fix for [#637]. + source = (dns.inet.any_for_af(af), 0) transport, protocol = await loop.create_datagram_endpoint( _DatagramProtocol, source, @@ -266,7 +266,7 @@ class Backend(dns._asyncbackend.Backend): await asyncio.sleep(interval) def datagram_connection_required(self): - return _is_win32 + return False def get_transport_class(self): return _HTTPTransport diff --git a/lib/dns/_features.py b/lib/dns/_features.py new file mode 100644 index 00000000..03ccaa77 --- /dev/null +++ b/lib/dns/_features.py @@ -0,0 +1,92 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import importlib.metadata +import itertools +import string +from typing import Dict, List, Tuple + + +def _tuple_from_text(version: str) -> Tuple: + text_parts = version.split(".") + int_parts = [] + for text_part in text_parts: + digit_prefix = "".join( + itertools.takewhile(lambda x: x in string.digits, text_part) + ) + try: + int_parts.append(int(digit_prefix)) + except Exception: + break + return tuple(int_parts) + + +def _version_check( + requirement: str, +) -> bool: + """Is the requirement fulfilled? + + The requirement must be of the form + + package>=version + """ + package, minimum = requirement.split(">=") + try: + version = importlib.metadata.version(package) + except Exception: + return False + t_version = _tuple_from_text(version) + t_minimum = _tuple_from_text(minimum) + if t_version < t_minimum: + return False + return True + + +_cache: Dict[str, bool] = {} + + +def have(feature: str) -> bool: + """Is *feature* available? + + This tests if all optional packages needed for the + feature are available and recent enough. + + Returns ``True`` if the feature is available, + and ``False`` if it is not or if metadata is + missing. + """ + value = _cache.get(feature) + if value is not None: + return value + requirements = _requirements.get(feature) + if requirements is None: + # we make a cache entry here for consistency not performance + _cache[feature] = False + return False + ok = True + for requirement in requirements: + if not _version_check(requirement): + ok = False + break + _cache[feature] = ok + return ok + + +def force(feature: str, enabled: bool) -> None: + """Force the status of *feature* to be *enabled*. + + This method is provided as a workaround for any cases + where importlib.metadata is ineffective, or for testing. + """ + _cache[feature] = enabled + + +_requirements: Dict[str, List[str]] = { + ### BEGIN generated requirements + "dnssec": ["cryptography>=41"], + "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"], + "doq": ["aioquic>=0.9.25"], + "idna": ["idna>=3.6"], + "trio": ["trio>=0.23"], + "wmi": ["wmi>=1.5.1"], + ### END generated requirements +} diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py index 4d9fb820..398e3276 100644 --- a/lib/dns/_trio_backend.py +++ b/lib/dns/_trio_backend.py @@ -8,9 +8,13 @@ import trio import trio.socket # type: ignore import dns._asyncbackend +import dns._features import dns.exception import dns.inet +if not dns._features.have("trio"): + raise ImportError("trio not found or too old") + def _maybe_timeout(timeout): if timeout is not None: @@ -95,7 +99,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket): raise NotImplementedError -try: +if dns._features.have("doh"): import httpcore import httpcore._backends.trio import httpx @@ -177,7 +181,7 @@ try: resolver, local_port, bootstrap_address, family ) -except ImportError: +else: _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore diff --git a/lib/dns/asyncbackend.py b/lib/dns/asyncbackend.py index 07d50e1e..0ec58b06 100644 --- a/lib/dns/asyncbackend.py +++ b/lib/dns/asyncbackend.py @@ -32,7 +32,7 @@ def get_backend(name: str) -> Backend: *name*, a ``str``, the name of the backend. Currently the "trio" and "asyncio" backends are available. - Raises NotImplementError if an unknown backend name is specified. + Raises NotImplementedError if an unknown backend name is specified. """ # pylint: disable=import-outside-toplevel,redefined-outer-name backend = _backends.get(name) diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py index ecf9c1a5..4d9ab9ae 100644 --- a/lib/dns/asyncquery.py +++ b/lib/dns/asyncquery.py @@ -41,7 +41,7 @@ from dns.query import ( NoDOQ, UDPMode, _compute_times, - _have_http2, + _make_dot_ssl_context, _matches_destination, _remaining, have_doh, @@ -120,6 +120,8 @@ async def receive_udp( request_mac: Optional[bytes] = b"", ignore_trailing: bool = False, raise_on_truncation: bool = False, + ignore_errors: bool = False, + query: Optional[dns.message.Message] = None, ) -> Any: """Read a DNS message from a UDP socket. @@ -133,22 +135,40 @@ async def receive_udp( """ wire = b"" - while 1: + while True: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _matches_destination( + if not _matches_destination( sock.family, from_address, destination, ignore_unexpected ): - break - received_time = time.time() - r = dns.message.from_wire( - wire, - keyring=keyring, - request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation, - ) - return (r, received_time, from_address) + continue + received_time = time.time() + try: + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) + except dns.message.Truncated as e: + # See the comment in query.py for details. + if ( + ignore_errors + and query is not None + and not query.is_response(e.message()) + ): + continue + else: + raise + except Exception: + if ignore_errors: + continue + else: + raise + if ignore_errors and query is not None and not query.is_response(r): + continue + return (r, received_time, from_address) async def udp( @@ -164,6 +184,7 @@ async def udp( raise_on_truncation: bool = False, sock: Optional[dns.asyncbackend.DatagramSocket] = None, backend: Optional[dns.asyncbackend.Backend] = None, + ignore_errors: bool = False, ) -> dns.message.Message: """Return the response obtained after sending a query via UDP. @@ -205,9 +226,13 @@ async def udp( q.mac, ignore_trailing, raise_on_truncation, + ignore_errors, + q, ) r.time = received_time - begin_time - if not q.is_response(r): + # We don't need to check q.is_response() if we are in ignore_errors mode + # as receive_udp() will have checked it. + if not (ignore_errors or q.is_response(r)): raise BadResponse return r @@ -225,6 +250,7 @@ async def udp_with_fallback( udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None, tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None, backend: Optional[dns.asyncbackend.Backend] = None, + ignore_errors: bool = False, ) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -260,6 +286,7 @@ async def udp_with_fallback( True, udp_sock, backend, + ignore_errors, ) return (response, False) except dns.message.Truncated: @@ -292,14 +319,12 @@ async def send_tcp( """ if isinstance(what, dns.message.Message): - wire = what.to_wire() + tcpmsg = what.to_wire(prepend_length=True) else: - wire = what - l = len(wire) - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = len(what).to_bytes(2, "big") + what sent_time = time.time() await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) return (len(tcpmsg), sent_time) @@ -418,6 +443,7 @@ async def tls( backend: Optional[dns.asyncbackend.Backend] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, + verify: Union[bool, str] = True, ) -> dns.message.Message: """Return the response obtained after sending a query via TLS. @@ -439,11 +465,7 @@ async def tls( cm: contextlib.AbstractAsyncContextManager = NullContext(sock) else: if ssl_context is None: - # See the comment about ssl.create_default_context() in query.py - ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol] - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - if server_hostname is None: - ssl_context.check_hostname = False + ssl_context = _make_dot_ssl_context(server_hostname, verify) af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) @@ -538,7 +560,7 @@ async def https( transport = backend.get_transport_class()( local_address=local_address, http1=True, - http2=_have_http2, + http2=True, verify=verify, local_port=local_port, bootstrap_address=bootstrap_address, @@ -550,7 +572,7 @@ async def https( cm: contextlib.AbstractAsyncContextManager = NullContext(client) else: cm = httpx.AsyncClient( - http1=True, http2=_have_http2, verify=verify, transport=transport + http1=True, http2=True, verify=verify, transport=transport ) async with cm as the_client: diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index 2949f619..e49c3b79 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -27,6 +27,7 @@ import time from datetime import datetime from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast +import dns._features import dns.exception import dns.name import dns.node @@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs): ) # pragma: no cover -try: +if dns._features.have("dnssec"): from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 @@ -1184,20 +1185,20 @@ try: get_algorithm_cls_from_dnskey, ) from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey -except ImportError: # pragma: no cover - validate = _need_pyca - validate_rrsig = _need_pyca - sign = _need_pyca - make_dnskey = _need_pyca - make_cdnskey = _need_pyca - _have_pyca = False -else: + validate = _validate # type: ignore validate_rrsig = _validate_rrsig # type: ignore sign = _sign make_dnskey = _make_dnskey make_cdnskey = _make_cdnskey _have_pyca = True +else: # pragma: no cover + validate = _need_pyca + validate_rrsig = _need_pyca + sign = _need_pyca + make_dnskey = _need_pyca + make_cdnskey = _need_pyca + _have_pyca = False ### BEGIN generated Algorithm constants diff --git a/lib/dns/dnssecalgs/__init__.py b/lib/dns/dnssecalgs/__init__.py index d1ffd519..3d9181a7 100644 --- a/lib/dns/dnssecalgs/__init__.py +++ b/lib/dns/dnssecalgs/__init__.py @@ -1,9 +1,12 @@ from typing import Dict, Optional, Tuple, Type, Union import dns.name +from dns.dnssecalgs.base import GenericPrivateKey +from dns.dnssectypes import Algorithm +from dns.exception import UnsupportedAlgorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY -try: - from dns.dnssecalgs.base import GenericPrivateKey +if dns._features.have("dnssec"): from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519 @@ -16,13 +19,9 @@ try: ) _have_cryptography = True -except ImportError: +else: _have_cryptography = False -from dns.dnssectypes import Algorithm -from dns.exception import UnsupportedAlgorithm -from dns.rdtypes.ANY.DNSKEY import DNSKEY - AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} diff --git a/lib/dns/edns.py b/lib/dns/edns.py index f05baac4..776e5eeb 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -17,6 +17,7 @@ """EDNS Options""" +import binascii import math import socket import struct @@ -58,7 +59,6 @@ class OptionType(dns.enum.IntEnum): class Option: - """Base class for all EDNS option types.""" def __init__(self, otype: Union[OptionType, str]): @@ -76,6 +76,9 @@ class Option: """ raise NotImplementedError # pragma: no cover + def to_text(self) -> str: + raise NotImplementedError # pragma: no cover + @classmethod def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option": """Build an EDNS option object from wire format. @@ -141,7 +144,6 @@ class Option: class GenericOption(Option): # lgtm[py/missing-equals] - """Generic Option Class This class is used for EDNS option types for which we have no better @@ -343,6 +345,8 @@ class EDECode(dns.enum.IntEnum): class EDEOption(Option): # lgtm[py/missing-equals] """Extended DNS Error (EDE, RFC8914)""" + _preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"} + def __init__(self, code: Union[EDECode, str], text: Optional[str] = None): """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the extended error. @@ -360,6 +364,13 @@ class EDEOption(Option): # lgtm[py/missing-equals] def to_text(self) -> str: output = f"EDE {self.code}" + if self.code in EDECode: + desc = EDECode.to_text(self.code) + desc = " ".join( + word if word in self._preserve_case else word.title() + for word in desc.split("_") + ) + output += f" ({desc})" if self.text is not None: output += f": {self.text}" return output @@ -392,9 +403,37 @@ class EDEOption(Option): # lgtm[py/missing-equals] return cls(code, btext) +class NSIDOption(Option): + def __init__(self, nsid: bytes): + super().__init__(OptionType.NSID) + self.nsid = nsid + + def to_wire(self, file: Any = None) -> Optional[bytes]: + if file: + file.write(self.nsid) + return None + else: + return self.nsid + + def to_text(self) -> str: + if all(c >= 0x20 and c <= 0x7E for c in self.nsid): + # All ASCII printable, so it's probably a string. + value = self.nsid.decode() + else: + value = binascii.hexlify(self.nsid).decode() + return f"NSID {value}" + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: dns.wire.Parser + ) -> Option: + return cls(parser.get_remaining()) + + _type_to_class: Dict[OptionType, Any] = { OptionType.ECS: ECSOption, OptionType.EDE: EDEOption, + OptionType.NSID: NSIDOption, } diff --git a/lib/dns/immutable.py b/lib/dns/immutable.py index cab8d6fb..36b0362c 100644 --- a/lib/dns/immutable.py +++ b/lib/dns/immutable.py @@ -1,24 +1,30 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license import collections.abc -from typing import Any +from typing import Any, Callable from dns._immutable_ctx import immutable @immutable class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] - def __init__(self, dictionary: Any, no_copy: bool = False): + def __init__( + self, + dictionary: Any, + no_copy: bool = False, + map_factory: Callable[[], collections.abc.MutableMapping] = dict, + ): """Make an immutable dictionary from the specified dictionary. If *no_copy* is `True`, then *dictionary* will be wrapped instead of copied. Only set this if you are sure there will be no external references to the dictionary. """ - if no_copy and isinstance(dictionary, dict): + if no_copy and isinstance(dictionary, collections.abc.MutableMapping): self._odict = dictionary else: - self._odict = dict(dictionary) + self._odict = map_factory() + self._odict.update(dictionary) self._hash = None def __getitem__(self, key): diff --git a/lib/dns/inet.py b/lib/dns/inet.py index 02e925c6..4a03f996 100644 --- a/lib/dns/inet.py +++ b/lib/dns/inet.py @@ -178,3 +178,20 @@ def any_for_af(af): elif af == socket.AF_INET6: return "::" raise NotImplementedError(f"unknown address family {af}") + + +def canonicalize(text: str) -> str: + """Verify that *address* is a valid text form IPv4 or IPv6 address and return its + canonical text form. IPv6 addresses with scopes are rejected. + + *text*, a ``str``, the address in textual form. + + Raises ``ValueError`` if the text is not valid. + """ + try: + return dns.ipv6.canonicalize(text) + except Exception: + try: + return dns.ipv4.canonicalize(text) + except Exception: + raise ValueError diff --git a/lib/dns/ipv4.py b/lib/dns/ipv4.py index f549150a..65ee69c0 100644 --- a/lib/dns/ipv4.py +++ b/lib/dns/ipv4.py @@ -62,3 +62,16 @@ def inet_aton(text: Union[str, bytes]) -> bytes: return struct.pack("BBBB", *b) except Exception: raise dns.exception.SyntaxError + + +def canonicalize(text: Union[str, bytes]) -> str: + """Verify that *address* is a valid text form IPv4 address and return its + canonical text form. + + *text*, a ``str`` or ``bytes``, the IPv4 address in textual form. + + Raises ``dns.exception.SyntaxError`` if the text is not valid. + """ + # Note that inet_aton() only accepts canonial form, but we still run through + # inet_ntoa() to ensure the output is a str. + return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index 0cc3d868..44a10639 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -104,7 +104,7 @@ _colon_colon_end = re.compile(rb".*::$") def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes: """Convert an IPv6 address in text form to binary form. - *text*, a ``str``, the IPv6 address in textual form. + *text*, a ``str`` or ``bytes``, the IPv6 address in textual form. *ignore_scope*, a ``bool``. If ``True``, a scope will be ignored. If ``False``, the default, it is an error for a scope to be present. @@ -206,3 +206,14 @@ def is_mapped(address: bytes) -> bool: """ return address.startswith(_mapped_prefix) + + +def canonicalize(text: Union[str, bytes]) -> str: + """Verify that *address* is a valid text form IPv6 address and return its + canonical text form. Addresses with scopes are rejected. + + *text*, a ``str`` or ``bytes``, the IPv6 address in textual form. + + Raises ``dns.exception.SyntaxError`` if the text is not valid. + """ + return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) diff --git a/lib/dns/message.py b/lib/dns/message.py index daae6363..44cacbd9 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -393,7 +393,7 @@ class Message: section_number = section section = self.section_from_number(section_number) elif isinstance(section, str): - section_number = MessageSection.from_text(section) + section_number = self._section_enum.from_text(section) section = self.section_from_number(section_number) else: section_number = self.section_number(section) @@ -489,6 +489,34 @@ class Message: rrset = None return rrset + def section_count(self, section: SectionType) -> int: + """Returns the number of records in the specified section. + + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the + the section of the message to count. For example:: + + my_message.section_count(my_message.answer) + my_message.section_count(dns.message.ANSWER) + my_message.section_count("ANSWER") + """ + + if isinstance(section, int): + section_number = section + section = self.section_from_number(section_number) + elif isinstance(section, str): + section_number = self._section_enum.from_text(section) + section = self.section_from_number(section_number) + else: + section_number = self.section_number(section) + count = sum(max(1, len(rrs)) for rrs in section) + if section_number == MessageSection.ADDITIONAL: + if self.opt is not None: + count += 1 + if self.tsig is not None: + count += 1 + return count + def _compute_opt_reserve(self) -> int: """Compute the size required for the OPT RR, padding excluded""" if not self.opt: @@ -527,6 +555,8 @@ class Message: max_size: int = 0, multi: bool = False, tsig_ctx: Optional[Any] = None, + prepend_length: bool = False, + prefer_truncation: bool = False, **kw: Dict[str, Any], ) -> bytes: """Return a string containing the message in DNS compressed wire @@ -549,6 +579,15 @@ class Message: *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG context, used when signing zone transfers. + *prepend_length*, a ``bool``, should be set to ``True`` if the caller + wants the message length prepended to the message itself. This is + useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ). + + *prefer_truncation*, a ``bool``, should be set to ``True`` if the caller + wants the message to be truncated if it would otherwise exceed the + maximum length. If the truncation occurs before the additional section, + the TC bit will be set. + Raises ``dns.exception.TooBig`` if *max_size* was exceeded. Returns a ``bytes``. @@ -570,14 +609,21 @@ class Message: r.reserve(opt_reserve) tsig_reserve = self._compute_tsig_reserve() r.reserve(tsig_reserve) - for rrset in self.question: - r.add_question(rrset.name, rrset.rdtype, rrset.rdclass) - for rrset in self.answer: - r.add_rrset(dns.renderer.ANSWER, rrset, **kw) - for rrset in self.authority: - r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw) - for rrset in self.additional: - r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) + try: + for rrset in self.question: + r.add_question(rrset.name, rrset.rdtype, rrset.rdclass) + for rrset in self.answer: + r.add_rrset(dns.renderer.ANSWER, rrset, **kw) + for rrset in self.authority: + r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw) + for rrset in self.additional: + r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) + except dns.exception.TooBig: + if prefer_truncation: + if r.section < dns.renderer.ADDITIONAL: + r.flags |= dns.flags.TC + else: + raise r.release_reserved() if self.opt is not None: r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve) @@ -598,7 +644,10 @@ class Message: r.write_header() if multi: self.tsig_ctx = ctx - return r.get_wire() + wire = r.get_wire() + if prepend_length: + wire = len(wire).to_bytes(2, "big") + wire + return wire @staticmethod def _make_tsig( @@ -777,6 +826,8 @@ class Message: if request_payload is None: request_payload = payload self.request_payload = request_payload + if pad < 0: + raise ValueError("pad must be non-negative") self.pad = pad @property @@ -826,7 +877,7 @@ class Message: if wanted: self.ednsflags |= dns.flags.DO elif self.opt: - self.ednsflags &= ~dns.flags.DO + self.ednsflags &= ~int(dns.flags.DO) def rcode(self) -> dns.rcode.Rcode: """Return the rcode. @@ -1035,7 +1086,6 @@ def _message_factory_from_opcode(opcode): class _WireReader: - """Wire format reader. parser: the binary parser @@ -1335,7 +1385,6 @@ def from_wire( class _TextReader: - """Text format reader. tok: the tokenizer. @@ -1768,30 +1817,34 @@ def make_response( our_payload: int = 8192, fudge: int = 300, tsig_error: int = 0, + pad: Optional[int] = None, ) -> Message: """Make a message which is a response for the specified query. - The message returned is really a response skeleton; it has all - of the infrastructure required of a response, but none of the - content. + The message returned is really a response skeleton; it has all of the infrastructure + required of a response, but none of the content. - The response's question section is a shallow copy of the query's - question section, so the query's question RRsets should not be - changed. + The response's question section is a shallow copy of the query's question section, + so the query's question RRsets should not be changed. *query*, a ``dns.message.Message``, the query to respond to. *recursion_available*, a ``bool``, should RA be set in the response? - *our_payload*, an ``int``, the payload size to advertise in EDNS - responses. + *our_payload*, an ``int``, the payload size to advertise in EDNS responses. *fudge*, an ``int``, the TSIG time fudge. *tsig_error*, an ``int``, the TSIG error. - Returns a ``dns.message.Message`` object whose specific class is - appropriate for the query. For example, if query is a - ``dns.update.UpdateMessage``, response will be too. + *pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise + if not ``None`` add padding bytes to make the message size a multiple of *pad*. + Note that if padding is non-zero, an EDNS PADDING option will always be added to the + message. If ``None``, add padding following RFC 8467, namely if the request is + padded, pad the response to 468 otherwise do not pad. + + Returns a ``dns.message.Message`` object whose specific class is appropriate for the + query. For example, if query is a ``dns.update.UpdateMessage``, response will be + too. """ if query.flags & dns.flags.QR: @@ -1804,7 +1857,13 @@ def make_response( response.set_opcode(query.opcode()) response.question = list(query.question) if query.edns >= 0: - response.use_edns(0, 0, our_payload, query.payload) + if pad is None: + # Set response padding per RFC 8467 + pad = 0 + for option in query.options: + if option.otype == dns.edns.OptionType.PADDING: + pad = 468 + response.use_edns(0, 0, our_payload, query.payload, pad=pad) if query.had_tsig: response.use_tsig( query.keyring, diff --git a/lib/dns/name.py b/lib/dns/name.py index f452bfed..22ccb392 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -20,21 +20,23 @@ import copy import encodings.idna # type: ignore +import functools import struct -from typing import Any, Dict, Iterable, Optional, Tuple, Union - -try: - import idna # type: ignore - - have_idna_2008 = True -except ImportError: # pragma: no cover - have_idna_2008 = False +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +import dns._features import dns.enum import dns.exception import dns.immutable import dns.wire +if dns._features.have("idna"): + import idna # type: ignore + + have_idna_2008 = True +else: # pragma: no cover + have_idna_2008 = False + CompressType = Dict["Name", int] @@ -128,6 +130,10 @@ class IDNAException(dns.exception.DNSException): super().__init__(*args, **kwargs) +class NeedSubdomainOfOrigin(dns.exception.DNSException): + """An absolute name was provided that is not a subdomain of the specified origin.""" + + _escaped = b'"().;\\@$' _escaped_text = '"().;\\@$' @@ -350,7 +356,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes: @dns.immutable.immutable class Name: - """A DNS name. The dns.name.Name class represents a DNS name as a tuple of @@ -843,6 +848,42 @@ class Name: raise NoParent return Name(self.labels[1:]) + def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name": + """Return the maximal predecessor of *name* in the DNSSEC ordering in the zone + whose origin is *origin*, or return the longest name under *origin* if the + name is origin (i.e. wrap around to the longest name, which may still be + *origin* due to length considerations. + + The relativity of the name is preserved, so if this name is relative + then the method will return a relative name, and likewise if this name + is absolute then the predecessor will be absolute. + + *prefix_ok* indicates if prefixing labels is allowed, and + defaults to ``True``. Normally it is good to allow this, but if computing + a maximal predecessor at a zone cut point then ``False`` must be specified. + """ + return _handle_relativity_and_call( + _absolute_predecessor, self, origin, prefix_ok + ) + + def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name": + """Return the minimal successor of *name* in the DNSSEC ordering in the zone + whose origin is *origin*, or return *origin* if the successor cannot be + computed due to name length limitations. + + Note that *origin* is returned in the "too long" cases because wrapping + around to the origin is how NSEC records express "end of the zone". + + The relativity of the name is preserved, so if this name is relative + then the method will return a relative name, and likewise if this name + is absolute then the successor will be absolute. + + *prefix_ok* indicates if prefixing a new minimal label is allowed, and + defaults to ``True``. Normally it is good to allow this, but if computing + a minimal successor at a zone cut point then ``False`` must be specified. + """ + return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok) + #: The root name, '.' root = Name([b""]) @@ -1082,3 +1123,161 @@ def from_wire(message: bytes, current: int) -> Tuple[Name, int]: parser = dns.wire.Parser(message, current) name = from_wire_parser(parser) return (name, parser.current - current) + + +# RFC 4471 Support + +_MINIMAL_OCTET = b"\x00" +_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET) +_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET]) +_MAXIMAL_OCTET = b"\xff" +_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET) +_AT_SIGN_VALUE = ord("@") +_LEFT_SQUARE_BRACKET_VALUE = ord("[") + + +def _wire_length(labels): + return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0) + + +def _pad_to_max_name(name): + needed = 255 - _wire_length(name.labels) + new_labels = [] + while needed > 64: + new_labels.append(_MAXIMAL_OCTET * 63) + needed -= 64 + if needed >= 2: + new_labels.append(_MAXIMAL_OCTET * (needed - 1)) + # Note we're already maximal in the needed == 1 case as while we'd like + # to add one more byte as a new label, we can't, as adding a new non-empty + # label requires at least 2 bytes. + new_labels = list(reversed(new_labels)) + new_labels.extend(name.labels) + return Name(new_labels) + + +def _pad_to_max_label(label, suffix_labels): + length = len(label) + # We have to subtract one here to account for the length byte of label. + remaining = 255 - _wire_length(suffix_labels) - length - 1 + if remaining <= 0: + # Shouldn't happen! + return label + needed = min(63 - length, remaining) + return label + _MAXIMAL_OCTET * needed + + +def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name: + # This is the RFC 4471 predecessor algorithm using the "absolute method" of section + # 3.1.1. + # + # Our caller must ensure that the name and origin are absolute, and that name is a + # subdomain of origin. + if name == origin: + return _pad_to_max_name(name) + least_significant_label = name[0] + if least_significant_label == _MINIMAL_OCTET: + return name.parent() + least_octet = least_significant_label[-1] + suffix_labels = name.labels[1:] + if least_octet == _MINIMAL_OCTET_VALUE: + new_labels = [least_significant_label[:-1]] + else: + octets = bytearray(least_significant_label) + octet = octets[-1] + if octet == _LEFT_SQUARE_BRACKET_VALUE: + octet = _AT_SIGN_VALUE + else: + octet -= 1 + octets[-1] = octet + least_significant_label = bytes(octets) + new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)] + new_labels.extend(suffix_labels) + name = Name(new_labels) + if prefix_ok: + return _pad_to_max_name(name) + else: + return name + + +def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name: + # This is the RFC 4471 successor algorithm using the "absolute method" of section + # 3.1.2. + # + # Our caller must ensure that the name and origin are absolute, and that name is a + # subdomain of origin. + if prefix_ok: + # Try prefixing \000 as new label + try: + return _SUCCESSOR_PREFIX.concatenate(name) + except NameTooLong: + pass + while name != origin: + # Try extending the least significant label. + least_significant_label = name[0] + if len(least_significant_label) < 63: + # We may be able to extend the least label with a minimal additional byte. + # This is only "may" because we could have a maximal length name even though + # the least significant label isn't maximally long. + new_labels = [least_significant_label + _MINIMAL_OCTET] + new_labels.extend(name.labels[1:]) + try: + return dns.name.Name(new_labels) + except dns.name.NameTooLong: + pass + # We can't extend the label either, so we'll try to increment the least + # signficant non-maximal byte in it. + octets = bytearray(least_significant_label) + # We do this reversed iteration with an explicit indexing variable because + # if we find something to increment, we're going to want to truncate everything + # to the right of it. + for i in range(len(octets) - 1, -1, -1): + octet = octets[i] + if octet == _MAXIMAL_OCTET_VALUE: + # We can't increment this, so keep looking. + continue + # Finally, something we can increment. We have to apply a special rule for + # incrementing "@", sending it to "[", because RFC 4034 6.1 says that when + # comparing names, uppercase letters compare as if they were their + # lower-case equivalents. If we increment "@" to "A", then it would compare + # as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have + # skipped the most minimal successor, namely "[". + if octet == _AT_SIGN_VALUE: + octet = _LEFT_SQUARE_BRACKET_VALUE + else: + octet += 1 + octets[i] = octet + # We can now truncate all of the maximal values we skipped (if any) + new_labels = [bytes(octets[: i + 1])] + new_labels.extend(name.labels[1:]) + # We haven't changed the length of the name, so the Name constructor will + # always work. + return Name(new_labels) + # We couldn't increment, so chop off the least significant label and try + # again. + name = name.parent() + + # We couldn't increment at all, so return the origin, as wrapping around is the + # DNSSEC way. + return origin + + +def _handle_relativity_and_call( + function: Callable[[Name, Name, bool], Name], + name: Name, + origin: Name, + prefix_ok: bool, +) -> Name: + # Make "name" absolute if needed, ensure that the origin is absolute, + # call function(), and then relativize the result if needed. + if not origin.is_absolute(): + raise NeedAbsoluteNameOrOrigin + relative = not name.is_absolute() + if relative: + name = name.derelativize(origin) + elif not name.is_subdomain(origin): + raise NeedSubdomainOfOrigin + result_name = function(name, origin, prefix_ok) + if relative: + result_name = result_name.relativize(origin) + return result_name diff --git a/lib/dns/nameserver.py b/lib/dns/nameserver.py index 5910139e..5dbb4e8b 100644 --- a/lib/dns/nameserver.py +++ b/lib/dns/nameserver.py @@ -115,6 +115,8 @@ class Do53Nameserver(AddressAndPortNameserver): raise_on_truncation=True, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, + ignore_errors=True, + ignore_unexpected=True, ) return response @@ -153,15 +155,25 @@ class Do53Nameserver(AddressAndPortNameserver): backend=backend, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, + ignore_errors=True, + ignore_unexpected=True, ) return response class DoHNameserver(Nameserver): - def __init__(self, url: str, bootstrap_address: Optional[str] = None): + def __init__( + self, + url: str, + bootstrap_address: Optional[str] = None, + verify: Union[bool, str] = True, + want_get: bool = False, + ): super().__init__() self.url = url self.bootstrap_address = bootstrap_address + self.verify = verify + self.want_get = want_get def kind(self): return "DoH" @@ -195,9 +207,13 @@ class DoHNameserver(Nameserver): request, self.url, timeout=timeout, + source=source, + source_port=source_port, bootstrap_address=self.bootstrap_address, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, + verify=self.verify, + post=(not self.want_get), ) async def async_query( @@ -215,15 +231,27 @@ class DoHNameserver(Nameserver): request, self.url, timeout=timeout, + source=source, + source_port=source_port, + bootstrap_address=self.bootstrap_address, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, + verify=self.verify, + post=(not self.want_get), ) class DoTNameserver(AddressAndPortNameserver): - def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None): + def __init__( + self, + address: str, + port: int = 853, + hostname: Optional[str] = None, + verify: Union[bool, str] = True, + ): super().__init__(address, port) self.hostname = hostname + self.verify = verify def kind(self): return "DoT" @@ -246,6 +274,7 @@ class DoTNameserver(AddressAndPortNameserver): one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, server_hostname=self.hostname, + verify=self.verify, ) async def async_query( @@ -267,6 +296,7 @@ class DoTNameserver(AddressAndPortNameserver): one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, server_hostname=self.hostname, + verify=self.verify, ) diff --git a/lib/dns/node.py b/lib/dns/node.py index c670243c..de85a82d 100644 --- a/lib/dns/node.py +++ b/lib/dns/node.py @@ -70,7 +70,6 @@ class NodeKind(enum.Enum): class Node: - """A Node is a set of rdatasets. A node is either a CNAME node or an "other data" node. A CNAME diff --git a/lib/dns/query.py b/lib/dns/query.py index 0d711251..f0ee9161 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -22,12 +22,14 @@ import contextlib import enum import errno import os +import os.path import selectors import socket import struct import time from typing import Any, Dict, Optional, Tuple, Union +import dns._features import dns.exception import dns.inet import dns.message @@ -57,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration): return min(time.time() + timeout, expiration) -_have_httpx = False -_have_http2 = False -try: - import httpcore +_have_httpx = dns._features.have("doh") +if _have_httpx: import httpcore._backends.sync import httpx _CoreNetworkBackend = httpcore.NetworkBackend _CoreSyncStream = httpcore._backends.sync.SyncStream - _have_httpx = True - try: - # See if http2 support is available. - with httpx.Client(http2=True): - _have_http2 = True - except Exception: - pass - class _NetworkBackend(_CoreNetworkBackend): def __init__(self, resolver, local_port, bootstrap_address, family): super().__init__() @@ -147,7 +139,7 @@ try: resolver, local_port, bootstrap_address, family ) -except ImportError: # pragma: no cover +else: class _HTTPTransport: # type: ignore def connect_tcp(self, host, port, timeout, local_address): @@ -161,6 +153,8 @@ try: except ImportError: # pragma: no cover class ssl: # type: ignore + CERT_NONE = 0 + class WantReadException(Exception): pass @@ -459,7 +453,7 @@ def https( transport = _HTTPTransport( local_address=local_address, http1=True, - http2=_have_http2, + http2=True, verify=verify, local_port=local_port, bootstrap_address=bootstrap_address, @@ -470,9 +464,7 @@ def https( if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) else: - cm = httpx.Client( - http1=True, http2=_have_http2, verify=verify, transport=transport - ) + cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) with cm as session: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples @@ -577,6 +569,8 @@ def receive_udp( request_mac: Optional[bytes] = b"", ignore_trailing: bool = False, raise_on_truncation: bool = False, + ignore_errors: bool = False, + query: Optional[dns.message.Message] = None, ) -> Any: """Read a DNS message from a UDP socket. @@ -617,28 +611,58 @@ def receive_udp( ``(dns.message.Message, float, tuple)`` tuple of the received message, the received time, and the address where the message arrived from. + + *ignore_errors*, a ``bool``. If various format errors or response + mismatches occur, ignore them and keep listening for a valid response. + The default is ``False``. + + *query*, a ``dns.message.Message`` or ``None``. If not ``None`` and + *ignore_errors* is ``True``, check that the received message is a response + to this query, and if not keep listening for a valid response. """ wire = b"" while True: (wire, from_address) = _udp_recv(sock, 65535, expiration) - if _matches_destination( + if not _matches_destination( sock.family, from_address, destination, ignore_unexpected ): - break - received_time = time.time() - r = dns.message.from_wire( - wire, - keyring=keyring, - request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation, - ) - if destination: - return (r, received_time) - else: - return (r, received_time, from_address) + continue + received_time = time.time() + try: + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) + except dns.message.Truncated as e: + # If we got Truncated and not FORMERR, we at least got the header with TC + # set, and very likely the question section, so we'll re-raise if the + # message seems to be a response as we need to know when truncation happens. + # We need to check that it seems to be a response as we don't want a random + # injected message with TC set to cause us to bail out. + if ( + ignore_errors + and query is not None + and not query.is_response(e.message()) + ): + continue + else: + raise + except Exception: + if ignore_errors: + continue + else: + raise + if ignore_errors and query is not None and not query.is_response(r): + continue + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) def udp( @@ -653,6 +677,7 @@ def udp( ignore_trailing: bool = False, raise_on_truncation: bool = False, sock: Optional[Any] = None, + ignore_errors: bool = False, ) -> dns.message.Message: """Return the response obtained after sending a query via UDP. @@ -689,6 +714,10 @@ def udp( if a socket is provided, it must be a nonblocking datagram socket, and the *source* and *source_port* are ignored. + *ignore_errors*, a ``bool``. If various format errors or response + mismatches occur, ignore them and keep listening for a valid response. + The default is ``False``. + Returns a ``dns.message.Message``. """ @@ -713,9 +742,13 @@ def udp( q.mac, ignore_trailing, raise_on_truncation, + ignore_errors, + q, ) r.time = received_time - begin_time - if not q.is_response(r): + # We don't need to check q.is_response() if we are in ignore_errors mode + # as receive_udp() will have checked it. + if not (ignore_errors or q.is_response(r)): raise BadResponse return r assert ( @@ -735,48 +768,50 @@ def udp_with_fallback( ignore_trailing: bool = False, udp_sock: Optional[Any] = None, tcp_sock: Optional[Any] = None, + ignore_errors: bool = False, ) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. *q*, a ``dns.message.Message``, the query to send - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. + *where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message. - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the - query times out. If ``None``, the default, wait forever. + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. *port*, an ``int``, the port send the message to. The default is 53. - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. - *source_port*, an ``int``, the port from which to send the message. - The default is 0. + *source_port*, an ``int``, the port from which to send the message. The default is + 0. - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected + sources. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. - *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the - UDP query. If ``None``, the default, a socket is created. Note that - if a socket is provided, it must be a nonblocking datagram socket, - and the *source* and *source_port* are ignored for the UDP query. + *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query. + If ``None``, the default, a socket is created. Note that if a socket is provided, + it must be a nonblocking datagram socket, and the *source* and *source_port* are + ignored for the UDP query. *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the - TCP query. If ``None``, the default, a socket is created. Note that - if a socket is provided, it must be a nonblocking connected stream - socket, and *where*, *source* and *source_port* are ignored for the TCP - query. + TCP query. If ``None``, the default, a socket is created. Note that if a socket is + provided, it must be a nonblocking connected stream socket, and *where*, *source* + and *source_port* are ignored for the TCP query. - Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` - if and only if TCP was used. + *ignore_errors*, a ``bool``. If various format errors or response mismatches occur + while listening for UDP, ignore them and keep listening for a valid response. The + default is ``False``. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if + TCP was used. """ try: response = udp( @@ -791,6 +826,7 @@ def udp_with_fallback( ignore_trailing, True, udp_sock, + ignore_errors, ) return (response, False) except dns.message.Truncated: @@ -864,14 +900,12 @@ def send_tcp( """ if isinstance(what, dns.message.Message): - wire = what.to_wire() + tcpmsg = what.to_wire(prepend_length=True) else: - wire = what - l = len(wire) - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = len(what).to_bytes(2, "big") + what sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) @@ -1014,6 +1048,28 @@ def _tls_handshake(s, expiration): _wait_for_writable(s, expiration) +def _make_dot_ssl_context( + server_hostname: Optional[str], verify: Union[bool, str] +) -> ssl.SSLContext: + cafile: Optional[str] = None + capath: Optional[str] = None + if isinstance(verify, str): + if os.path.isfile(verify): + cafile = verify + elif os.path.isdir(verify): + capath = verify + else: + raise ValueError("invalid verify string") + ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + if server_hostname is None: + ssl_context.check_hostname = False + ssl_context.set_alpn_protocols(["dot"]) + if verify is False: + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + + def tls( q: dns.message.Message, where: str, @@ -1026,6 +1082,7 @@ def tls( sock: Optional[ssl.SSLSocket] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, + verify: Union[bool, str] = True, ) -> dns.message.Message: """Return the response obtained after sending a query via TLS. @@ -1065,6 +1122,11 @@ def tls( default is ``None``, which means that no hostname is known, and if an SSL context is created, hostname checking will be disabled. + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + Returns a ``dns.message.Message``. """ @@ -1091,10 +1153,7 @@ def tls( where, port, source, source_port ) if ssl_context is None and not sock: - ssl_context = ssl.create_default_context() - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - if server_hostname is None: - ssl_context.check_hostname = False + ssl_context = _make_dot_ssl_context(server_hostname, verify) with _make_socket( af, diff --git a/lib/dns/quic/__init__.py b/lib/dns/quic/__init__.py index 69813f9f..20aff345 100644 --- a/lib/dns/quic/__init__.py +++ b/lib/dns/quic/__init__.py @@ -1,9 +1,11 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -try: +import dns._features +import dns.asyncbackend + +if dns._features.have("doq"): import aioquic.quic.configuration # type: ignore - import dns.asyncbackend from dns._asyncbackend import NullContext from dns.quic._asyncio import ( AsyncioQuicConnection, @@ -17,7 +19,7 @@ try: def null_factory( *args, # pylint: disable=unused-argument - **kwargs # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument ): return NullContext(None) @@ -31,7 +33,7 @@ try: _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} - try: + if dns._features.have("trio"): import trio from dns.quic._trio import ( # pylint: disable=ungrouped-imports @@ -47,15 +49,13 @@ try: return TrioQuicManager(context, *args, **kwargs) _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) - except ImportError: - pass def factories_for_backend(backend=None): if backend is None: backend = dns.asyncbackend.get_default_backend() return _async_factories[backend.name()] -except ImportError: +else: # pragma: no cover have_quic = False from typing import Any diff --git a/lib/dns/quic/_asyncio.py b/lib/dns/quic/_asyncio.py index e1c52339..0f44331f 100644 --- a/lib/dns/quic/_asyncio.py +++ b/lib/dns/quic/_asyncio.py @@ -101,9 +101,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): ) if address[0] != self._peer[0] or address[1] != self._peer[1]: continue - self._connection.receive_datagram( - datagram, self._peer[0], time.time() - ) + self._connection.receive_datagram(datagram, address, time.time()) # Wake up the timer in case the sender is sleeping, as there may be # stuff to send now. async with self._wake_timer: @@ -125,7 +123,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): while not self._done: datagrams = self._connection.datagrams_to_send(time.time()) for datagram, address in datagrams: - assert address == self._peer[0] + assert address == self._peer await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() try: @@ -147,11 +145,14 @@ class AsyncioQuicConnection(AsyncQuicConnection): await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() - elif isinstance( - event, aioquic.quic.events.ConnectionTerminated - ) or isinstance(event, aioquic.quic.events.StreamReset): + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): self._done = True self._receiver_task.cancel() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) + count += 1 if count > 10: # yield @@ -188,7 +189,6 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._connection.close() # sender might be blocked on this, so set it self._socket_created.set() - await self._socket.close() async with self._wake_timer: self._wake_timer.notify_all() try: @@ -199,14 +199,19 @@ class AsyncioQuicConnection(AsyncQuicConnection): await self._sender_task except asyncio.CancelledError: pass + await self._socket.close() class AsyncioQuicManager(AsyncQuicManager): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) - def connect(self, address, port=853, source=None, source_port=0): - (connection, start) = self._connect(address, port, source, source_port) + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) if start: connection.run() return connection diff --git a/lib/dns/quic/_common.py b/lib/dns/quic/_common.py index 38ec103f..0eacc691 100644 --- a/lib/dns/quic/_common.py +++ b/lib/dns/quic/_common.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import copy +import functools import socket import struct import time @@ -11,6 +13,10 @@ import aioquic.quic.connection # type: ignore import dns.inet QUIC_MAX_DATAGRAM = 2048 +MAX_SESSION_TICKETS = 8 +# If we hit the max sessions limit we will delete this many of the oldest connections. +# The value must be a integer > 0 and <= MAX_SESSION_TICKETS. +SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4 class UnexpectedEOF(Exception): @@ -79,7 +85,10 @@ class BaseQuicStream: def _common_add_input(self, data, is_end): self._buffer.put(data, is_end) - return self._expecting > 0 and self._buffer.have(self._expecting) + try: + return self._expecting > 0 and self._buffer.have(self._expecting) + except UnexpectedEOF: + return True def _close(self): self._connection.close_stream(self._stream_id) @@ -142,6 +151,7 @@ class BaseQuicManager: def __init__(self, conf, verify_mode, connection_factory, server_name=None): self._connections = {} self._connection_factory = connection_factory + self._session_tickets = {} if conf is None: verify_path = None if isinstance(verify_mode, str): @@ -156,12 +166,35 @@ class BaseQuicManager: conf.load_verify_locations(verify_path) self._conf = conf - def _connect(self, address, port=853, source=None, source_port=0): + def _connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): connection = self._connections.get((address, port)) if connection is not None: return (connection, False) - qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf) - qconn.connect(address, time.time()) + conf = self._conf + if want_session_ticket: + try: + session_ticket = self._session_tickets.pop((address, port)) + # We found a session ticket, so make a configuration that uses it. + conf = copy.copy(conf) + conf.session_ticket = session_ticket + except KeyError: + # No session ticket. + pass + # Whether or not we found a session ticket, we want a handler to save + # one. + session_ticket_handler = functools.partial( + self.save_session_ticket, address, port + ) + else: + session_ticket_handler = None + qconn = aioquic.quic.connection.QuicConnection( + configuration=conf, + session_ticket_handler=session_ticket_handler, + ) + lladdress = dns.inet.low_level_address_tuple((address, port)) + qconn.connect(lladdress, time.time()) connection = self._connection_factory( qconn, address, port, source, source_port, self ) @@ -174,6 +207,17 @@ class BaseQuicManager: except KeyError: pass + def save_session_ticket(self, address, port, ticket): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._session_tickets) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._session_tickets[key] + self._session_tickets[(address, port)] = ticket + class AsyncQuicManager(BaseQuicManager): def connect(self, address, port=853, source=None, source_port=0): diff --git a/lib/dns/quic/_sync.py b/lib/dns/quic/_sync.py index e944784d..120cb5f3 100644 --- a/lib/dns/quic/_sync.py +++ b/lib/dns/quic/_sync.py @@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection): def __init__(self, connection, address, port, source, source_port, manager): super().__init__(connection, address, port, source, source_port, manager) self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) - self._socket.connect(self._peer) - (self._send_wakeup, self._receive_wakeup) = socket.socketpair() - self._receive_wakeup.setblocking(False) - self._socket.setblocking(False) if self._source is not None: try: self._socket.bind( @@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection): except Exception: self._socket.close() raise + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) self._handshake_complete = threading.Event() self._worker_thread = None self._lock = threading.Lock() @@ -107,7 +107,7 @@ class SyncQuicConnection(BaseQuicConnection): except BlockingIOError: return with self._lock: - self._connection.receive_datagram(datagram, self._peer[0], time.time()) + self._connection.receive_datagram(datagram, self._peer, time.time()) def _drain_wakeup(self): while True: @@ -128,6 +128,8 @@ class SyncQuicConnection(BaseQuicConnection): key.data() with self._lock: self._handle_timer(expiration) + self._handle_events() + with self._lock: datagrams = self._connection.datagrams_to_send(time.time()) for datagram, _ in datagrams: try: @@ -135,7 +137,6 @@ class SyncQuicConnection(BaseQuicConnection): except BlockingIOError: # we let QUIC handle any lossage pass - self._handle_events() finally: with self._lock: self._done = True @@ -155,11 +156,14 @@ class SyncQuicConnection(BaseQuicConnection): stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() - elif isinstance( - event, aioquic.quic.events.ConnectionTerminated - ) or isinstance(event, aioquic.quic.events.StreamReset): + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): with self._lock: self._done = True + elif isinstance(event, aioquic.quic.events.StreamReset): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(b"", True) def write(self, stream, data, is_end=False): with self._lock: @@ -203,9 +207,13 @@ class SyncQuicManager(BaseQuicManager): super().__init__(conf, verify_mode, SyncQuicConnection, server_name) self._lock = threading.Lock() - def connect(self, address, port=853, source=None, source_port=0): + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): with self._lock: - (connection, start) = self._connect(address, port, source, source_port) + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) if start: connection.run() return connection @@ -214,6 +222,10 @@ class SyncQuicManager(BaseQuicManager): with self._lock: super().closed(address, port) + def save_session_ticket(self, address, port, ticket): + with self._lock: + super().save_session_ticket(address, port, ticket) + def __enter__(self): return self diff --git a/lib/dns/quic/_trio.py b/lib/dns/quic/_trio.py index ee07e4f6..35e36b98 100644 --- a/lib/dns/quic/_trio.py +++ b/lib/dns/quic/_trio.py @@ -76,30 +76,43 @@ class TrioQuicConnection(AsyncQuicConnection): def __init__(self, connection, address, port, source, source_port, manager=None): super().__init__(connection, address, port, source, source_port, manager) self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) - if self._source: - trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af)) self._handshake_complete = trio.Event() self._run_done = trio.Event() self._worker_scope = None + self._send_pending = False async def _worker(self): try: + if self._source: + await self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) await self._socket.connect(self._peer) while not self._done: (expiration, interval) = self._get_timer_values(False) + if self._send_pending: + # Do not block forever if sends are pending. Even though we + # have a wake-up mechanism if we've already started the blocking + # read, the possibility of context switching in send means that + # more writes can happen while we have no wake up context, so + # we need self._send_pending to avoid (effectively) a "lost wakeup" + # race. + interval = 0.0 with trio.CancelScope( deadline=trio.current_time() + interval ) as self._worker_scope: datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) - self._connection.receive_datagram( - datagram, self._peer[0], time.time() - ) + self._connection.receive_datagram(datagram, self._peer, time.time()) self._worker_scope = None self._handle_timer(expiration) + await self._handle_events() + # We clear this now, before sending anything, as sending can cause + # context switches that do more sends. We want to know if that + # happens so we don't block a long time on the recv() above. + self._send_pending = False datagrams = self._connection.datagrams_to_send(time.time()) for datagram, _ in datagrams: await self._socket.send(datagram) - await self._handle_events() finally: self._done = True self._handshake_complete.set() @@ -116,11 +129,13 @@ class TrioQuicConnection(AsyncQuicConnection): await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() - elif isinstance( - event, aioquic.quic.events.ConnectionTerminated - ) or isinstance(event, aioquic.quic.events.StreamReset): + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): self._done = True self._socket.close() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) count += 1 if count > 10: # yield @@ -129,6 +144,7 @@ class TrioQuicConnection(AsyncQuicConnection): async def write(self, stream, data, is_end=False): self._connection.send_stream_data(stream, data, is_end) + self._send_pending = True if self._worker_scope is not None: self._worker_scope.cancel() @@ -159,6 +175,7 @@ class TrioQuicConnection(AsyncQuicConnection): self._manager.closed(self._peer[0], self._peer[1]) self._closed = True self._connection.close() + self._send_pending = True if self._worker_scope is not None: self._worker_scope.cancel() await self._run_done.wait() @@ -171,8 +188,12 @@ class TrioQuicManager(AsyncQuicManager): super().__init__(conf, verify_mode, TrioQuicConnection, server_name) self._nursery = nursery - def connect(self, address, port=853, source=None, source_port=0): - (connection, start) = self._connect(address, port, source, source_port) + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) if start: self._nursery.start_soon(connection.run) return connection diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index 0d262e8d..024fd8f6 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -199,7 +199,7 @@ class Rdata: self, origin: Optional[dns.name.Name] = None, relativize: bool = True, - **kw: Dict[str, Any] + **kw: Dict[str, Any], ) -> str: """Convert an rdata to text format. @@ -547,9 +547,7 @@ class Rdata: @classmethod def _as_ipv4_address(cls, value): if isinstance(value, str): - # call to check validity - dns.ipv4.inet_aton(value) - return value + return dns.ipv4.canonicalize(value) elif isinstance(value, bytes): return dns.ipv4.inet_ntoa(value) else: @@ -558,9 +556,7 @@ class Rdata: @classmethod def _as_ipv6_address(cls, value): if isinstance(value, str): - # call to check validity - dns.ipv6.inet_aton(value) - return value + return dns.ipv6.canonicalize(value) elif isinstance(value, bytes): return dns.ipv6.inet_ntoa(value) else: @@ -604,7 +600,6 @@ class Rdata: @dns.immutable.immutable class GenericRdata(Rdata): - """Generic Rdata Class This class is used for rdata types for which we have no better @@ -621,7 +616,7 @@ class GenericRdata(Rdata): self, origin: Optional[dns.name.Name] = None, relativize: bool = True, - **kw: Dict[str, Any] + **kw: Dict[str, Any], ) -> str: return r"\# %d " % len(self.data) + _hexify(self.data, **kw) @@ -647,9 +642,9 @@ class GenericRdata(Rdata): return cls(rdclass, rdtype, parser.get_remaining()) -_rdata_classes: Dict[ - Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any -] = {} +_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = ( + {} +) _module_prefix = "dns.rdtypes" diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index 31124afc..8bff58d7 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -28,6 +28,7 @@ import dns.name import dns.rdata import dns.rdataclass import dns.rdatatype +import dns.renderer import dns.set import dns.ttl @@ -45,7 +46,6 @@ class IncompatibleTypes(dns.exception.DNSException): class Rdataset(dns.set.Set): - """A DNS rdataset.""" __slots__ = ["rdclass", "rdtype", "covers", "ttl"] @@ -316,11 +316,9 @@ class Rdataset(dns.set.Set): want_shuffle = False else: rdclass = self.rdclass - file.seek(0, io.SEEK_END) if len(self) == 0: name.to_wire(file, compress, origin) - stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0) - file.write(stuff) + file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)) return 1 else: l: Union[Rdataset, List[dns.rdata.Rdata]] @@ -331,16 +329,9 @@ class Rdataset(dns.set.Set): l = self for rd in l: name.to_wire(file, compress, origin) - stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0) - file.write(stuff) - start = file.tell() - rd.to_wire(file, compress, origin) - end = file.tell() - assert end - start < 65536 - file.seek(start - 2) - stuff = struct.pack("!H", end - start) - file.write(stuff) - file.seek(0, io.SEEK_END) + file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl)) + with dns.renderer.prefixed_length(file, 2): + rd.to_wire(file, compress, origin) return len(self) def match( @@ -373,7 +364,6 @@ class Rdataset(dns.set.Set): @dns.immutable.immutable class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] - """An immutable DNS rdataset.""" _clone_class = Rdataset diff --git a/lib/dns/rdtypes/ANY/AFSDB.py b/lib/dns/rdtypes/ANY/AFSDB.py index 3d287f6e..06a3b970 100644 --- a/lib/dns/rdtypes/ANY/AFSDB.py +++ b/lib/dns/rdtypes/ANY/AFSDB.py @@ -21,7 +21,6 @@ import dns.rdtypes.mxbase @dns.immutable.immutable class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): - """AFSDB record""" # Use the property mechanism to make "subtype" an alias for the diff --git a/lib/dns/rdtypes/ANY/AMTRELAY.py b/lib/dns/rdtypes/ANY/AMTRELAY.py index dfe7abc3..ed2b072b 100644 --- a/lib/dns/rdtypes/ANY/AMTRELAY.py +++ b/lib/dns/rdtypes/ANY/AMTRELAY.py @@ -32,7 +32,6 @@ class Relay(dns.rdtypes.util.Gateway): @dns.immutable.immutable class AMTRELAY(dns.rdata.Rdata): - """AMTRELAY record""" # see: RFC 8777 diff --git a/lib/dns/rdtypes/ANY/AVC.py b/lib/dns/rdtypes/ANY/AVC.py index 766d5e2d..a27ae2d6 100644 --- a/lib/dns/rdtypes/ANY/AVC.py +++ b/lib/dns/rdtypes/ANY/AVC.py @@ -21,7 +21,6 @@ import dns.rdtypes.txtbase @dns.immutable.immutable class AVC(dns.rdtypes.txtbase.TXTBase): - """AVC record""" # See: IANA dns parameters for AVC diff --git a/lib/dns/rdtypes/ANY/CAA.py b/lib/dns/rdtypes/ANY/CAA.py index 8afb538c..2e6a7e7e 100644 --- a/lib/dns/rdtypes/ANY/CAA.py +++ b/lib/dns/rdtypes/ANY/CAA.py @@ -25,7 +25,6 @@ import dns.tokenizer @dns.immutable.immutable class CAA(dns.rdata.Rdata): - """CAA (Certification Authority Authorization) record""" # see: RFC 6844 diff --git a/lib/dns/rdtypes/ANY/CDNSKEY.py b/lib/dns/rdtypes/ANY/CDNSKEY.py index 38b8a8da..b613409f 100644 --- a/lib/dns/rdtypes/ANY/CDNSKEY.py +++ b/lib/dns/rdtypes/ANY/CDNSKEY.py @@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] @dns.immutable.immutable class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): - """CDNSKEY record""" diff --git a/lib/dns/rdtypes/ANY/CDS.py b/lib/dns/rdtypes/ANY/CDS.py index 2ff42d9a..8312b972 100644 --- a/lib/dns/rdtypes/ANY/CDS.py +++ b/lib/dns/rdtypes/ANY/CDS.py @@ -21,7 +21,6 @@ import dns.rdtypes.dsbase @dns.immutable.immutable class CDS(dns.rdtypes.dsbase.DSBase): - """CDS record""" _digest_length_by_type = { diff --git a/lib/dns/rdtypes/ANY/CERT.py b/lib/dns/rdtypes/ANY/CERT.py index 30fe863f..f369cc85 100644 --- a/lib/dns/rdtypes/ANY/CERT.py +++ b/lib/dns/rdtypes/ANY/CERT.py @@ -67,7 +67,6 @@ def _ctype_to_text(what): @dns.immutable.immutable class CERT(dns.rdata.Rdata): - """CERT record""" # see RFC 4398 diff --git a/lib/dns/rdtypes/ANY/CNAME.py b/lib/dns/rdtypes/ANY/CNAME.py index 759adb90..665e407c 100644 --- a/lib/dns/rdtypes/ANY/CNAME.py +++ b/lib/dns/rdtypes/ANY/CNAME.py @@ -21,7 +21,6 @@ import dns.rdtypes.nsbase @dns.immutable.immutable class CNAME(dns.rdtypes.nsbase.NSBase): - """CNAME record Note: although CNAME is officially a singleton type, dnspython allows diff --git a/lib/dns/rdtypes/ANY/CSYNC.py b/lib/dns/rdtypes/ANY/CSYNC.py index 315da9ff..2f972f6e 100644 --- a/lib/dns/rdtypes/ANY/CSYNC.py +++ b/lib/dns/rdtypes/ANY/CSYNC.py @@ -32,7 +32,6 @@ class Bitmap(dns.rdtypes.util.Bitmap): @dns.immutable.immutable class CSYNC(dns.rdata.Rdata): - """CSYNC record""" __slots__ = ["serial", "flags", "windows"] diff --git a/lib/dns/rdtypes/ANY/DLV.py b/lib/dns/rdtypes/ANY/DLV.py index 632e90f8..6c134f18 100644 --- a/lib/dns/rdtypes/ANY/DLV.py +++ b/lib/dns/rdtypes/ANY/DLV.py @@ -21,5 +21,4 @@ import dns.rdtypes.dsbase @dns.immutable.immutable class DLV(dns.rdtypes.dsbase.DSBase): - """DLV record""" diff --git a/lib/dns/rdtypes/ANY/DNAME.py b/lib/dns/rdtypes/ANY/DNAME.py index 556bff59..bbf9186c 100644 --- a/lib/dns/rdtypes/ANY/DNAME.py +++ b/lib/dns/rdtypes/ANY/DNAME.py @@ -21,7 +21,6 @@ import dns.rdtypes.nsbase @dns.immutable.immutable class DNAME(dns.rdtypes.nsbase.UncompressedNS): - """DNAME record""" def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/DNSKEY.py b/lib/dns/rdtypes/ANY/DNSKEY.py index f1a63062..6d961a9f 100644 --- a/lib/dns/rdtypes/ANY/DNSKEY.py +++ b/lib/dns/rdtypes/ANY/DNSKEY.py @@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] @dns.immutable.immutable class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): - """DNSKEY record""" diff --git a/lib/dns/rdtypes/ANY/DS.py b/lib/dns/rdtypes/ANY/DS.py index 097ecfa0..58b3108d 100644 --- a/lib/dns/rdtypes/ANY/DS.py +++ b/lib/dns/rdtypes/ANY/DS.py @@ -21,5 +21,4 @@ import dns.rdtypes.dsbase @dns.immutable.immutable class DS(dns.rdtypes.dsbase.DSBase): - """DS record""" diff --git a/lib/dns/rdtypes/ANY/EUI48.py b/lib/dns/rdtypes/ANY/EUI48.py index 7e4e1ff3..c843be50 100644 --- a/lib/dns/rdtypes/ANY/EUI48.py +++ b/lib/dns/rdtypes/ANY/EUI48.py @@ -22,7 +22,6 @@ import dns.rdtypes.euibase @dns.immutable.immutable class EUI48(dns.rdtypes.euibase.EUIBase): - """EUI48 record""" # see: rfc7043.txt diff --git a/lib/dns/rdtypes/ANY/EUI64.py b/lib/dns/rdtypes/ANY/EUI64.py index 68b5820f..f6d7e257 100644 --- a/lib/dns/rdtypes/ANY/EUI64.py +++ b/lib/dns/rdtypes/ANY/EUI64.py @@ -22,7 +22,6 @@ import dns.rdtypes.euibase @dns.immutable.immutable class EUI64(dns.rdtypes.euibase.EUIBase): - """EUI64 record""" # see: rfc7043.txt diff --git a/lib/dns/rdtypes/ANY/GPOS.py b/lib/dns/rdtypes/ANY/GPOS.py index 30aab321..312338f9 100644 --- a/lib/dns/rdtypes/ANY/GPOS.py +++ b/lib/dns/rdtypes/ANY/GPOS.py @@ -44,7 +44,6 @@ def _validate_float_string(what): @dns.immutable.immutable class GPOS(dns.rdata.Rdata): - """GPOS record""" # see: RFC 1712 diff --git a/lib/dns/rdtypes/ANY/HINFO.py b/lib/dns/rdtypes/ANY/HINFO.py index 513c155a..c2c45de0 100644 --- a/lib/dns/rdtypes/ANY/HINFO.py +++ b/lib/dns/rdtypes/ANY/HINFO.py @@ -25,7 +25,6 @@ import dns.tokenizer @dns.immutable.immutable class HINFO(dns.rdata.Rdata): - """HINFO record""" # see: RFC 1035 diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index a20aa1e5..91669139 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -27,7 +27,6 @@ import dns.rdatatype @dns.immutable.immutable class HIP(dns.rdata.Rdata): - """HIP record""" # see: RFC 5205 diff --git a/lib/dns/rdtypes/ANY/ISDN.py b/lib/dns/rdtypes/ANY/ISDN.py index 536a35d6..fb01eab3 100644 --- a/lib/dns/rdtypes/ANY/ISDN.py +++ b/lib/dns/rdtypes/ANY/ISDN.py @@ -25,7 +25,6 @@ import dns.tokenizer @dns.immutable.immutable class ISDN(dns.rdata.Rdata): - """ISDN record""" # see: RFC 1183 diff --git a/lib/dns/rdtypes/ANY/L32.py b/lib/dns/rdtypes/ANY/L32.py index 14be01f9..09804c2d 100644 --- a/lib/dns/rdtypes/ANY/L32.py +++ b/lib/dns/rdtypes/ANY/L32.py @@ -8,7 +8,6 @@ import dns.rdata @dns.immutable.immutable class L32(dns.rdata.Rdata): - """L32 record""" # see: rfc6742.txt diff --git a/lib/dns/rdtypes/ANY/L64.py b/lib/dns/rdtypes/ANY/L64.py index d083d403..fb76808e 100644 --- a/lib/dns/rdtypes/ANY/L64.py +++ b/lib/dns/rdtypes/ANY/L64.py @@ -8,7 +8,6 @@ import dns.rdtypes.util @dns.immutable.immutable class L64(dns.rdata.Rdata): - """L64 record""" # see: rfc6742.txt diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index 783d54af..a36a2c10 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -105,7 +105,6 @@ def _check_coordinate_list(value, low, high): @dns.immutable.immutable class LOC(dns.rdata.Rdata): - """LOC record""" # see: RFC 1876 diff --git a/lib/dns/rdtypes/ANY/LP.py b/lib/dns/rdtypes/ANY/LP.py index 8a7c5125..312663f1 100644 --- a/lib/dns/rdtypes/ANY/LP.py +++ b/lib/dns/rdtypes/ANY/LP.py @@ -8,7 +8,6 @@ import dns.rdata @dns.immutable.immutable class LP(dns.rdata.Rdata): - """LP record""" # see: rfc6742.txt diff --git a/lib/dns/rdtypes/ANY/MX.py b/lib/dns/rdtypes/ANY/MX.py index 1f9df21f..0c300c5a 100644 --- a/lib/dns/rdtypes/ANY/MX.py +++ b/lib/dns/rdtypes/ANY/MX.py @@ -21,5 +21,4 @@ import dns.rdtypes.mxbase @dns.immutable.immutable class MX(dns.rdtypes.mxbase.MXBase): - """MX record""" diff --git a/lib/dns/rdtypes/ANY/NID.py b/lib/dns/rdtypes/ANY/NID.py index ad54aca3..2f649178 100644 --- a/lib/dns/rdtypes/ANY/NID.py +++ b/lib/dns/rdtypes/ANY/NID.py @@ -8,7 +8,6 @@ import dns.rdtypes.util @dns.immutable.immutable class NID(dns.rdata.Rdata): - """NID record""" # see: rfc6742.txt diff --git a/lib/dns/rdtypes/ANY/NINFO.py b/lib/dns/rdtypes/ANY/NINFO.py index 55bc5614..b177bddb 100644 --- a/lib/dns/rdtypes/ANY/NINFO.py +++ b/lib/dns/rdtypes/ANY/NINFO.py @@ -21,7 +21,6 @@ import dns.rdtypes.txtbase @dns.immutable.immutable class NINFO(dns.rdtypes.txtbase.TXTBase): - """NINFO record""" # see: draft-reid-dnsext-zs-01 diff --git a/lib/dns/rdtypes/ANY/NS.py b/lib/dns/rdtypes/ANY/NS.py index fe453f0d..c3f34ce9 100644 --- a/lib/dns/rdtypes/ANY/NS.py +++ b/lib/dns/rdtypes/ANY/NS.py @@ -21,5 +21,4 @@ import dns.rdtypes.nsbase @dns.immutable.immutable class NS(dns.rdtypes.nsbase.NSBase): - """NS record""" diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index a2d98fa7..340525a6 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -30,7 +30,6 @@ class Bitmap(dns.rdtypes.util.Bitmap): @dns.immutable.immutable class NSEC(dns.rdata.Rdata): - """NSEC record""" __slots__ = ["next", "windows"] diff --git a/lib/dns/rdtypes/ANY/NSEC3.py b/lib/dns/rdtypes/ANY/NSEC3.py index d32fe169..d71302b7 100644 --- a/lib/dns/rdtypes/ANY/NSEC3.py +++ b/lib/dns/rdtypes/ANY/NSEC3.py @@ -46,7 +46,6 @@ class Bitmap(dns.rdtypes.util.Bitmap): @dns.immutable.immutable class NSEC3(dns.rdata.Rdata): - """NSEC3 record""" __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"] @@ -64,9 +63,13 @@ class NSEC3(dns.rdata.Rdata): windows = Bitmap(windows) self.windows = tuple(windows.windows) - def to_text(self, origin=None, relativize=True, **kw): + def _next_text(self): next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() next = next.rstrip("=") + return next + + def to_text(self, origin=None, relativize=True, **kw): + next = self._next_text() if self.salt == b"": salt = "-" else: @@ -118,3 +121,6 @@ class NSEC3(dns.rdata.Rdata): next = parser.get_counted_bytes() bitmap = Bitmap.from_wire_parser(parser) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) + + def next_name(self, origin=None): + return dns.name.from_text(self._next_text(), origin) diff --git a/lib/dns/rdtypes/ANY/NSEC3PARAM.py b/lib/dns/rdtypes/ANY/NSEC3PARAM.py index 1a0c0e08..d1e62ebc 100644 --- a/lib/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/lib/dns/rdtypes/ANY/NSEC3PARAM.py @@ -25,7 +25,6 @@ import dns.rdata @dns.immutable.immutable class NSEC3PARAM(dns.rdata.Rdata): - """NSEC3PARAM record""" __slots__ = ["algorithm", "flags", "iterations", "salt"] diff --git a/lib/dns/rdtypes/ANY/OPENPGPKEY.py b/lib/dns/rdtypes/ANY/OPENPGPKEY.py index e5e25727..4d7a4b6c 100644 --- a/lib/dns/rdtypes/ANY/OPENPGPKEY.py +++ b/lib/dns/rdtypes/ANY/OPENPGPKEY.py @@ -25,7 +25,6 @@ import dns.tokenizer @dns.immutable.immutable class OPENPGPKEY(dns.rdata.Rdata): - """OPENPGPKEY record""" # see: RFC 7929 diff --git a/lib/dns/rdtypes/ANY/OPT.py b/lib/dns/rdtypes/ANY/OPT.py index d70e5373..d343dfa5 100644 --- a/lib/dns/rdtypes/ANY/OPT.py +++ b/lib/dns/rdtypes/ANY/OPT.py @@ -28,7 +28,6 @@ import dns.rdata @dns.immutable.immutable class OPT(dns.rdata.Rdata): - """OPT record""" __slots__ = ["options"] diff --git a/lib/dns/rdtypes/ANY/PTR.py b/lib/dns/rdtypes/ANY/PTR.py index 7fd5547d..98c36167 100644 --- a/lib/dns/rdtypes/ANY/PTR.py +++ b/lib/dns/rdtypes/ANY/PTR.py @@ -21,5 +21,4 @@ import dns.rdtypes.nsbase @dns.immutable.immutable class PTR(dns.rdtypes.nsbase.NSBase): - """PTR record""" diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index 9c64c6e2..9b74549d 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -23,7 +23,6 @@ import dns.rdata @dns.immutable.immutable class RP(dns.rdata.Rdata): - """RP record""" # see: RFC 1183 diff --git a/lib/dns/rdtypes/ANY/RRSIG.py b/lib/dns/rdtypes/ANY/RRSIG.py index 11605026..8beb4237 100644 --- a/lib/dns/rdtypes/ANY/RRSIG.py +++ b/lib/dns/rdtypes/ANY/RRSIG.py @@ -28,7 +28,6 @@ import dns.rdatatype class BadSigTime(dns.exception.DNSException): - """Time in DNS SIG or RRSIG resource record cannot be parsed.""" @@ -52,7 +51,6 @@ def posixtime_to_sigtime(what): @dns.immutable.immutable class RRSIG(dns.rdata.Rdata): - """RRSIG record""" __slots__ = [ diff --git a/lib/dns/rdtypes/ANY/RT.py b/lib/dns/rdtypes/ANY/RT.py index 950f2a06..5a4d45cf 100644 --- a/lib/dns/rdtypes/ANY/RT.py +++ b/lib/dns/rdtypes/ANY/RT.py @@ -21,5 +21,4 @@ import dns.rdtypes.mxbase @dns.immutable.immutable class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX): - """RT record""" diff --git a/lib/dns/rdtypes/ANY/SOA.py b/lib/dns/rdtypes/ANY/SOA.py index bde55e15..09aa8321 100644 --- a/lib/dns/rdtypes/ANY/SOA.py +++ b/lib/dns/rdtypes/ANY/SOA.py @@ -25,7 +25,6 @@ import dns.rdata @dns.immutable.immutable class SOA(dns.rdata.Rdata): - """SOA record""" # see: RFC 1035 diff --git a/lib/dns/rdtypes/ANY/SPF.py b/lib/dns/rdtypes/ANY/SPF.py index c403589a..1df3b705 100644 --- a/lib/dns/rdtypes/ANY/SPF.py +++ b/lib/dns/rdtypes/ANY/SPF.py @@ -21,7 +21,6 @@ import dns.rdtypes.txtbase @dns.immutable.immutable class SPF(dns.rdtypes.txtbase.TXTBase): - """SPF record""" # see: RFC 4408 diff --git a/lib/dns/rdtypes/ANY/SSHFP.py b/lib/dns/rdtypes/ANY/SSHFP.py index 67805452..d2c4b073 100644 --- a/lib/dns/rdtypes/ANY/SSHFP.py +++ b/lib/dns/rdtypes/ANY/SSHFP.py @@ -25,7 +25,6 @@ import dns.rdatatype @dns.immutable.immutable class SSHFP(dns.rdata.Rdata): - """SSHFP record""" # See RFC 4255 diff --git a/lib/dns/rdtypes/ANY/TKEY.py b/lib/dns/rdtypes/ANY/TKEY.py index d5f5fc45..5b490b82 100644 --- a/lib/dns/rdtypes/ANY/TKEY.py +++ b/lib/dns/rdtypes/ANY/TKEY.py @@ -25,7 +25,6 @@ import dns.rdata @dns.immutable.immutable class TKEY(dns.rdata.Rdata): - """TKEY Record""" __slots__ = [ diff --git a/lib/dns/rdtypes/ANY/TLSA.py b/lib/dns/rdtypes/ANY/TLSA.py index c9ba1991..4dffc553 100644 --- a/lib/dns/rdtypes/ANY/TLSA.py +++ b/lib/dns/rdtypes/ANY/TLSA.py @@ -6,5 +6,4 @@ import dns.rdtypes.tlsabase @dns.immutable.immutable class TLSA(dns.rdtypes.tlsabase.TLSABase): - """TLSA record""" diff --git a/lib/dns/rdtypes/ANY/TSIG.py b/lib/dns/rdtypes/ANY/TSIG.py index 1ae87ebe..79423826 100644 --- a/lib/dns/rdtypes/ANY/TSIG.py +++ b/lib/dns/rdtypes/ANY/TSIG.py @@ -26,7 +26,6 @@ import dns.rdata @dns.immutable.immutable class TSIG(dns.rdata.Rdata): - """TSIG record""" __slots__ = [ diff --git a/lib/dns/rdtypes/ANY/TXT.py b/lib/dns/rdtypes/ANY/TXT.py index f4e61930..6d4dae27 100644 --- a/lib/dns/rdtypes/ANY/TXT.py +++ b/lib/dns/rdtypes/ANY/TXT.py @@ -21,5 +21,4 @@ import dns.rdtypes.txtbase @dns.immutable.immutable class TXT(dns.rdtypes.txtbase.TXTBase): - """TXT record""" diff --git a/lib/dns/rdtypes/ANY/URI.py b/lib/dns/rdtypes/ANY/URI.py index 7463e277..2efbb305 100644 --- a/lib/dns/rdtypes/ANY/URI.py +++ b/lib/dns/rdtypes/ANY/URI.py @@ -27,7 +27,6 @@ import dns.rdtypes.util @dns.immutable.immutable class URI(dns.rdata.Rdata): - """URI record""" # see RFC 7553 diff --git a/lib/dns/rdtypes/ANY/X25.py b/lib/dns/rdtypes/ANY/X25.py index 06c14534..8375611d 100644 --- a/lib/dns/rdtypes/ANY/X25.py +++ b/lib/dns/rdtypes/ANY/X25.py @@ -25,7 +25,6 @@ import dns.tokenizer @dns.immutable.immutable class X25(dns.rdata.Rdata): - """X25 record""" # see RFC 1183 diff --git a/lib/dns/rdtypes/ANY/ZONEMD.py b/lib/dns/rdtypes/ANY/ZONEMD.py index 3062843b..c90e3ee1 100644 --- a/lib/dns/rdtypes/ANY/ZONEMD.py +++ b/lib/dns/rdtypes/ANY/ZONEMD.py @@ -11,7 +11,6 @@ import dns.zonetypes @dns.immutable.immutable class ZONEMD(dns.rdata.Rdata): - """ZONEMD record""" # See RFC 8976 diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py index e457f38a..583a88ac 100644 --- a/lib/dns/rdtypes/CH/A.py +++ b/lib/dns/rdtypes/CH/A.py @@ -23,7 +23,6 @@ import dns.rdtypes.mxbase @dns.immutable.immutable class A(dns.rdata.Rdata): - """A record for Chaosnet""" # domain: the domain of the address diff --git a/lib/dns/rdtypes/IN/A.py b/lib/dns/rdtypes/IN/A.py index 713d5eea..e09d6110 100644 --- a/lib/dns/rdtypes/IN/A.py +++ b/lib/dns/rdtypes/IN/A.py @@ -24,7 +24,6 @@ import dns.tokenizer @dns.immutable.immutable class A(dns.rdata.Rdata): - """A record.""" __slots__ = ["address"] diff --git a/lib/dns/rdtypes/IN/AAAA.py b/lib/dns/rdtypes/IN/AAAA.py index f8237b44..0cd139e7 100644 --- a/lib/dns/rdtypes/IN/AAAA.py +++ b/lib/dns/rdtypes/IN/AAAA.py @@ -24,7 +24,6 @@ import dns.tokenizer @dns.immutable.immutable class AAAA(dns.rdata.Rdata): - """AAAA record.""" __slots__ = ["address"] diff --git a/lib/dns/rdtypes/IN/APL.py b/lib/dns/rdtypes/IN/APL.py index f1bb01db..44cb3fef 100644 --- a/lib/dns/rdtypes/IN/APL.py +++ b/lib/dns/rdtypes/IN/APL.py @@ -29,7 +29,6 @@ import dns.tokenizer @dns.immutable.immutable class APLItem: - """An APL list item.""" __slots__ = ["family", "negation", "address", "prefix"] @@ -80,7 +79,6 @@ class APLItem: @dns.immutable.immutable class APL(dns.rdata.Rdata): - """APL record.""" # see: RFC 3123 diff --git a/lib/dns/rdtypes/IN/DHCID.py b/lib/dns/rdtypes/IN/DHCID.py index 65f85897..723492fa 100644 --- a/lib/dns/rdtypes/IN/DHCID.py +++ b/lib/dns/rdtypes/IN/DHCID.py @@ -24,7 +24,6 @@ import dns.rdata @dns.immutable.immutable class DHCID(dns.rdata.Rdata): - """DHCID record""" # see: RFC 4701 diff --git a/lib/dns/rdtypes/IN/IPSECKEY.py b/lib/dns/rdtypes/IN/IPSECKEY.py index 8bb2bcb6..e3a66157 100644 --- a/lib/dns/rdtypes/IN/IPSECKEY.py +++ b/lib/dns/rdtypes/IN/IPSECKEY.py @@ -29,7 +29,6 @@ class Gateway(dns.rdtypes.util.Gateway): @dns.immutable.immutable class IPSECKEY(dns.rdata.Rdata): - """IPSECKEY record""" # see: RFC 4025 diff --git a/lib/dns/rdtypes/IN/KX.py b/lib/dns/rdtypes/IN/KX.py index a03d1d51..6073df47 100644 --- a/lib/dns/rdtypes/IN/KX.py +++ b/lib/dns/rdtypes/IN/KX.py @@ -21,5 +21,4 @@ import dns.rdtypes.mxbase @dns.immutable.immutable class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX): - """KX record""" diff --git a/lib/dns/rdtypes/IN/NAPTR.py b/lib/dns/rdtypes/IN/NAPTR.py index 1f1f5a12..195d1cba 100644 --- a/lib/dns/rdtypes/IN/NAPTR.py +++ b/lib/dns/rdtypes/IN/NAPTR.py @@ -33,7 +33,6 @@ def _write_string(file, s): @dns.immutable.immutable class NAPTR(dns.rdata.Rdata): - """NAPTR record""" # see: RFC 3403 diff --git a/lib/dns/rdtypes/IN/NSAP.py b/lib/dns/rdtypes/IN/NSAP.py index be8581e6..a4854b3f 100644 --- a/lib/dns/rdtypes/IN/NSAP.py +++ b/lib/dns/rdtypes/IN/NSAP.py @@ -25,7 +25,6 @@ import dns.tokenizer @dns.immutable.immutable class NSAP(dns.rdata.Rdata): - """NSAP record.""" # see: RFC 1706 diff --git a/lib/dns/rdtypes/IN/NSAP_PTR.py b/lib/dns/rdtypes/IN/NSAP_PTR.py index 0a18fdce..ce1c6632 100644 --- a/lib/dns/rdtypes/IN/NSAP_PTR.py +++ b/lib/dns/rdtypes/IN/NSAP_PTR.py @@ -21,5 +21,4 @@ import dns.rdtypes.nsbase @dns.immutable.immutable class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS): - """NSAP-PTR record""" diff --git a/lib/dns/rdtypes/IN/PX.py b/lib/dns/rdtypes/IN/PX.py index 5c0aa81e..cdca1532 100644 --- a/lib/dns/rdtypes/IN/PX.py +++ b/lib/dns/rdtypes/IN/PX.py @@ -26,7 +26,6 @@ import dns.rdtypes.util @dns.immutable.immutable class PX(dns.rdata.Rdata): - """PX record.""" # see: RFC 2163 diff --git a/lib/dns/rdtypes/IN/SRV.py b/lib/dns/rdtypes/IN/SRV.py index 84c54007..5adef98f 100644 --- a/lib/dns/rdtypes/IN/SRV.py +++ b/lib/dns/rdtypes/IN/SRV.py @@ -26,7 +26,6 @@ import dns.rdtypes.util @dns.immutable.immutable class SRV(dns.rdata.Rdata): - """SRV record""" # see: RFC 2782 diff --git a/lib/dns/rdtypes/IN/WKS.py b/lib/dns/rdtypes/IN/WKS.py index 26d287a3..881a7849 100644 --- a/lib/dns/rdtypes/IN/WKS.py +++ b/lib/dns/rdtypes/IN/WKS.py @@ -33,7 +33,6 @@ except OSError: @dns.immutable.immutable class WKS(dns.rdata.Rdata): - """WKS record""" # see: RFC 1035 diff --git a/lib/dns/rdtypes/dnskeybase.py b/lib/dns/rdtypes/dnskeybase.py index 3bfcf860..db300f8b 100644 --- a/lib/dns/rdtypes/dnskeybase.py +++ b/lib/dns/rdtypes/dnskeybase.py @@ -36,7 +36,6 @@ class Flag(enum.IntFlag): @dns.immutable.immutable class DNSKEYBase(dns.rdata.Rdata): - """Base class for rdata that is like a DNSKEY record""" __slots__ = ["flags", "protocol", "algorithm", "key"] diff --git a/lib/dns/rdtypes/dsbase.py b/lib/dns/rdtypes/dsbase.py index 1ad0b7a5..cd21f026 100644 --- a/lib/dns/rdtypes/dsbase.py +++ b/lib/dns/rdtypes/dsbase.py @@ -26,7 +26,6 @@ import dns.rdatatype @dns.immutable.immutable class DSBase(dns.rdata.Rdata): - """Base class for rdata that is like a DS record""" __slots__ = ["key_tag", "algorithm", "digest_type", "digest"] diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index 4c4068b2..751087b4 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -22,7 +22,6 @@ import dns.rdata @dns.immutable.immutable class EUIBase(dns.rdata.Rdata): - """EUIxx record""" # see: rfc7043.txt diff --git a/lib/dns/rdtypes/mxbase.py b/lib/dns/rdtypes/mxbase.py index a6bae078..6d5e3d87 100644 --- a/lib/dns/rdtypes/mxbase.py +++ b/lib/dns/rdtypes/mxbase.py @@ -28,7 +28,6 @@ import dns.rdtypes.util @dns.immutable.immutable class MXBase(dns.rdata.Rdata): - """Base class for rdata that is like an MX record.""" __slots__ = ["preference", "exchange"] @@ -71,7 +70,6 @@ class MXBase(dns.rdata.Rdata): @dns.immutable.immutable class UncompressedMX(MXBase): - """Base class for rdata that is like an MX record, but whose name is not compressed when converted to DNS wire format, and whose digestable form is not downcased.""" @@ -82,7 +80,6 @@ class UncompressedMX(MXBase): @dns.immutable.immutable class UncompressedDowncasingMX(MXBase): - """Base class for rdata that is like an MX record, but whose name is not compressed when convert to DNS wire format.""" diff --git a/lib/dns/rdtypes/nsbase.py b/lib/dns/rdtypes/nsbase.py index 56d94235..904224f0 100644 --- a/lib/dns/rdtypes/nsbase.py +++ b/lib/dns/rdtypes/nsbase.py @@ -25,7 +25,6 @@ import dns.rdata @dns.immutable.immutable class NSBase(dns.rdata.Rdata): - """Base class for rdata that is like an NS record.""" __slots__ = ["target"] @@ -56,7 +55,6 @@ class NSBase(dns.rdata.Rdata): @dns.immutable.immutable class UncompressedNS(NSBase): - """Base class for rdata that is like an NS record, but whose name is not compressed when convert to DNS wire format, and whose digestable form is not downcased.""" diff --git a/lib/dns/rdtypes/svcbbase.py b/lib/dns/rdtypes/svcbbase.py index ba5b53d2..05652413 100644 --- a/lib/dns/rdtypes/svcbbase.py +++ b/lib/dns/rdtypes/svcbbase.py @@ -2,7 +2,6 @@ import base64 import enum -import io import struct import dns.enum @@ -13,6 +12,7 @@ import dns.ipv6 import dns.name import dns.rdata import dns.rdtypes.util +import dns.renderer import dns.tokenizer import dns.wire @@ -427,7 +427,6 @@ def _validate_and_define(params, key, value): @dns.immutable.immutable class SVCBBase(dns.rdata.Rdata): - """Base class for SVCB-like records""" # see: draft-ietf-dnsop-svcb-https-11 @@ -521,19 +520,10 @@ class SVCBBase(dns.rdata.Rdata): for key in sorted(self.params): file.write(struct.pack("!H", key)) value = self.params[key] - # placeholder for length (or actual length of empty values) - file.write(struct.pack("!H", 0)) - if value is None: - continue - else: - start = file.tell() - value.to_wire(file, origin) - end = file.tell() - assert end - start < 65536 - file.seek(start - 2) - stuff = struct.pack("!H", end - start) - file.write(stuff) - file.seek(0, io.SEEK_END) + with dns.renderer.prefixed_length(file, 2): + # Note that we're still writing a length of zero if the value is None + if value is not None: + value.to_wire(file, origin) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/lib/dns/rdtypes/tlsabase.py b/lib/dns/rdtypes/tlsabase.py index 4cdb7ab3..a059d2c4 100644 --- a/lib/dns/rdtypes/tlsabase.py +++ b/lib/dns/rdtypes/tlsabase.py @@ -25,7 +25,6 @@ import dns.rdatatype @dns.immutable.immutable class TLSABase(dns.rdata.Rdata): - """Base class for TLSA and SMIMEA records""" # see: RFC 6698 diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index fdbfb646..44d6df57 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -17,18 +17,17 @@ """TXT-like base class.""" -import struct from typing import Any, Dict, Iterable, Optional, Tuple, Union import dns.exception import dns.immutable import dns.rdata +import dns.renderer import dns.tokenizer @dns.immutable.immutable class TXTBase(dns.rdata.Rdata): - """Base class for rdata that is like a TXT record (see RFC 1035).""" __slots__ = ["strings"] @@ -56,7 +55,7 @@ class TXTBase(dns.rdata.Rdata): self, origin: Optional[dns.name.Name] = None, relativize: bool = True, - **kw: Dict[str, Any] + **kw: Dict[str, Any], ) -> str: txt = "" prefix = "" @@ -93,10 +92,8 @@ class TXTBase(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for s in self.strings: - l = len(s) - assert l < 256 - file.write(struct.pack("!B", l)) - file.write(s) + with dns.renderer.prefixed_length(file, 1): + file.write(s) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/lib/dns/renderer.py b/lib/dns/renderer.py index 53e7c0f6..a77481f6 100644 --- a/lib/dns/renderer.py +++ b/lib/dns/renderer.py @@ -32,6 +32,24 @@ AUTHORITY = 2 ADDITIONAL = 3 +@contextlib.contextmanager +def prefixed_length(output, length_length): + output.write(b"\00" * length_length) + start = output.tell() + yield + end = output.tell() + length = end - start + if length > 0: + try: + output.seek(start - length_length) + try: + output.write(length.to_bytes(length_length, "big")) + except OverflowError: + raise dns.exception.FormError + finally: + output.seek(end) + + class Renderer: """Helper class for building DNS wire-format messages. @@ -134,6 +152,15 @@ class Renderer: self._rollback(start) raise dns.exception.TooBig + @contextlib.contextmanager + def _temporarily_seek_to(self, where): + current = self.output.tell() + try: + self.output.seek(where) + yield + finally: + self.output.seek(current) + def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): """Add a question to the message.""" @@ -269,18 +296,14 @@ class Renderer: with self._track_size(): keyname.to_wire(self.output, compress, self.origin) self.output.write( - struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0) + struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0) ) - rdata_start = self.output.tell() - tsig.to_wire(self.output) + with prefixed_length(self.output, 2): + tsig.to_wire(self.output) - after = self.output.tell() - self.output.seek(rdata_start - 2) - self.output.write(struct.pack("!H", after - rdata_start)) self.counts[ADDITIONAL] += 1 - self.output.seek(10) - self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) - self.output.seek(0, io.SEEK_END) + with self._temporarily_seek_to(10): + self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) def write_header(self): """Write the DNS message header. @@ -290,19 +313,18 @@ class Renderer: is added. """ - self.output.seek(0) - self.output.write( - struct.pack( - "!HHHHHH", - self.id, - self.flags, - self.counts[0], - self.counts[1], - self.counts[2], - self.counts[3], + with self._temporarily_seek_to(0): + self.output.write( + struct.pack( + "!HHHHHH", + self.id, + self.flags, + self.counts[0], + self.counts[1], + self.counts[2], + self.counts[3], + ) ) - ) - self.output.seek(0, io.SEEK_END) def get_wire(self): """Return the wire format message.""" diff --git a/lib/dns/rrset.py b/lib/dns/rrset.py index 350de13e..6f39b108 100644 --- a/lib/dns/rrset.py +++ b/lib/dns/rrset.py @@ -26,7 +26,6 @@ import dns.renderer class RRset(dns.rdataset.Rdataset): - """A DNS RRset (named rdataset). RRset inherits from Rdataset, and RRsets can be treated as @@ -132,7 +131,7 @@ class RRset(dns.rdataset.Rdataset): self, origin: Optional[dns.name.Name] = None, relativize: bool = True, - **kw: Dict[str, Any] + **kw: Dict[str, Any], ) -> str: """Convert the RRset into DNS zone file format. @@ -159,7 +158,7 @@ class RRset(dns.rdataset.Rdataset): file: Any, compress: Optional[dns.name.CompressType] = None, # type: ignore origin: Optional[dns.name.Name] = None, - **kw: Dict[str, Any] + **kw: Dict[str, Any], ) -> int: """Convert the RRset to wire format. @@ -231,7 +230,7 @@ def from_text( ttl: int, rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.RdataType, str], - *text_rdatas: Any + *text_rdatas: Any, ) -> RRset: """Create an RRset with the specified name, TTL, class, and type and with the specified rdatas in text format. diff --git a/lib/dns/set.py b/lib/dns/set.py index fa50ed97..f0fb0d50 100644 --- a/lib/dns/set.py +++ b/lib/dns/set.py @@ -19,7 +19,6 @@ import itertools class Set: - """A simple set class. This class was originally used to deal with sets being missing in diff --git a/lib/dns/transaction.py b/lib/dns/transaction.py index 21dea775..84e54f7d 100644 --- a/lib/dns/transaction.py +++ b/lib/dns/transaction.py @@ -203,7 +203,7 @@ class Transaction: - name - - name, rdataclass, rdatatype, [covers] + - name, rdatatype, [covers] - name, rdataset... @@ -222,7 +222,7 @@ class Transaction: - name - - name, rdataclass, rdatatype, [covers] + - name, rdatatype, [covers] - name, rdataset... diff --git a/lib/dns/tsig.py b/lib/dns/tsig.py index 58760f5f..780852e8 100644 --- a/lib/dns/tsig.py +++ b/lib/dns/tsig.py @@ -29,47 +29,38 @@ import dns.rdataclass class BadTime(dns.exception.DNSException): - """The current time is not within the TSIG's validity time.""" class BadSignature(dns.exception.DNSException): - """The TSIG signature fails to verify.""" class BadKey(dns.exception.DNSException): - """The TSIG record owner name does not match the key.""" class BadAlgorithm(dns.exception.DNSException): - """The TSIG algorithm does not match the key.""" class PeerError(dns.exception.DNSException): - """Base class for all TSIG errors generated by the remote peer""" class PeerBadKey(PeerError): - """The peer didn't know the key we used""" class PeerBadSignature(PeerError): - """The peer didn't like the signature we sent""" class PeerBadTime(PeerError): - """The peer didn't like the time we sent""" class PeerBadTruncation(PeerError): - """The peer didn't like amount of truncation in the TSIG we sent""" diff --git a/lib/dns/version.py b/lib/dns/version.py index 1f1fbf2d..251f2583 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -20,9 +20,9 @@ #: MAJOR MAJOR = 2 #: MINOR -MINOR = 4 +MINOR = 6 #: MICRO -MICRO = 2 +MICRO = 1 #: RELEASELEVEL RELEASELEVEL = 0x0F #: SERIAL diff --git a/lib/dns/win32util.py b/lib/dns/win32util.py index b2ca61da..aaa7e93e 100644 --- a/lib/dns/win32util.py +++ b/lib/dns/win32util.py @@ -1,5 +1,7 @@ import sys +import dns._features + if sys.platform == "win32": from typing import Any @@ -15,14 +17,14 @@ if sys.platform == "win32": except KeyError: WindowsError = Exception - try: + if dns._features.have("wmi"): import threading import pythoncom # pylint: disable=import-error import wmi # pylint: disable=import-error _have_wmi = True - except Exception: + else: _have_wmi = False def _config_domain(domain): @@ -51,9 +53,10 @@ if sys.platform == "win32": try: system = wmi.WMI() for interface in system.Win32_NetworkAdapterConfiguration(): - if interface.IPEnabled and interface.DNSDomain: - self.info.domain = _config_domain(interface.DNSDomain) + if interface.IPEnabled and interface.DNSServerSearchOrder: self.info.nameservers = list(interface.DNSServerSearchOrder) + if interface.DNSDomain: + self.info.domain = _config_domain(interface.DNSDomain) if interface.DNSDomainSuffixSearchOrder: self.info.search = [ _config_domain(x) diff --git a/lib/dns/zone.py b/lib/dns/zone.py index 9e763f5f..844919e4 100644 --- a/lib/dns/zone.py +++ b/lib/dns/zone.py @@ -21,7 +21,18 @@ import contextlib import io import os import struct -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import dns.exception import dns.grange @@ -43,47 +54,70 @@ from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers class BadZone(dns.exception.DNSException): - """The DNS zone is malformed.""" class NoSOA(BadZone): - """The DNS zone has no SOA RR at its origin.""" class NoNS(BadZone): - """The DNS zone has no NS RRset at its origin.""" class UnknownOrigin(BadZone): - """The DNS zone's origin is unknown.""" class UnsupportedDigestScheme(dns.exception.DNSException): - """The zone digest's scheme is unsupported.""" class UnsupportedDigestHashAlgorithm(dns.exception.DNSException): - """The zone digest's origin is unsupported.""" class NoDigest(dns.exception.DNSException): - """The DNS zone has no ZONEMD RRset at its origin.""" class DigestVerificationFailure(dns.exception.DNSException): - """The ZONEMD digest failed to verify.""" -class Zone(dns.transaction.TransactionManager): +def _validate_name( + name: dns.name.Name, + origin: Optional[dns.name.Name], + relativize: bool, +) -> dns.name.Name: + # This name validation code is shared by Zone and Version + if origin is None: + # This should probably never happen as other code (e.g. + # _rr_line) will notice the lack of an origin before us, but + # we check just in case! + raise KeyError("no zone origin is defined") + if name.is_absolute(): + if not name.is_subdomain(origin): + raise KeyError("name parameter must be a subdomain of the zone origin") + if relativize: + name = name.relativize(origin) + else: + # We have a relative name. Make sure that the derelativized name is + # not too long. + try: + abs_name = name.derelativize(origin) + except dns.name.NameTooLong: + # We map dns.name.NameTooLong to KeyError to be consistent with + # the other exceptions above. + raise KeyError("relative name too long for zone") + if not relativize: + # We have a relative name in a non-relative zone, so use the + # derelativized name. + name = abs_name + return name + +class Zone(dns.transaction.TransactionManager): """A DNS zone. A ``Zone`` is a mapping from names to nodes. The zone object may be @@ -94,7 +128,10 @@ class Zone(dns.transaction.TransactionManager): the zone. """ - node_factory = dns.node.Node + node_factory: Callable[[], dns.node.Node] = dns.node.Node + map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict + writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None + immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None __slots__ = ["rdclass", "origin", "nodes", "relativize"] @@ -125,7 +162,7 @@ class Zone(dns.transaction.TransactionManager): raise ValueError("origin parameter must be an absolute name") self.origin = origin self.rdclass = rdclass - self.nodes: Dict[dns.name.Name, dns.node.Node] = {} + self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory() self.relativize = relativize def __eq__(self, other): @@ -154,26 +191,13 @@ class Zone(dns.transaction.TransactionManager): return not self.__eq__(other) def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: + # Note that any changes in this method should have corresponding changes + # made in the Version _validate_name() method. if isinstance(name, str): name = dns.name.from_text(name, None) elif not isinstance(name, dns.name.Name): raise KeyError("name parameter must be convertible to a DNS name") - if name.is_absolute(): - if self.origin is None: - # This should probably never happen as other code (e.g. - # _rr_line) will notice the lack of an origin before us, but - # we check just in case! - raise KeyError("no zone origin is defined") - if not name.is_subdomain(self.origin): - raise KeyError("name parameter must be a subdomain of the zone origin") - if self.relativize: - name = name.relativize(self.origin) - elif not self.relativize: - # We have a relative name in a non-relative zone, so derelativize. - if self.origin is None: - raise KeyError("no zone origin is defined") - name = name.derelativize(self.origin) - return name + return _validate_name(name, self.origin, self.relativize) def __getitem__(self, key): key = self._validate_name(key) @@ -252,9 +276,6 @@ class Zone(dns.transaction.TransactionManager): *create*, a ``bool``. If true, the node will be created if it does not exist. - Raises ``KeyError`` if the name is not known and create was - not specified, or if the name was not a subdomain of the origin. - Returns a ``dns.node.Node`` or ``None``. """ @@ -527,9 +548,6 @@ class Zone(dns.transaction.TransactionManager): *create*, a ``bool``. If true, the node will be created if it does not exist. - Raises ``KeyError`` if the name is not known and create was - not specified, or if the name was not a subdomain of the origin. - Returns a ``dns.rrset.RRset`` or ``None``. """ @@ -952,7 +970,7 @@ class Version: self, zone: Zone, id: int, - nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None, + nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None, origin: Optional[dns.name.Name] = None, ): self.zone = zone @@ -960,26 +978,11 @@ class Version: if nodes is not None: self.nodes = nodes else: - self.nodes = {} + self.nodes = zone.map_factory() self.origin = origin def _validate_name(self, name: dns.name.Name) -> dns.name.Name: - if name.is_absolute(): - if self.origin is None: - # This should probably never happen as other code (e.g. - # _rr_line) will notice the lack of an origin before us, but - # we check just in case! - raise KeyError("no zone origin is defined") - if not name.is_subdomain(self.origin): - raise KeyError("name is not a subdomain of the zone origin") - if self.zone.relativize: - name = name.relativize(self.origin) - elif not self.zone.relativize: - # We have a relative name in a non-relative zone, so derelativize. - if self.origin is None: - raise KeyError("no zone origin is defined") - name = name.derelativize(self.origin) - return name + return _validate_name(name, self.origin, self.zone.relativize) def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: name = self._validate_name(name) @@ -1085,7 +1088,9 @@ class ImmutableVersion(Version): version.nodes[name] = ImmutableVersionedNode(node) # We're changing the type of the nodes dictionary here on purpose, so # we ignore the mypy error. - self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore + self.nodes = dns.immutable.Dict( + version.nodes, True, self.zone.map_factory + ) # type: ignore class Transaction(dns.transaction.Transaction): @@ -1101,7 +1106,10 @@ class Transaction(dns.transaction.Transaction): def _setup_version(self): assert self.version is None - self.version = WritableVersion(self.zone, self.replacement) + factory = self.manager.writable_version_factory + if factory is None: + factory = WritableVersion + self.version = factory(self.zone, self.replacement) def _get_rdataset(self, name, rdtype, covers): return self.version.get_rdataset(name, rdtype, covers) @@ -1132,7 +1140,10 @@ class Transaction(dns.transaction.Transaction): self.zone._end_read(self) elif commit and len(self.version.changed) > 0: if self.make_immutable: - version = ImmutableVersion(self.version) + factory = self.manager.immutable_version_factory + if factory is None: + factory = ImmutableVersion + version = factory(self.version) else: version = self.version self.zone._commit_version(self, version, self.version.origin) @@ -1168,6 +1179,48 @@ class Transaction(dns.transaction.Transaction): return (absolute, relativize, effective) +def _from_text( + text: Any, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = False, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, + allow_directives: Union[bool, Iterable[str]] = True, +) -> Zone: + # See the comments for the public APIs from_text() and from_file() for + # details. + + # 'text' can also be a file, but we don't publish that fact + # since it's an implementation detail. The official file + # interface is from_file(). + + if filename is None: + filename = "" + zone = zone_factory(origin, rdclass, relativize=relativize) + with zone.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) + reader = dns.zonefile.Reader( + tok, + rdclass, + txn, + allow_include=allow_include, + allow_directives=allow_directives, + ) + try: + reader.read() + except dns.zonefile.UnknownOrigin: + # for backwards compatibility + raise dns.zone.UnknownOrigin + # Now that we're done reading, do some basic checking of the zone. + if check_origin: + zone.check_origin() + return zone + + def from_text( text: str, origin: Optional[Union[dns.name.Name, str]] = None, @@ -1228,32 +1281,18 @@ def from_text( Returns a subclass of ``dns.zone.Zone``. """ - - # 'text' can also be a file, but we don't publish that fact - # since it's an implementation detail. The official file - # interface is from_file(). - - if filename is None: - filename = "" - zone = zone_factory(origin, rdclass, relativize=relativize) - with zone.writer(True) as txn: - tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) - reader = dns.zonefile.Reader( - tok, - rdclass, - txn, - allow_include=allow_include, - allow_directives=allow_directives, - ) - try: - reader.read() - except dns.zonefile.UnknownOrigin: - # for backwards compatibility - raise dns.zone.UnknownOrigin - # Now that we're done reading, do some basic checking of the zone. - if check_origin: - zone.check_origin() - return zone + return _from_text( + text, + origin, + rdclass, + relativize, + zone_factory, + filename, + allow_include, + check_origin, + idna_codec, + allow_directives, + ) def from_file( @@ -1324,7 +1363,7 @@ def from_file( else: cm = contextlib.nullcontext(f) with cm as f: - return from_text( + return _from_text( f, origin, rdclass, diff --git a/lib/dns/zonefile.py b/lib/dns/zonefile.py index 27f04924..af064e73 100644 --- a/lib/dns/zonefile.py +++ b/lib/dns/zonefile.py @@ -86,7 +86,6 @@ def _upper_dollarize(s): class Reader: - """Read a DNS zone file into a transaction.""" def __init__( diff --git a/requirements.txt b/requirements.txt index fe84492e..64458e26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ cheroot==10.0.0 cherrypy==18.8.0 cloudinary==1.34.0 distro==1.9.0 -dnspython==2.4.2 +dnspython==2.6.1 facebook-sdk==3.1.0 future==0.18.3 ga4mp==2.0.4