Bump dnspython from 2.4.2 to 2.6.1 (#2264)

* Bump dnspython from 2.4.2 to 2.6.1

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.4.2 to 2.6.1.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.4.2...v2.6.1)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.6.1

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2024-03-24 15:25:23 -07:00 committed by GitHub
parent aca7e72715
commit cfefa928be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 1052 additions and 459 deletions

View file

@ -7,7 +7,9 @@ import socket
import sys import sys
import dns._asyncbackend import dns._asyncbackend
import dns._features
import dns.exception import dns.exception
import dns.inet
_is_win32 = sys.platform == "win32" _is_win32 = sys.platform == "win32"
@ -121,7 +123,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
return self.writer.get_extra_info("peercert") return self.writer.get_extra_info("peercert")
try: if dns._features.have("doh"):
import anyio import anyio
import httpcore import httpcore
import httpcore._backends.anyio import httpcore._backends.anyio
@ -205,7 +207,7 @@ try:
resolver, local_port, bootstrap_address, family resolver, local_port, bootstrap_address, family
) )
except ImportError: else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
@ -224,14 +226,12 @@ class Backend(dns._asyncbackend.Backend):
ssl_context=None, ssl_context=None,
server_hostname=None, server_hostname=None,
): ):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop() loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM: if socktype == socket.SOCK_DGRAM:
if _is_win32 and source is None:
# Win32 wants explicit binding before recvfrom(). This is the
# proper fix for [#637].
source = (dns.inet.any_for_af(af), 0)
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, _DatagramProtocol,
source, source,
@ -266,7 +266,7 @@ class Backend(dns._asyncbackend.Backend):
await asyncio.sleep(interval) await asyncio.sleep(interval)
def datagram_connection_required(self): def datagram_connection_required(self):
return _is_win32 return False
def get_transport_class(self): def get_transport_class(self):
return _HTTPTransport return _HTTPTransport

92
lib/dns/_features.py Normal file
View file

@ -0,0 +1,92 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import importlib.metadata
import itertools
import string
from typing import Dict, List, Tuple
def _tuple_from_text(version: str) -> Tuple:
text_parts = version.split(".")
int_parts = []
for text_part in text_parts:
digit_prefix = "".join(
itertools.takewhile(lambda x: x in string.digits, text_part)
)
try:
int_parts.append(int(digit_prefix))
except Exception:
break
return tuple(int_parts)
def _version_check(
requirement: str,
) -> bool:
"""Is the requirement fulfilled?
The requirement must be of the form
package>=version
"""
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
except Exception:
return False
t_version = _tuple_from_text(version)
t_minimum = _tuple_from_text(minimum)
if t_version < t_minimum:
return False
return True
_cache: Dict[str, bool] = {}
def have(feature: str) -> bool:
"""Is *feature* available?
This tests if all optional packages needed for the
feature are available and recent enough.
Returns ``True`` if the feature is available,
and ``False`` if it is not or if metadata is
missing.
"""
value = _cache.get(feature)
if value is not None:
return value
requirements = _requirements.get(feature)
if requirements is None:
# we make a cache entry here for consistency not performance
_cache[feature] = False
return False
ok = True
for requirement in requirements:
if not _version_check(requirement):
ok = False
break
_cache[feature] = ok
return ok
def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
_cache[feature] = enabled
_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=41"],
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
"doq": ["aioquic>=0.9.25"],
"idna": ["idna>=3.6"],
"trio": ["trio>=0.23"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements
}

View file

@ -8,9 +8,13 @@ import trio
import trio.socket # type: ignore import trio.socket # type: ignore
import dns._asyncbackend import dns._asyncbackend
import dns._features
import dns.exception import dns.exception
import dns.inet import dns.inet
if not dns._features.have("trio"):
raise ImportError("trio not found or too old")
def _maybe_timeout(timeout): def _maybe_timeout(timeout):
if timeout is not None: if timeout is not None:
@ -95,7 +99,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
raise NotImplementedError raise NotImplementedError
try: if dns._features.have("doh"):
import httpcore import httpcore
import httpcore._backends.trio import httpcore._backends.trio
import httpx import httpx
@ -177,7 +181,7 @@ try:
resolver, local_port, bootstrap_address, family resolver, local_port, bootstrap_address, family
) )
except ImportError: else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore

View file

@ -32,7 +32,7 @@ def get_backend(name: str) -> Backend:
*name*, a ``str``, the name of the backend. Currently the "trio" *name*, a ``str``, the name of the backend. Currently the "trio"
and "asyncio" backends are available. and "asyncio" backends are available.
Raises NotImplementError if an unknown backend name is specified. Raises NotImplementedError if an unknown backend name is specified.
""" """
# pylint: disable=import-outside-toplevel,redefined-outer-name # pylint: disable=import-outside-toplevel,redefined-outer-name
backend = _backends.get(name) backend = _backends.get(name)

View file

@ -41,7 +41,7 @@ from dns.query import (
NoDOQ, NoDOQ,
UDPMode, UDPMode,
_compute_times, _compute_times,
_have_http2, _make_dot_ssl_context,
_matches_destination, _matches_destination,
_remaining, _remaining,
have_doh, have_doh,
@ -120,6 +120,8 @@ async def receive_udp(
request_mac: Optional[bytes] = b"", request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False, ignore_trailing: bool = False,
raise_on_truncation: bool = False, raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: Optional[dns.message.Message] = None,
) -> Any: ) -> Any:
"""Read a DNS message from a UDP socket. """Read a DNS message from a UDP socket.
@ -133,13 +135,14 @@ async def receive_udp(
""" """
wire = b"" wire = b""
while 1: while True:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if _matches_destination( if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected sock.family, from_address, destination, ignore_unexpected
): ):
break continue
received_time = time.time() received_time = time.time()
try:
r = dns.message.from_wire( r = dns.message.from_wire(
wire, wire,
keyring=keyring, keyring=keyring,
@ -148,6 +151,23 @@ async def receive_udp(
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation, raise_on_truncation=raise_on_truncation,
) )
except dns.message.Truncated as e:
# See the comment in query.py for details.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
return (r, received_time, from_address) return (r, received_time, from_address)
@ -164,6 +184,7 @@ async def udp(
raise_on_truncation: bool = False, raise_on_truncation: bool = False,
sock: Optional[dns.asyncbackend.DatagramSocket] = None, sock: Optional[dns.asyncbackend.DatagramSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None, backend: Optional[dns.asyncbackend.Backend] = None,
ignore_errors: bool = False,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP. """Return the response obtained after sending a query via UDP.
@ -205,9 +226,13 @@ async def udp(
q.mac, q.mac,
ignore_trailing, ignore_trailing,
raise_on_truncation, raise_on_truncation,
ignore_errors,
q,
) )
r.time = received_time - begin_time r.time = received_time - begin_time
if not q.is_response(r): # We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse raise BadResponse
return r return r
@ -225,6 +250,7 @@ async def udp_with_fallback(
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None, udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None, tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None, backend: Optional[dns.asyncbackend.Backend] = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]: ) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back """Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response. to TCP if UDP results in a truncated response.
@ -260,6 +286,7 @@ async def udp_with_fallback(
True, True,
udp_sock, udp_sock,
backend, backend,
ignore_errors,
) )
return (response, False) return (response, False)
except dns.message.Truncated: except dns.message.Truncated:
@ -292,14 +319,12 @@ async def send_tcp(
""" """
if isinstance(what, dns.message.Message): if isinstance(what, dns.message.Message):
wire = what.to_wire() tcpmsg = what.to_wire(prepend_length=True)
else: else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us # copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed # avoid writev() or doing a short write that would get pushed
# onto the net # onto the net
tcpmsg = struct.pack("!H", l) + wire tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time() sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time) return (len(tcpmsg), sent_time)
@ -418,6 +443,7 @@ async def tls(
backend: Optional[dns.asyncbackend.Backend] = None, backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None, ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None, server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS. """Return the response obtained after sending a query via TLS.
@ -439,11 +465,7 @@ async def tls(
cm: contextlib.AbstractAsyncContextManager = NullContext(sock) cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else: else:
if ssl_context is None: if ssl_context is None:
# See the comment about ssl.create_default_context() in query.py ssl_context = _make_dot_ssl_context(server_hostname, verify)
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port) stuple = _source_tuple(af, source, source_port)
dtuple = (where, port) dtuple = (where, port)
@ -538,7 +560,7 @@ async def https(
transport = backend.get_transport_class()( transport = backend.get_transport_class()(
local_address=local_address, local_address=local_address,
http1=True, http1=True,
http2=_have_http2, http2=True,
verify=verify, verify=verify,
local_port=local_port, local_port=local_port,
bootstrap_address=bootstrap_address, bootstrap_address=bootstrap_address,
@ -550,7 +572,7 @@ async def https(
cm: contextlib.AbstractAsyncContextManager = NullContext(client) cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else: else:
cm = httpx.AsyncClient( cm = httpx.AsyncClient(
http1=True, http2=_have_http2, verify=verify, transport=transport http1=True, http2=True, verify=verify, transport=transport
) )
async with cm as the_client: async with cm as the_client:

View file

@ -27,6 +27,7 @@ import time
from datetime import datetime from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
import dns._features
import dns.exception import dns.exception
import dns.name import dns.name
import dns.node import dns.node
@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs):
) # pragma: no cover ) # pragma: no cover
try: if dns._features.have("dnssec"):
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
@ -1184,20 +1185,20 @@ try:
get_algorithm_cls_from_dnskey, get_algorithm_cls_from_dnskey,
) )
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
except ImportError: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False
else:
validate = _validate # type: ignore validate = _validate # type: ignore
validate_rrsig = _validate_rrsig # type: ignore validate_rrsig = _validate_rrsig # type: ignore
sign = _sign sign = _sign
make_dnskey = _make_dnskey make_dnskey = _make_dnskey
make_cdnskey = _make_cdnskey make_cdnskey = _make_cdnskey
_have_pyca = True _have_pyca = True
else: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False
### BEGIN generated Algorithm constants ### BEGIN generated Algorithm constants

View file

@ -1,9 +1,12 @@
from typing import Dict, Optional, Tuple, Type, Union from typing import Dict, Optional, Tuple, Type, Union
import dns.name import dns.name
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
try: if dns._features.have("dnssec"):
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519 from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
@ -16,13 +19,9 @@ try:
) )
_have_cryptography = True _have_cryptography = True
except ImportError: else:
_have_cryptography = False _have_cryptography = False
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}

View file

@ -17,6 +17,7 @@
"""EDNS Options""" """EDNS Options"""
import binascii
import math import math
import socket import socket
import struct import struct
@ -58,7 +59,6 @@ class OptionType(dns.enum.IntEnum):
class Option: class Option:
"""Base class for all EDNS option types.""" """Base class for all EDNS option types."""
def __init__(self, otype: Union[OptionType, str]): def __init__(self, otype: Union[OptionType, str]):
@ -76,6 +76,9 @@ class Option:
""" """
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def to_text(self) -> str:
raise NotImplementedError # pragma: no cover
@classmethod @classmethod
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option": def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
@ -141,7 +144,6 @@ class Option:
class GenericOption(Option): # lgtm[py/missing-equals] class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class """Generic Option Class
This class is used for EDNS option types for which we have no better This class is used for EDNS option types for which we have no better
@ -343,6 +345,8 @@ class EDECode(dns.enum.IntEnum):
class EDEOption(Option): # lgtm[py/missing-equals] class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)""" """Extended DNS Error (EDE, RFC8914)"""
_preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
def __init__(self, code: Union[EDECode, str], text: Optional[str] = None): def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error. extended error.
@ -360,6 +364,13 @@ class EDEOption(Option): # lgtm[py/missing-equals]
def to_text(self) -> str: def to_text(self) -> str:
output = f"EDE {self.code}" output = f"EDE {self.code}"
if self.code in EDECode:
desc = EDECode.to_text(self.code)
desc = " ".join(
word if word in self._preserve_case else word.title()
for word in desc.split("_")
)
output += f" ({desc})"
if self.text is not None: if self.text is not None:
output += f": {self.text}" output += f": {self.text}"
return output return output
@ -392,9 +403,37 @@ class EDEOption(Option): # lgtm[py/missing-equals]
return cls(code, btext) return cls(code, btext)
class NSIDOption(Option):
def __init__(self, nsid: bytes):
super().__init__(OptionType.NSID)
self.nsid = nsid
def to_wire(self, file: Any = None) -> Optional[bytes]:
if file:
file.write(self.nsid)
return None
else:
return self.nsid
def to_text(self) -> str:
if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
# All ASCII printable, so it's probably a string.
value = self.nsid.decode()
else:
value = binascii.hexlify(self.nsid).decode()
return f"NSID {value}"
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_remaining())
_type_to_class: Dict[OptionType, Any] = { _type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption, OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption, OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption,
} }

View file

@ -1,24 +1,30 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc import collections.abc
from typing import Any from typing import Any, Callable
from dns._immutable_ctx import immutable from dns._immutable_ctx import immutable
@immutable @immutable
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
def __init__(self, dictionary: Any, no_copy: bool = False): def __init__(
self,
dictionary: Any,
no_copy: bool = False,
map_factory: Callable[[], collections.abc.MutableMapping] = dict,
):
"""Make an immutable dictionary from the specified dictionary. """Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external of copied. Only set this if you are sure there will be no external
references to the dictionary. references to the dictionary.
""" """
if no_copy and isinstance(dictionary, dict): if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
self._odict = dictionary self._odict = dictionary
else: else:
self._odict = dict(dictionary) self._odict = map_factory()
self._odict.update(dictionary)
self._hash = None self._hash = None
def __getitem__(self, key): def __getitem__(self, key):

View file

@ -178,3 +178,20 @@ def any_for_af(af):
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
return "::" return "::"
raise NotImplementedError(f"unknown address family {af}") raise NotImplementedError(f"unknown address family {af}")
def canonicalize(text: str) -> str:
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
canonical text form. IPv6 addresses with scopes are rejected.
*text*, a ``str``, the address in textual form.
Raises ``ValueError`` if the text is not valid.
"""
try:
return dns.ipv6.canonicalize(text)
except Exception:
try:
return dns.ipv4.canonicalize(text)
except Exception:
raise ValueError

View file

@ -62,3 +62,16 @@ def inet_aton(text: Union[str, bytes]) -> bytes:
return struct.pack("BBBB", *b) return struct.pack("BBBB", *b)
except Exception: except Exception:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
def canonicalize(text: Union[str, bytes]) -> str:
"""Verify that *address* is a valid text form IPv4 address and return its
canonical text form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
# Note that inet_aton() only accepts canonial form, but we still run through
# inet_ntoa() to ensure the output is a str.
return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))

View file

@ -104,7 +104,7 @@ _colon_colon_end = re.compile(rb".*::$")
def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes: def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form. """Convert an IPv6 address in text form to binary form.
*text*, a ``str``, the IPv6 address in textual form. *text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
*ignore_scope*, a ``bool``. If ``True``, a scope will be ignored. *ignore_scope*, a ``bool``. If ``True``, a scope will be ignored.
If ``False``, the default, it is an error for a scope to be present. If ``False``, the default, it is an error for a scope to be present.
@ -206,3 +206,14 @@ def is_mapped(address: bytes) -> bool:
""" """
return address.startswith(_mapped_prefix) return address.startswith(_mapped_prefix)
def canonicalize(text: Union[str, bytes]) -> str:
"""Verify that *address* is a valid text form IPv6 address and return its
canonical text form. Addresses with scopes are rejected.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))

View file

@ -393,7 +393,7 @@ class Message:
section_number = section section_number = section
section = self.section_from_number(section_number) section = self.section_from_number(section_number)
elif isinstance(section, str): elif isinstance(section, str):
section_number = MessageSection.from_text(section) section_number = self._section_enum.from_text(section)
section = self.section_from_number(section_number) section = self.section_from_number(section_number)
else: else:
section_number = self.section_number(section) section_number = self.section_number(section)
@ -489,6 +489,34 @@ class Message:
rrset = None rrset = None
return rrset return rrset
def section_count(self, section: SectionType) -> int:
"""Returns the number of records in the specified section.
*section*, an ``int`` section number, a ``str`` section name, or one of
the section attributes of this message. This specifies the
the section of the message to count. For example::
my_message.section_count(my_message.answer)
my_message.section_count(dns.message.ANSWER)
my_message.section_count("ANSWER")
"""
if isinstance(section, int):
section_number = section
section = self.section_from_number(section_number)
elif isinstance(section, str):
section_number = self._section_enum.from_text(section)
section = self.section_from_number(section_number)
else:
section_number = self.section_number(section)
count = sum(max(1, len(rrs)) for rrs in section)
if section_number == MessageSection.ADDITIONAL:
if self.opt is not None:
count += 1
if self.tsig is not None:
count += 1
return count
def _compute_opt_reserve(self) -> int: def _compute_opt_reserve(self) -> int:
"""Compute the size required for the OPT RR, padding excluded""" """Compute the size required for the OPT RR, padding excluded"""
if not self.opt: if not self.opt:
@ -527,6 +555,8 @@ class Message:
max_size: int = 0, max_size: int = 0,
multi: bool = False, multi: bool = False,
tsig_ctx: Optional[Any] = None, tsig_ctx: Optional[Any] = None,
prepend_length: bool = False,
prefer_truncation: bool = False,
**kw: Dict[str, Any], **kw: Dict[str, Any],
) -> bytes: ) -> bytes:
"""Return a string containing the message in DNS compressed wire """Return a string containing the message in DNS compressed wire
@ -549,6 +579,15 @@ class Message:
*tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
ongoing TSIG context, used when signing zone transfers. ongoing TSIG context, used when signing zone transfers.
*prepend_length*, a ``bool``, should be set to ``True`` if the caller
wants the message length prepended to the message itself. This is
useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ).
*prefer_truncation*, a ``bool``, should be set to ``True`` if the caller
wants the message to be truncated if it would otherwise exceed the
maximum length. If the truncation occurs before the additional section,
the TC bit will be set.
Raises ``dns.exception.TooBig`` if *max_size* was exceeded. Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
Returns a ``bytes``. Returns a ``bytes``.
@ -570,6 +609,7 @@ class Message:
r.reserve(opt_reserve) r.reserve(opt_reserve)
tsig_reserve = self._compute_tsig_reserve() tsig_reserve = self._compute_tsig_reserve()
r.reserve(tsig_reserve) r.reserve(tsig_reserve)
try:
for rrset in self.question: for rrset in self.question:
r.add_question(rrset.name, rrset.rdtype, rrset.rdclass) r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
for rrset in self.answer: for rrset in self.answer:
@ -578,6 +618,12 @@ class Message:
r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw) r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
for rrset in self.additional: for rrset in self.additional:
r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
except dns.exception.TooBig:
if prefer_truncation:
if r.section < dns.renderer.ADDITIONAL:
r.flags |= dns.flags.TC
else:
raise
r.release_reserved() r.release_reserved()
if self.opt is not None: if self.opt is not None:
r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve) r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
@ -598,7 +644,10 @@ class Message:
r.write_header() r.write_header()
if multi: if multi:
self.tsig_ctx = ctx self.tsig_ctx = ctx
return r.get_wire() wire = r.get_wire()
if prepend_length:
wire = len(wire).to_bytes(2, "big") + wire
return wire
@staticmethod @staticmethod
def _make_tsig( def _make_tsig(
@ -777,6 +826,8 @@ class Message:
if request_payload is None: if request_payload is None:
request_payload = payload request_payload = payload
self.request_payload = request_payload self.request_payload = request_payload
if pad < 0:
raise ValueError("pad must be non-negative")
self.pad = pad self.pad = pad
@property @property
@ -826,7 +877,7 @@ class Message:
if wanted: if wanted:
self.ednsflags |= dns.flags.DO self.ednsflags |= dns.flags.DO
elif self.opt: elif self.opt:
self.ednsflags &= ~dns.flags.DO self.ednsflags &= ~int(dns.flags.DO)
def rcode(self) -> dns.rcode.Rcode: def rcode(self) -> dns.rcode.Rcode:
"""Return the rcode. """Return the rcode.
@ -1035,7 +1086,6 @@ def _message_factory_from_opcode(opcode):
class _WireReader: class _WireReader:
"""Wire format reader. """Wire format reader.
parser: the binary parser parser: the binary parser
@ -1335,7 +1385,6 @@ def from_wire(
class _TextReader: class _TextReader:
"""Text format reader. """Text format reader.
tok: the tokenizer. tok: the tokenizer.
@ -1768,30 +1817,34 @@ def make_response(
our_payload: int = 8192, our_payload: int = 8192,
fudge: int = 300, fudge: int = 300,
tsig_error: int = 0, tsig_error: int = 0,
pad: Optional[int] = None,
) -> Message: ) -> Message:
"""Make a message which is a response for the specified query. """Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all The message returned is really a response skeleton; it has all of the infrastructure
of the infrastructure required of a response, but none of the required of a response, but none of the content.
content.
The response's question section is a shallow copy of the query's The response's question section is a shallow copy of the query's question section,
question section, so the query's question RRsets should not be so the query's question RRsets should not be changed.
changed.
*query*, a ``dns.message.Message``, the query to respond to. *query*, a ``dns.message.Message``, the query to respond to.
*recursion_available*, a ``bool``, should RA be set in the response? *recursion_available*, a ``bool``, should RA be set in the response?
*our_payload*, an ``int``, the payload size to advertise in EDNS *our_payload*, an ``int``, the payload size to advertise in EDNS responses.
responses.
*fudge*, an ``int``, the TSIG time fudge. *fudge*, an ``int``, the TSIG time fudge.
*tsig_error*, an ``int``, the TSIG error. *tsig_error*, an ``int``, the TSIG error.
Returns a ``dns.message.Message`` object whose specific class is *pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
appropriate for the query. For example, if query is a if not ``None`` add padding bytes to make the message size a multiple of *pad*.
``dns.update.UpdateMessage``, response will be too. Note that if padding is non-zero, an EDNS PADDING option will always be added to the
message. If ``None``, add padding following RFC 8467, namely if the request is
padded, pad the response to 468 otherwise do not pad.
Returns a ``dns.message.Message`` object whose specific class is appropriate for the
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
too.
""" """
if query.flags & dns.flags.QR: if query.flags & dns.flags.QR:
@ -1804,7 +1857,13 @@ def make_response(
response.set_opcode(query.opcode()) response.set_opcode(query.opcode())
response.question = list(query.question) response.question = list(query.question)
if query.edns >= 0: if query.edns >= 0:
response.use_edns(0, 0, our_payload, query.payload) if pad is None:
# Set response padding per RFC 8467
pad = 0
for option in query.options:
if option.otype == dns.edns.OptionType.PADDING:
pad = 468
response.use_edns(0, 0, our_payload, query.payload, pad=pad)
if query.had_tsig: if query.had_tsig:
response.use_tsig( response.use_tsig(
query.keyring, query.keyring,

View file

@ -20,21 +20,23 @@
import copy import copy
import encodings.idna # type: ignore import encodings.idna # type: ignore
import functools
import struct import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
try:
import idna # type: ignore
have_idna_2008 = True
except ImportError: # pragma: no cover
have_idna_2008 = False
import dns._features
import dns.enum import dns.enum
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.wire import dns.wire
if dns._features.have("idna"):
import idna # type: ignore
have_idna_2008 = True
else: # pragma: no cover
have_idna_2008 = False
CompressType = Dict["Name", int] CompressType = Dict["Name", int]
@ -128,6 +130,10 @@ class IDNAException(dns.exception.DNSException):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class NeedSubdomainOfOrigin(dns.exception.DNSException):
"""An absolute name was provided that is not a subdomain of the specified origin."""
_escaped = b'"().;\\@$' _escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$' _escaped_text = '"().;\\@$'
@ -350,7 +356,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
@dns.immutable.immutable @dns.immutable.immutable
class Name: class Name:
"""A DNS name. """A DNS name.
The dns.name.Name class represents a DNS name as a tuple of The dns.name.Name class represents a DNS name as a tuple of
@ -843,6 +848,42 @@ class Name:
raise NoParent raise NoParent
return Name(self.labels[1:]) return Name(self.labels[1:])
def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
"""Return the maximal predecessor of *name* in the DNSSEC ordering in the zone
whose origin is *origin*, or return the longest name under *origin* if the
name is origin (i.e. wrap around to the longest name, which may still be
*origin* due to length considerations.
The relativity of the name is preserved, so if this name is relative
then the method will return a relative name, and likewise if this name
is absolute then the predecessor will be absolute.
*prefix_ok* indicates if prefixing labels is allowed, and
defaults to ``True``. Normally it is good to allow this, but if computing
a maximal predecessor at a zone cut point then ``False`` must be specified.
"""
return _handle_relativity_and_call(
_absolute_predecessor, self, origin, prefix_ok
)
def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
"""Return the minimal successor of *name* in the DNSSEC ordering in the zone
whose origin is *origin*, or return *origin* if the successor cannot be
computed due to name length limitations.
Note that *origin* is returned in the "too long" cases because wrapping
around to the origin is how NSEC records express "end of the zone".
The relativity of the name is preserved, so if this name is relative
then the method will return a relative name, and likewise if this name
is absolute then the successor will be absolute.
*prefix_ok* indicates if prefixing a new minimal label is allowed, and
defaults to ``True``. Normally it is good to allow this, but if computing
a minimal successor at a zone cut point then ``False`` must be specified.
"""
return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok)
#: The root name, '.' #: The root name, '.'
root = Name([b""]) root = Name([b""])
@ -1082,3 +1123,161 @@ def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
parser = dns.wire.Parser(message, current) parser = dns.wire.Parser(message, current)
name = from_wire_parser(parser) name = from_wire_parser(parser)
return (name, parser.current - current) return (name, parser.current - current)
# RFC 4471 Support
_MINIMAL_OCTET = b"\x00"
_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET)
_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET])
_MAXIMAL_OCTET = b"\xff"
_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET)
_AT_SIGN_VALUE = ord("@")
_LEFT_SQUARE_BRACKET_VALUE = ord("[")
def _wire_length(labels):
return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0)
def _pad_to_max_name(name):
needed = 255 - _wire_length(name.labels)
new_labels = []
while needed > 64:
new_labels.append(_MAXIMAL_OCTET * 63)
needed -= 64
if needed >= 2:
new_labels.append(_MAXIMAL_OCTET * (needed - 1))
# Note we're already maximal in the needed == 1 case as while we'd like
# to add one more byte as a new label, we can't, as adding a new non-empty
# label requires at least 2 bytes.
new_labels = list(reversed(new_labels))
new_labels.extend(name.labels)
return Name(new_labels)
def _pad_to_max_label(label, suffix_labels):
length = len(label)
# We have to subtract one here to account for the length byte of label.
remaining = 255 - _wire_length(suffix_labels) - length - 1
if remaining <= 0:
# Shouldn't happen!
return label
needed = min(63 - length, remaining)
return label + _MAXIMAL_OCTET * needed
def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name:
# This is the RFC 4471 predecessor algorithm using the "absolute method" of section
# 3.1.1.
#
# Our caller must ensure that the name and origin are absolute, and that name is a
# subdomain of origin.
if name == origin:
return _pad_to_max_name(name)
least_significant_label = name[0]
if least_significant_label == _MINIMAL_OCTET:
return name.parent()
least_octet = least_significant_label[-1]
suffix_labels = name.labels[1:]
if least_octet == _MINIMAL_OCTET_VALUE:
new_labels = [least_significant_label[:-1]]
else:
octets = bytearray(least_significant_label)
octet = octets[-1]
if octet == _LEFT_SQUARE_BRACKET_VALUE:
octet = _AT_SIGN_VALUE
else:
octet -= 1
octets[-1] = octet
least_significant_label = bytes(octets)
new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)]
new_labels.extend(suffix_labels)
name = Name(new_labels)
if prefix_ok:
return _pad_to_max_name(name)
else:
return name
def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name:
# This is the RFC 4471 successor algorithm using the "absolute method" of section
# 3.1.2.
#
# Our caller must ensure that the name and origin are absolute, and that name is a
# subdomain of origin.
if prefix_ok:
# Try prefixing \000 as new label
try:
return _SUCCESSOR_PREFIX.concatenate(name)
except NameTooLong:
pass
while name != origin:
# Try extending the least significant label.
least_significant_label = name[0]
if len(least_significant_label) < 63:
# We may be able to extend the least label with a minimal additional byte.
# This is only "may" because we could have a maximal length name even though
# the least significant label isn't maximally long.
new_labels = [least_significant_label + _MINIMAL_OCTET]
new_labels.extend(name.labels[1:])
try:
return dns.name.Name(new_labels)
except dns.name.NameTooLong:
pass
# We can't extend the label either, so we'll try to increment the least
# signficant non-maximal byte in it.
octets = bytearray(least_significant_label)
# We do this reversed iteration with an explicit indexing variable because
# if we find something to increment, we're going to want to truncate everything
# to the right of it.
for i in range(len(octets) - 1, -1, -1):
octet = octets[i]
if octet == _MAXIMAL_OCTET_VALUE:
# We can't increment this, so keep looking.
continue
# Finally, something we can increment. We have to apply a special rule for
# incrementing "@", sending it to "[", because RFC 4034 6.1 says that when
# comparing names, uppercase letters compare as if they were their
# lower-case equivalents. If we increment "@" to "A", then it would compare
# as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have
# skipped the most minimal successor, namely "[".
if octet == _AT_SIGN_VALUE:
octet = _LEFT_SQUARE_BRACKET_VALUE
else:
octet += 1
octets[i] = octet
# We can now truncate all of the maximal values we skipped (if any)
new_labels = [bytes(octets[: i + 1])]
new_labels.extend(name.labels[1:])
# We haven't changed the length of the name, so the Name constructor will
# always work.
return Name(new_labels)
# We couldn't increment, so chop off the least significant label and try
# again.
name = name.parent()
# We couldn't increment at all, so return the origin, as wrapping around is the
# DNSSEC way.
return origin
def _handle_relativity_and_call(
function: Callable[[Name, Name, bool], Name],
name: Name,
origin: Name,
prefix_ok: bool,
) -> Name:
# Make "name" absolute if needed, ensure that the origin is absolute,
# call function(), and then relativize the result if needed.
if not origin.is_absolute():
raise NeedAbsoluteNameOrOrigin
relative = not name.is_absolute()
if relative:
name = name.derelativize(origin)
elif not name.is_subdomain(origin):
raise NeedSubdomainOfOrigin
result_name = function(name, origin, prefix_ok)
if relative:
result_name = result_name.relativize(origin)
return result_name

View file

@ -115,6 +115,8 @@ class Do53Nameserver(AddressAndPortNameserver):
raise_on_truncation=True, raise_on_truncation=True,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
) )
return response return response
@ -153,15 +155,25 @@ class Do53Nameserver(AddressAndPortNameserver):
backend=backend, backend=backend,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
) )
return response return response
class DoHNameserver(Nameserver): class DoHNameserver(Nameserver):
def __init__(self, url: str, bootstrap_address: Optional[str] = None): def __init__(
self,
url: str,
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
):
super().__init__() super().__init__()
self.url = url self.url = url
self.bootstrap_address = bootstrap_address self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
def kind(self): def kind(self):
return "DoH" return "DoH"
@ -195,9 +207,13 @@ class DoHNameserver(Nameserver):
request, request,
self.url, self.url,
timeout=timeout, timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address, bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
) )
async def async_query( async def async_query(
@ -215,15 +231,27 @@ class DoHNameserver(Nameserver):
request, request,
self.url, self.url,
timeout=timeout, timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
) )
class DoTNameserver(AddressAndPortNameserver): class DoTNameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None): def __init__(
self,
address: str,
port: int = 853,
hostname: Optional[str] = None,
verify: Union[bool, str] = True,
):
super().__init__(address, port) super().__init__(address, port)
self.hostname = hostname self.hostname = hostname
self.verify = verify
def kind(self): def kind(self):
return "DoT" return "DoT"
@ -246,6 +274,7 @@ class DoTNameserver(AddressAndPortNameserver):
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
server_hostname=self.hostname, server_hostname=self.hostname,
verify=self.verify,
) )
async def async_query( async def async_query(
@ -267,6 +296,7 @@ class DoTNameserver(AddressAndPortNameserver):
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
server_hostname=self.hostname, server_hostname=self.hostname,
verify=self.verify,
) )

View file

@ -70,7 +70,6 @@ class NodeKind(enum.Enum):
class Node: class Node:
"""A Node is a set of rdatasets. """A Node is a set of rdatasets.
A node is either a CNAME node or an "other data" node. A CNAME A node is either a CNAME node or an "other data" node. A CNAME

View file

@ -22,12 +22,14 @@ import contextlib
import enum import enum
import errno import errno
import os import os
import os.path
import selectors import selectors
import socket import socket
import struct import struct
import time import time
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import dns._features
import dns.exception import dns.exception
import dns.inet import dns.inet
import dns.message import dns.message
@ -57,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration):
return min(time.time() + timeout, expiration) return min(time.time() + timeout, expiration)
_have_httpx = False _have_httpx = dns._features.have("doh")
_have_http2 = False if _have_httpx:
try:
import httpcore
import httpcore._backends.sync import httpcore._backends.sync
import httpx import httpx
_CoreNetworkBackend = httpcore.NetworkBackend _CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream _CoreSyncStream = httpcore._backends.sync.SyncStream
_have_httpx = True
try:
# See if http2 support is available.
with httpx.Client(http2=True):
_have_http2 = True
except Exception:
pass
class _NetworkBackend(_CoreNetworkBackend): class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family): def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__() super().__init__()
@ -147,7 +139,7 @@ try:
resolver, local_port, bootstrap_address, family resolver, local_port, bootstrap_address, family
) )
except ImportError: # pragma: no cover else:
class _HTTPTransport: # type: ignore class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address): def connect_tcp(self, host, port, timeout, local_address):
@ -161,6 +153,8 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
class ssl: # type: ignore class ssl: # type: ignore
CERT_NONE = 0
class WantReadException(Exception): class WantReadException(Exception):
pass pass
@ -459,7 +453,7 @@ def https(
transport = _HTTPTransport( transport = _HTTPTransport(
local_address=local_address, local_address=local_address,
http1=True, http1=True,
http2=_have_http2, http2=True,
verify=verify, verify=verify,
local_port=local_port, local_port=local_port,
bootstrap_address=bootstrap_address, bootstrap_address=bootstrap_address,
@ -470,9 +464,7 @@ def https(
if session: if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else: else:
cm = httpx.Client( cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
http1=True, http2=_have_http2, verify=verify, transport=transport
)
with cm as session: with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
@ -577,6 +569,8 @@ def receive_udp(
request_mac: Optional[bytes] = b"", request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False, ignore_trailing: bool = False,
raise_on_truncation: bool = False, raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: Optional[dns.message.Message] = None,
) -> Any: ) -> Any:
"""Read a DNS message from a UDP socket. """Read a DNS message from a UDP socket.
@ -617,16 +611,25 @@ def receive_udp(
``(dns.message.Message, float, tuple)`` ``(dns.message.Message, float, tuple)``
tuple of the received message, the received time, and the address where tuple of the received message, the received time, and the address where
the message arrived from. the message arrived from.
*ignore_errors*, a ``bool``. If various format errors or response
mismatches occur, ignore them and keep listening for a valid response.
The default is ``False``.
*query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
*ignore_errors* is ``True``, check that the received message is a response
to this query, and if not keep listening for a valid response.
""" """
wire = b"" wire = b""
while True: while True:
(wire, from_address) = _udp_recv(sock, 65535, expiration) (wire, from_address) = _udp_recv(sock, 65535, expiration)
if _matches_destination( if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected sock.family, from_address, destination, ignore_unexpected
): ):
break continue
received_time = time.time() received_time = time.time()
try:
r = dns.message.from_wire( r = dns.message.from_wire(
wire, wire,
keyring=keyring, keyring=keyring,
@ -635,6 +638,27 @@ def receive_udp(
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation, raise_on_truncation=raise_on_truncation,
) )
except dns.message.Truncated as e:
# If we got Truncated and not FORMERR, we at least got the header with TC
# set, and very likely the question section, so we'll re-raise if the
# message seems to be a response as we need to know when truncation happens.
# We need to check that it seems to be a response as we don't want a random
# injected message with TC set to cause us to bail out.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
if destination: if destination:
return (r, received_time) return (r, received_time)
else: else:
@ -653,6 +677,7 @@ def udp(
ignore_trailing: bool = False, ignore_trailing: bool = False,
raise_on_truncation: bool = False, raise_on_truncation: bool = False,
sock: Optional[Any] = None, sock: Optional[Any] = None,
ignore_errors: bool = False,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP. """Return the response obtained after sending a query via UDP.
@ -689,6 +714,10 @@ def udp(
if a socket is provided, it must be a nonblocking datagram socket, if a socket is provided, it must be a nonblocking datagram socket,
and the *source* and *source_port* are ignored. and the *source* and *source_port* are ignored.
*ignore_errors*, a ``bool``. If various format errors or response
mismatches occur, ignore them and keep listening for a valid response.
The default is ``False``.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
@ -713,9 +742,13 @@ def udp(
q.mac, q.mac,
ignore_trailing, ignore_trailing,
raise_on_truncation, raise_on_truncation,
ignore_errors,
q,
) )
r.time = received_time - begin_time r.time = received_time - begin_time
if not q.is_response(r): # We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse raise BadResponse
return r return r
assert ( assert (
@ -735,48 +768,50 @@ def udp_with_fallback(
ignore_trailing: bool = False, ignore_trailing: bool = False,
udp_sock: Optional[Any] = None, udp_sock: Optional[Any] = None,
tcp_sock: Optional[Any] = None, tcp_sock: Optional[Any] = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]: ) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back """Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response. to TCP if UDP results in a truncated response.
*q*, a ``dns.message.Message``, the query to send *q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where *where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message.
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
query times out. If ``None``, the default, wait forever. times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53. *port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
the source address. The default is the wildcard address. address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message. *source_port*, an ``int``, the port from which to send the message. The default is
The default is 0. 0.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected
unexpected sources. sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
junk at end of the received message. received message.
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
UDP query. If ``None``, the default, a socket is created. Note that If ``None``, the default, a socket is created. Note that if a socket is provided,
if a socket is provided, it must be a nonblocking datagram socket, it must be a nonblocking datagram socket, and the *source* and *source_port* are
and the *source* and *source_port* are ignored for the UDP query. ignored for the UDP query.
*tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
TCP query. If ``None``, the default, a socket is created. Note that TCP query. If ``None``, the default, a socket is created. Note that if a socket is
if a socket is provided, it must be a nonblocking connected stream provided, it must be a nonblocking connected stream socket, and *where*, *source*
socket, and *where*, *source* and *source_port* are ignored for the TCP and *source_port* are ignored for the TCP query.
query.
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` *ignore_errors*, a ``bool``. If various format errors or response mismatches occur
if and only if TCP was used. while listening for UDP, ignore them and keep listening for a valid response. The
default is ``False``.
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
TCP was used.
""" """
try: try:
response = udp( response = udp(
@ -791,6 +826,7 @@ def udp_with_fallback(
ignore_trailing, ignore_trailing,
True, True,
udp_sock, udp_sock,
ignore_errors,
) )
return (response, False) return (response, False)
except dns.message.Truncated: except dns.message.Truncated:
@ -864,14 +900,12 @@ def send_tcp(
""" """
if isinstance(what, dns.message.Message): if isinstance(what, dns.message.Message):
wire = what.to_wire() tcpmsg = what.to_wire(prepend_length=True)
else: else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us # copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed # avoid writev() or doing a short write that would get pushed
# onto the net # onto the net
tcpmsg = struct.pack("!H", l) + wire tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time() sent_time = time.time()
_net_write(sock, tcpmsg, expiration) _net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time) return (len(tcpmsg), sent_time)
@ -1014,6 +1048,28 @@ def _tls_handshake(s, expiration):
_wait_for_writable(s, expiration) _wait_for_writable(s, expiration)
def _make_dot_ssl_context(
server_hostname: Optional[str], verify: Union[bool, str]
) -> ssl.SSLContext:
cafile: Optional[str] = None
capath: Optional[str] = None
if isinstance(verify, str):
if os.path.isfile(verify):
cafile = verify
elif os.path.isdir(verify):
capath = verify
else:
raise ValueError("invalid verify string")
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context.set_alpn_protocols(["dot"])
if verify is False:
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context
def tls( def tls(
q: dns.message.Message, q: dns.message.Message,
where: str, where: str,
@ -1026,6 +1082,7 @@ def tls(
sock: Optional[ssl.SSLSocket] = None, sock: Optional[ssl.SSLSocket] = None,
ssl_context: Optional[ssl.SSLContext] = None, ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None, server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message: ) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS. """Return the response obtained after sending a query via TLS.
@ -1065,6 +1122,11 @@ def tls(
default is ``None``, which means that no hostname is known, and if an default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled. SSL context is created, hostname checking will be disabled.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
Returns a ``dns.message.Message``. Returns a ``dns.message.Message``.
""" """
@ -1091,10 +1153,7 @@ def tls(
where, port, source, source_port where, port, source, source_port
) )
if ssl_context is None and not sock: if ssl_context is None and not sock:
ssl_context = ssl.create_default_context() ssl_context = _make_dot_ssl_context(server_hostname, verify)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
with _make_socket( with _make_socket(
af, af,

View file

@ -1,9 +1,11 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
try: import dns._features
import dns.asyncbackend
if dns._features.have("doq"):
import aioquic.quic.configuration # type: ignore import aioquic.quic.configuration # type: ignore
import dns.asyncbackend
from dns._asyncbackend import NullContext from dns._asyncbackend import NullContext
from dns.quic._asyncio import ( from dns.quic._asyncio import (
AsyncioQuicConnection, AsyncioQuicConnection,
@ -17,7 +19,7 @@ try:
def null_factory( def null_factory(
*args, # pylint: disable=unused-argument *args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
return NullContext(None) return NullContext(None)
@ -31,7 +33,7 @@ try:
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
try: if dns._features.have("trio"):
import trio import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports from dns.quic._trio import ( # pylint: disable=ungrouped-imports
@ -47,15 +49,13 @@ try:
return TrioQuicManager(context, *args, **kwargs) return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
except ImportError:
pass
def factories_for_backend(backend=None): def factories_for_backend(backend=None):
if backend is None: if backend is None:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()] return _async_factories[backend.name()]
except ImportError: else: # pragma: no cover
have_quic = False have_quic = False
from typing import Any from typing import Any

View file

@ -101,9 +101,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
) )
if address[0] != self._peer[0] or address[1] != self._peer[1]: if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue continue
self._connection.receive_datagram( self._connection.receive_datagram(datagram, address, time.time())
datagram, self._peer[0], time.time()
)
# Wake up the timer in case the sender is sleeping, as there may be # Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now. # stuff to send now.
async with self._wake_timer: async with self._wake_timer:
@ -125,7 +123,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
while not self._done: while not self._done:
datagrams = self._connection.datagrams_to_send(time.time()) datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams: for datagram, address in datagrams:
assert address == self._peer[0] assert address == self._peer
await self._socket.sendto(datagram, self._peer, None) await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values() (expiration, interval) = self._get_timer_values()
try: try:
@ -147,11 +145,14 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await stream._add_input(event.data, event.end_stream) await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted): elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set() self._handshake_complete.set()
elif isinstance( elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
self._done = True self._done = True
self._receiver_task.cancel() self._receiver_task.cancel()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1 count += 1
if count > 10: if count > 10:
# yield # yield
@ -188,7 +189,6 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.close() self._connection.close()
# sender might be blocked on this, so set it # sender might be blocked on this, so set it
self._socket_created.set() self._socket_created.set()
await self._socket.close()
async with self._wake_timer: async with self._wake_timer:
self._wake_timer.notify_all() self._wake_timer.notify_all()
try: try:
@ -199,14 +199,19 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await self._sender_task await self._sender_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
await self._socket.close()
class AsyncioQuicManager(AsyncQuicManager): class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
def connect(self, address, port=853, source=None, source_port=0): def connect(
(connection, start) = self._connect(address, port, source, source_port) self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start: if start:
connection.run() connection.run()
return connection return connection

View file

@ -1,5 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import copy
import functools
import socket import socket
import struct import struct
import time import time
@ -11,6 +13,10 @@ import aioquic.quic.connection # type: ignore
import dns.inet import dns.inet
QUIC_MAX_DATAGRAM = 2048 QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
# If we hit the max sessions limit we will delete this many of the oldest connections.
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
class UnexpectedEOF(Exception): class UnexpectedEOF(Exception):
@ -79,7 +85,10 @@ class BaseQuicStream:
def _common_add_input(self, data, is_end): def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end) self._buffer.put(data, is_end)
try:
return self._expecting > 0 and self._buffer.have(self._expecting) return self._expecting > 0 and self._buffer.have(self._expecting)
except UnexpectedEOF:
return True
def _close(self): def _close(self):
self._connection.close_stream(self._stream_id) self._connection.close_stream(self._stream_id)
@ -142,6 +151,7 @@ class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory, server_name=None): def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {} self._connections = {}
self._connection_factory = connection_factory self._connection_factory = connection_factory
self._session_tickets = {}
if conf is None: if conf is None:
verify_path = None verify_path = None
if isinstance(verify_mode, str): if isinstance(verify_mode, str):
@ -156,12 +166,35 @@ class BaseQuicManager:
conf.load_verify_locations(verify_path) conf.load_verify_locations(verify_path)
self._conf = conf self._conf = conf
def _connect(self, address, port=853, source=None, source_port=0): def _connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
connection = self._connections.get((address, port)) connection = self._connections.get((address, port))
if connection is not None: if connection is not None:
return (connection, False) return (connection, False)
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf) conf = self._conf
qconn.connect(address, time.time()) if want_session_ticket:
try:
session_ticket = self._session_tickets.pop((address, port))
# We found a session ticket, so make a configuration that uses it.
conf = copy.copy(conf)
conf.session_ticket = session_ticket
except KeyError:
# No session ticket.
pass
# Whether or not we found a session ticket, we want a handler to save
# one.
session_ticket_handler = functools.partial(
self.save_session_ticket, address, port
)
else:
session_ticket_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory( connection = self._connection_factory(
qconn, address, port, source, source_port, self qconn, address, port, source, source_port, self
) )
@ -174,6 +207,17 @@ class BaseQuicManager:
except KeyError: except KeyError:
pass pass
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._session_tickets)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
class AsyncQuicManager(BaseQuicManager): class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0): def connect(self, address, port=853, source=None, source_port=0):

View file

@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager): def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager) super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
if self._source is not None: if self._source is not None:
try: try:
self._socket.bind( self._socket.bind(
@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection):
except Exception: except Exception:
self._socket.close() self._socket.close()
raise raise
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event() self._handshake_complete = threading.Event()
self._worker_thread = None self._worker_thread = None
self._lock = threading.Lock() self._lock = threading.Lock()
@ -107,7 +107,7 @@ class SyncQuicConnection(BaseQuicConnection):
except BlockingIOError: except BlockingIOError:
return return
with self._lock: with self._lock:
self._connection.receive_datagram(datagram, self._peer[0], time.time()) self._connection.receive_datagram(datagram, self._peer, time.time())
def _drain_wakeup(self): def _drain_wakeup(self):
while True: while True:
@ -128,6 +128,8 @@ class SyncQuicConnection(BaseQuicConnection):
key.data() key.data()
with self._lock: with self._lock:
self._handle_timer(expiration) self._handle_timer(expiration)
self._handle_events()
with self._lock:
datagrams = self._connection.datagrams_to_send(time.time()) datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams: for datagram, _ in datagrams:
try: try:
@ -135,7 +137,6 @@ class SyncQuicConnection(BaseQuicConnection):
except BlockingIOError: except BlockingIOError:
# we let QUIC handle any lossage # we let QUIC handle any lossage
pass pass
self._handle_events()
finally: finally:
with self._lock: with self._lock:
self._done = True self._done = True
@ -155,11 +156,14 @@ class SyncQuicConnection(BaseQuicConnection):
stream._add_input(event.data, event.end_stream) stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted): elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set() self._handshake_complete.set()
elif isinstance( elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
with self._lock: with self._lock:
self._done = True self._done = True
elif isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(b"", True)
def write(self, stream, data, is_end=False): def write(self, stream, data, is_end=False):
with self._lock: with self._lock:
@ -203,9 +207,13 @@ class SyncQuicManager(BaseQuicManager):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name) super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock() self._lock = threading.Lock()
def connect(self, address, port=853, source=None, source_port=0): def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
with self._lock: with self._lock:
(connection, start) = self._connect(address, port, source, source_port) (connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start: if start:
connection.run() connection.run()
return connection return connection
@ -214,6 +222,10 @@ class SyncQuicManager(BaseQuicManager):
with self._lock: with self._lock:
super().closed(address, port) super().closed(address, port)
def save_session_ticket(self, address, port, ticket):
with self._lock:
super().save_session_ticket(address, port, ticket)
def __enter__(self): def __enter__(self):
return self return self

View file

@ -76,30 +76,43 @@ class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None): def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager) super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source:
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
self._handshake_complete = trio.Event() self._handshake_complete = trio.Event()
self._run_done = trio.Event() self._run_done = trio.Event()
self._worker_scope = None self._worker_scope = None
self._send_pending = False
async def _worker(self): async def _worker(self):
try: try:
if self._source:
await self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
await self._socket.connect(self._peer) await self._socket.connect(self._peer)
while not self._done: while not self._done:
(expiration, interval) = self._get_timer_values(False) (expiration, interval) = self._get_timer_values(False)
if self._send_pending:
# Do not block forever if sends are pending. Even though we
# have a wake-up mechanism if we've already started the blocking
# read, the possibility of context switching in send means that
# more writes can happen while we have no wake up context, so
# we need self._send_pending to avoid (effectively) a "lost wakeup"
# race.
interval = 0.0
with trio.CancelScope( with trio.CancelScope(
deadline=trio.current_time() + interval deadline=trio.current_time() + interval
) as self._worker_scope: ) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram( self._connection.receive_datagram(datagram, self._peer, time.time())
datagram, self._peer[0], time.time()
)
self._worker_scope = None self._worker_scope = None
self._handle_timer(expiration) self._handle_timer(expiration)
await self._handle_events()
# We clear this now, before sending anything, as sending can cause
# context switches that do more sends. We want to know if that
# happens so we don't block a long time on the recv() above.
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time()) datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams: for datagram, _ in datagrams:
await self._socket.send(datagram) await self._socket.send(datagram)
await self._handle_events()
finally: finally:
self._done = True self._done = True
self._handshake_complete.set() self._handshake_complete.set()
@ -116,11 +129,13 @@ class TrioQuicConnection(AsyncQuicConnection):
await stream._add_input(event.data, event.end_stream) await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted): elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set() self._handshake_complete.set()
elif isinstance( elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
self._done = True self._done = True
self._socket.close() self._socket.close()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1 count += 1
if count > 10: if count > 10:
# yield # yield
@ -129,6 +144,7 @@ class TrioQuicConnection(AsyncQuicConnection):
async def write(self, stream, data, is_end=False): async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end) self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
if self._worker_scope is not None: if self._worker_scope is not None:
self._worker_scope.cancel() self._worker_scope.cancel()
@ -159,6 +175,7 @@ class TrioQuicConnection(AsyncQuicConnection):
self._manager.closed(self._peer[0], self._peer[1]) self._manager.closed(self._peer[0], self._peer[1])
self._closed = True self._closed = True
self._connection.close() self._connection.close()
self._send_pending = True
if self._worker_scope is not None: if self._worker_scope is not None:
self._worker_scope.cancel() self._worker_scope.cancel()
await self._run_done.wait() await self._run_done.wait()
@ -171,8 +188,12 @@ class TrioQuicManager(AsyncQuicManager):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name) super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery self._nursery = nursery
def connect(self, address, port=853, source=None, source_port=0): def connect(
(connection, start) = self._connect(address, port, source, source_port) self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start: if start:
self._nursery.start_soon(connection.run) self._nursery.start_soon(connection.run)
return connection return connection

View file

@ -199,7 +199,7 @@ class Rdata:
self, self,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
relativize: bool = True, relativize: bool = True,
**kw: Dict[str, Any] **kw: Dict[str, Any],
) -> str: ) -> str:
"""Convert an rdata to text format. """Convert an rdata to text format.
@ -547,9 +547,7 @@ class Rdata:
@classmethod @classmethod
def _as_ipv4_address(cls, value): def _as_ipv4_address(cls, value):
if isinstance(value, str): if isinstance(value, str):
# call to check validity return dns.ipv4.canonicalize(value)
dns.ipv4.inet_aton(value)
return value
elif isinstance(value, bytes): elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value) return dns.ipv4.inet_ntoa(value)
else: else:
@ -558,9 +556,7 @@ class Rdata:
@classmethod @classmethod
def _as_ipv6_address(cls, value): def _as_ipv6_address(cls, value):
if isinstance(value, str): if isinstance(value, str):
# call to check validity return dns.ipv6.canonicalize(value)
dns.ipv6.inet_aton(value)
return value
elif isinstance(value, bytes): elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value) return dns.ipv6.inet_ntoa(value)
else: else:
@ -604,7 +600,6 @@ class Rdata:
@dns.immutable.immutable @dns.immutable.immutable
class GenericRdata(Rdata): class GenericRdata(Rdata):
"""Generic Rdata Class """Generic Rdata Class
This class is used for rdata types for which we have no better This class is used for rdata types for which we have no better
@ -621,7 +616,7 @@ class GenericRdata(Rdata):
self, self,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
relativize: bool = True, relativize: bool = True,
**kw: Dict[str, Any] **kw: Dict[str, Any],
) -> str: ) -> str:
return r"\# %d " % len(self.data) + _hexify(self.data, **kw) return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
@ -647,9 +642,9 @@ class GenericRdata(Rdata):
return cls(rdclass, rdtype, parser.get_remaining()) return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes: Dict[ _rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = (
Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any {}
] = {} )
_module_prefix = "dns.rdtypes" _module_prefix = "dns.rdtypes"

View file

@ -28,6 +28,7 @@ import dns.name
import dns.rdata import dns.rdata
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.renderer
import dns.set import dns.set
import dns.ttl import dns.ttl
@ -45,7 +46,6 @@ class IncompatibleTypes(dns.exception.DNSException):
class Rdataset(dns.set.Set): class Rdataset(dns.set.Set):
"""A DNS rdataset.""" """A DNS rdataset."""
__slots__ = ["rdclass", "rdtype", "covers", "ttl"] __slots__ = ["rdclass", "rdtype", "covers", "ttl"]
@ -316,11 +316,9 @@ class Rdataset(dns.set.Set):
want_shuffle = False want_shuffle = False
else: else:
rdclass = self.rdclass rdclass = self.rdclass
file.seek(0, io.SEEK_END)
if len(self) == 0: if len(self) == 0:
name.to_wire(file, compress, origin) name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0) file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
file.write(stuff)
return 1 return 1
else: else:
l: Union[Rdataset, List[dns.rdata.Rdata]] l: Union[Rdataset, List[dns.rdata.Rdata]]
@ -331,16 +329,9 @@ class Rdataset(dns.set.Set):
l = self l = self
for rd in l: for rd in l:
name.to_wire(file, compress, origin) name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0) file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
file.write(stuff) with dns.renderer.prefixed_length(file, 2):
start = file.tell()
rd.to_wire(file, compress, origin) rd.to_wire(file, compress, origin)
end = file.tell()
assert end - start < 65536
file.seek(start - 2)
stuff = struct.pack("!H", end - start)
file.write(stuff)
file.seek(0, io.SEEK_END)
return len(self) return len(self)
def match( def match(
@ -373,7 +364,6 @@ class Rdataset(dns.set.Set):
@dns.immutable.immutable @dns.immutable.immutable
class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
"""An immutable DNS rdataset.""" """An immutable DNS rdataset."""
_clone_class = Rdataset _clone_class = Rdataset

View file

@ -21,7 +21,6 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable
class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""AFSDB record""" """AFSDB record"""
# Use the property mechanism to make "subtype" an alias for the # Use the property mechanism to make "subtype" an alias for the

View file

@ -32,7 +32,6 @@ class Relay(dns.rdtypes.util.Gateway):
@dns.immutable.immutable @dns.immutable.immutable
class AMTRELAY(dns.rdata.Rdata): class AMTRELAY(dns.rdata.Rdata):
"""AMTRELAY record""" """AMTRELAY record"""
# see: RFC 8777 # see: RFC 8777

View file

@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable
class AVC(dns.rdtypes.txtbase.TXTBase): class AVC(dns.rdtypes.txtbase.TXTBase):
"""AVC record""" """AVC record"""
# See: IANA dns parameters for AVC # See: IANA dns parameters for AVC

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class CAA(dns.rdata.Rdata): class CAA(dns.rdata.Rdata):
"""CAA (Certification Authority Authorization) record""" """CAA (Certification Authority Authorization) record"""
# see: RFC 6844 # see: RFC 6844

View file

@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
@dns.immutable.immutable @dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""CDNSKEY record""" """CDNSKEY record"""

View file

@ -21,7 +21,6 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable
class CDS(dns.rdtypes.dsbase.DSBase): class CDS(dns.rdtypes.dsbase.DSBase):
"""CDS record""" """CDS record"""
_digest_length_by_type = { _digest_length_by_type = {

View file

@ -67,7 +67,6 @@ def _ctype_to_text(what):
@dns.immutable.immutable @dns.immutable.immutable
class CERT(dns.rdata.Rdata): class CERT(dns.rdata.Rdata):
"""CERT record""" """CERT record"""
# see RFC 4398 # see RFC 4398

View file

@ -21,7 +21,6 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable
class CNAME(dns.rdtypes.nsbase.NSBase): class CNAME(dns.rdtypes.nsbase.NSBase):
"""CNAME record """CNAME record
Note: although CNAME is officially a singleton type, dnspython allows Note: although CNAME is officially a singleton type, dnspython allows

View file

@ -32,7 +32,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
@dns.immutable.immutable @dns.immutable.immutable
class CSYNC(dns.rdata.Rdata): class CSYNC(dns.rdata.Rdata):
"""CSYNC record""" """CSYNC record"""
__slots__ = ["serial", "flags", "windows"] __slots__ = ["serial", "flags", "windows"]

View file

@ -21,5 +21,4 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable
class DLV(dns.rdtypes.dsbase.DSBase): class DLV(dns.rdtypes.dsbase.DSBase):
"""DLV record""" """DLV record"""

View file

@ -21,7 +21,6 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable
class DNAME(dns.rdtypes.nsbase.UncompressedNS): class DNAME(dns.rdtypes.nsbase.UncompressedNS):
"""DNAME record""" """DNAME record"""
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
@dns.immutable.immutable @dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""DNSKEY record""" """DNSKEY record"""

View file

@ -21,5 +21,4 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable @dns.immutable.immutable
class DS(dns.rdtypes.dsbase.DSBase): class DS(dns.rdtypes.dsbase.DSBase):
"""DS record""" """DS record"""

View file

@ -22,7 +22,6 @@ import dns.rdtypes.euibase
@dns.immutable.immutable @dns.immutable.immutable
class EUI48(dns.rdtypes.euibase.EUIBase): class EUI48(dns.rdtypes.euibase.EUIBase):
"""EUI48 record""" """EUI48 record"""
# see: rfc7043.txt # see: rfc7043.txt

View file

@ -22,7 +22,6 @@ import dns.rdtypes.euibase
@dns.immutable.immutable @dns.immutable.immutable
class EUI64(dns.rdtypes.euibase.EUIBase): class EUI64(dns.rdtypes.euibase.EUIBase):
"""EUI64 record""" """EUI64 record"""
# see: rfc7043.txt # see: rfc7043.txt

View file

@ -44,7 +44,6 @@ def _validate_float_string(what):
@dns.immutable.immutable @dns.immutable.immutable
class GPOS(dns.rdata.Rdata): class GPOS(dns.rdata.Rdata):
"""GPOS record""" """GPOS record"""
# see: RFC 1712 # see: RFC 1712

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class HINFO(dns.rdata.Rdata): class HINFO(dns.rdata.Rdata):
"""HINFO record""" """HINFO record"""
# see: RFC 1035 # see: RFC 1035

View file

@ -27,7 +27,6 @@ import dns.rdatatype
@dns.immutable.immutable @dns.immutable.immutable
class HIP(dns.rdata.Rdata): class HIP(dns.rdata.Rdata):
"""HIP record""" """HIP record"""
# see: RFC 5205 # see: RFC 5205

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class ISDN(dns.rdata.Rdata): class ISDN(dns.rdata.Rdata):
"""ISDN record""" """ISDN record"""
# see: RFC 1183 # see: RFC 1183

View file

@ -8,7 +8,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class L32(dns.rdata.Rdata): class L32(dns.rdata.Rdata):
"""L32 record""" """L32 record"""
# see: rfc6742.txt # see: rfc6742.txt

View file

@ -8,7 +8,6 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class L64(dns.rdata.Rdata): class L64(dns.rdata.Rdata):
"""L64 record""" """L64 record"""
# see: rfc6742.txt # see: rfc6742.txt

View file

@ -105,7 +105,6 @@ def _check_coordinate_list(value, low, high):
@dns.immutable.immutable @dns.immutable.immutable
class LOC(dns.rdata.Rdata): class LOC(dns.rdata.Rdata):
"""LOC record""" """LOC record"""
# see: RFC 1876 # see: RFC 1876

View file

@ -8,7 +8,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class LP(dns.rdata.Rdata): class LP(dns.rdata.Rdata):
"""LP record""" """LP record"""
# see: rfc6742.txt # see: rfc6742.txt

View file

@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable
class MX(dns.rdtypes.mxbase.MXBase): class MX(dns.rdtypes.mxbase.MXBase):
"""MX record""" """MX record"""

View file

@ -8,7 +8,6 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class NID(dns.rdata.Rdata): class NID(dns.rdata.Rdata):
"""NID record""" """NID record"""
# see: rfc6742.txt # see: rfc6742.txt

View file

@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable
class NINFO(dns.rdtypes.txtbase.TXTBase): class NINFO(dns.rdtypes.txtbase.TXTBase):
"""NINFO record""" """NINFO record"""
# see: draft-reid-dnsext-zs-01 # see: draft-reid-dnsext-zs-01

View file

@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable
class NS(dns.rdtypes.nsbase.NSBase): class NS(dns.rdtypes.nsbase.NSBase):
"""NS record""" """NS record"""

View file

@ -30,7 +30,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
@dns.immutable.immutable @dns.immutable.immutable
class NSEC(dns.rdata.Rdata): class NSEC(dns.rdata.Rdata):
"""NSEC record""" """NSEC record"""
__slots__ = ["next", "windows"] __slots__ = ["next", "windows"]

View file

@ -46,7 +46,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
@dns.immutable.immutable @dns.immutable.immutable
class NSEC3(dns.rdata.Rdata): class NSEC3(dns.rdata.Rdata):
"""NSEC3 record""" """NSEC3 record"""
__slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"] __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
@ -64,9 +63,13 @@ class NSEC3(dns.rdata.Rdata):
windows = Bitmap(windows) windows = Bitmap(windows)
self.windows = tuple(windows.windows) self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw): def _next_text(self):
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
next = next.rstrip("=") next = next.rstrip("=")
return next
def to_text(self, origin=None, relativize=True, **kw):
next = self._next_text()
if self.salt == b"": if self.salt == b"":
salt = "-" salt = "-"
else: else:
@ -118,3 +121,6 @@ class NSEC3(dns.rdata.Rdata):
next = parser.get_counted_bytes() next = parser.get_counted_bytes()
bitmap = Bitmap.from_wire_parser(parser) bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
def next_name(self, origin=None):
return dns.name.from_text(self._next_text(), origin)

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class NSEC3PARAM(dns.rdata.Rdata): class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record""" """NSEC3PARAM record"""
__slots__ = ["algorithm", "flags", "iterations", "salt"] __slots__ = ["algorithm", "flags", "iterations", "salt"]

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class OPENPGPKEY(dns.rdata.Rdata): class OPENPGPKEY(dns.rdata.Rdata):
"""OPENPGPKEY record""" """OPENPGPKEY record"""
# see: RFC 7929 # see: RFC 7929

View file

@ -28,7 +28,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class OPT(dns.rdata.Rdata): class OPT(dns.rdata.Rdata):
"""OPT record""" """OPT record"""
__slots__ = ["options"] __slots__ = ["options"]

View file

@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable
class PTR(dns.rdtypes.nsbase.NSBase): class PTR(dns.rdtypes.nsbase.NSBase):
"""PTR record""" """PTR record"""

View file

@ -23,7 +23,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class RP(dns.rdata.Rdata): class RP(dns.rdata.Rdata):
"""RP record""" """RP record"""
# see: RFC 1183 # see: RFC 1183

View file

@ -28,7 +28,6 @@ import dns.rdatatype
class BadSigTime(dns.exception.DNSException): class BadSigTime(dns.exception.DNSException):
"""Time in DNS SIG or RRSIG resource record cannot be parsed.""" """Time in DNS SIG or RRSIG resource record cannot be parsed."""
@ -52,7 +51,6 @@ def posixtime_to_sigtime(what):
@dns.immutable.immutable @dns.immutable.immutable
class RRSIG(dns.rdata.Rdata): class RRSIG(dns.rdata.Rdata):
"""RRSIG record""" """RRSIG record"""
__slots__ = [ __slots__ = [

View file

@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable
class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX): class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""RT record""" """RT record"""

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class SOA(dns.rdata.Rdata): class SOA(dns.rdata.Rdata):
"""SOA record""" """SOA record"""
# see: RFC 1035 # see: RFC 1035

View file

@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable
class SPF(dns.rdtypes.txtbase.TXTBase): class SPF(dns.rdtypes.txtbase.TXTBase):
"""SPF record""" """SPF record"""
# see: RFC 4408 # see: RFC 4408

View file

@ -25,7 +25,6 @@ import dns.rdatatype
@dns.immutable.immutable @dns.immutable.immutable
class SSHFP(dns.rdata.Rdata): class SSHFP(dns.rdata.Rdata):
"""SSHFP record""" """SSHFP record"""
# See RFC 4255 # See RFC 4255

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class TKEY(dns.rdata.Rdata): class TKEY(dns.rdata.Rdata):
"""TKEY Record""" """TKEY Record"""
__slots__ = [ __slots__ = [

View file

@ -6,5 +6,4 @@ import dns.rdtypes.tlsabase
@dns.immutable.immutable @dns.immutable.immutable
class TLSA(dns.rdtypes.tlsabase.TLSABase): class TLSA(dns.rdtypes.tlsabase.TLSABase):
"""TLSA record""" """TLSA record"""

View file

@ -26,7 +26,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class TSIG(dns.rdata.Rdata): class TSIG(dns.rdata.Rdata):
"""TSIG record""" """TSIG record"""
__slots__ = [ __slots__ = [

View file

@ -21,5 +21,4 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable @dns.immutable.immutable
class TXT(dns.rdtypes.txtbase.TXTBase): class TXT(dns.rdtypes.txtbase.TXTBase):
"""TXT record""" """TXT record"""

View file

@ -27,7 +27,6 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class URI(dns.rdata.Rdata): class URI(dns.rdata.Rdata):
"""URI record""" """URI record"""
# see RFC 7553 # see RFC 7553

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class X25(dns.rdata.Rdata): class X25(dns.rdata.Rdata):
"""X25 record""" """X25 record"""
# see RFC 1183 # see RFC 1183

View file

@ -11,7 +11,6 @@ import dns.zonetypes
@dns.immutable.immutable @dns.immutable.immutable
class ZONEMD(dns.rdata.Rdata): class ZONEMD(dns.rdata.Rdata):
"""ZONEMD record""" """ZONEMD record"""
# See RFC 8976 # See RFC 8976

View file

@ -23,7 +23,6 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable
class A(dns.rdata.Rdata): class A(dns.rdata.Rdata):
"""A record for Chaosnet""" """A record for Chaosnet"""
# domain: the domain of the address # domain: the domain of the address

View file

@ -24,7 +24,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class A(dns.rdata.Rdata): class A(dns.rdata.Rdata):
"""A record.""" """A record."""
__slots__ = ["address"] __slots__ = ["address"]

View file

@ -24,7 +24,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class AAAA(dns.rdata.Rdata): class AAAA(dns.rdata.Rdata):
"""AAAA record.""" """AAAA record."""
__slots__ = ["address"] __slots__ = ["address"]

View file

@ -29,7 +29,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class APLItem: class APLItem:
"""An APL list item.""" """An APL list item."""
__slots__ = ["family", "negation", "address", "prefix"] __slots__ = ["family", "negation", "address", "prefix"]
@ -80,7 +79,6 @@ class APLItem:
@dns.immutable.immutable @dns.immutable.immutable
class APL(dns.rdata.Rdata): class APL(dns.rdata.Rdata):
"""APL record.""" """APL record."""
# see: RFC 3123 # see: RFC 3123

View file

@ -24,7 +24,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class DHCID(dns.rdata.Rdata): class DHCID(dns.rdata.Rdata):
"""DHCID record""" """DHCID record"""
# see: RFC 4701 # see: RFC 4701

View file

@ -29,7 +29,6 @@ class Gateway(dns.rdtypes.util.Gateway):
@dns.immutable.immutable @dns.immutable.immutable
class IPSECKEY(dns.rdata.Rdata): class IPSECKEY(dns.rdata.Rdata):
"""IPSECKEY record""" """IPSECKEY record"""
# see: RFC 4025 # see: RFC 4025

View file

@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable @dns.immutable.immutable
class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX): class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""KX record""" """KX record"""

View file

@ -33,7 +33,6 @@ def _write_string(file, s):
@dns.immutable.immutable @dns.immutable.immutable
class NAPTR(dns.rdata.Rdata): class NAPTR(dns.rdata.Rdata):
"""NAPTR record""" """NAPTR record"""
# see: RFC 3403 # see: RFC 3403

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class NSAP(dns.rdata.Rdata): class NSAP(dns.rdata.Rdata):
"""NSAP record.""" """NSAP record."""
# see: RFC 1706 # see: RFC 1706

View file

@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable @dns.immutable.immutable
class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS): class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS):
"""NSAP-PTR record""" """NSAP-PTR record"""

View file

@ -26,7 +26,6 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class PX(dns.rdata.Rdata): class PX(dns.rdata.Rdata):
"""PX record.""" """PX record."""
# see: RFC 2163 # see: RFC 2163

View file

@ -26,7 +26,6 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class SRV(dns.rdata.Rdata): class SRV(dns.rdata.Rdata):
"""SRV record""" """SRV record"""
# see: RFC 2782 # see: RFC 2782

View file

@ -33,7 +33,6 @@ except OSError:
@dns.immutable.immutable @dns.immutable.immutable
class WKS(dns.rdata.Rdata): class WKS(dns.rdata.Rdata):
"""WKS record""" """WKS record"""
# see: RFC 1035 # see: RFC 1035

View file

@ -36,7 +36,6 @@ class Flag(enum.IntFlag):
@dns.immutable.immutable @dns.immutable.immutable
class DNSKEYBase(dns.rdata.Rdata): class DNSKEYBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DNSKEY record""" """Base class for rdata that is like a DNSKEY record"""
__slots__ = ["flags", "protocol", "algorithm", "key"] __slots__ = ["flags", "protocol", "algorithm", "key"]

View file

@ -26,7 +26,6 @@ import dns.rdatatype
@dns.immutable.immutable @dns.immutable.immutable
class DSBase(dns.rdata.Rdata): class DSBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DS record""" """Base class for rdata that is like a DS record"""
__slots__ = ["key_tag", "algorithm", "digest_type", "digest"] __slots__ = ["key_tag", "algorithm", "digest_type", "digest"]

View file

@ -22,7 +22,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class EUIBase(dns.rdata.Rdata): class EUIBase(dns.rdata.Rdata):
"""EUIxx record""" """EUIxx record"""
# see: rfc7043.txt # see: rfc7043.txt

View file

@ -28,7 +28,6 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class MXBase(dns.rdata.Rdata): class MXBase(dns.rdata.Rdata):
"""Base class for rdata that is like an MX record.""" """Base class for rdata that is like an MX record."""
__slots__ = ["preference", "exchange"] __slots__ = ["preference", "exchange"]
@ -71,7 +70,6 @@ class MXBase(dns.rdata.Rdata):
@dns.immutable.immutable @dns.immutable.immutable
class UncompressedMX(MXBase): class UncompressedMX(MXBase):
"""Base class for rdata that is like an MX record, but whose name """Base class for rdata that is like an MX record, but whose name
is not compressed when converted to DNS wire format, and whose is not compressed when converted to DNS wire format, and whose
digestable form is not downcased.""" digestable form is not downcased."""
@ -82,7 +80,6 @@ class UncompressedMX(MXBase):
@dns.immutable.immutable @dns.immutable.immutable
class UncompressedDowncasingMX(MXBase): class UncompressedDowncasingMX(MXBase):
"""Base class for rdata that is like an MX record, but whose name """Base class for rdata that is like an MX record, but whose name
is not compressed when convert to DNS wire format.""" is not compressed when convert to DNS wire format."""

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
class NSBase(dns.rdata.Rdata): class NSBase(dns.rdata.Rdata):
"""Base class for rdata that is like an NS record.""" """Base class for rdata that is like an NS record."""
__slots__ = ["target"] __slots__ = ["target"]
@ -56,7 +55,6 @@ class NSBase(dns.rdata.Rdata):
@dns.immutable.immutable @dns.immutable.immutable
class UncompressedNS(NSBase): class UncompressedNS(NSBase):
"""Base class for rdata that is like an NS record, but whose name """Base class for rdata that is like an NS record, but whose name
is not compressed when convert to DNS wire format, and whose is not compressed when convert to DNS wire format, and whose
digestable form is not downcased.""" digestable form is not downcased."""

View file

@ -2,7 +2,6 @@
import base64 import base64
import enum import enum
import io
import struct import struct
import dns.enum import dns.enum
@ -13,6 +12,7 @@ import dns.ipv6
import dns.name import dns.name
import dns.rdata import dns.rdata
import dns.rdtypes.util import dns.rdtypes.util
import dns.renderer
import dns.tokenizer import dns.tokenizer
import dns.wire import dns.wire
@ -427,7 +427,6 @@ def _validate_and_define(params, key, value):
@dns.immutable.immutable @dns.immutable.immutable
class SVCBBase(dns.rdata.Rdata): class SVCBBase(dns.rdata.Rdata):
"""Base class for SVCB-like records""" """Base class for SVCB-like records"""
# see: draft-ietf-dnsop-svcb-https-11 # see: draft-ietf-dnsop-svcb-https-11
@ -521,19 +520,10 @@ class SVCBBase(dns.rdata.Rdata):
for key in sorted(self.params): for key in sorted(self.params):
file.write(struct.pack("!H", key)) file.write(struct.pack("!H", key))
value = self.params[key] value = self.params[key]
# placeholder for length (or actual length of empty values) with dns.renderer.prefixed_length(file, 2):
file.write(struct.pack("!H", 0)) # Note that we're still writing a length of zero if the value is None
if value is None: if value is not None:
continue
else:
start = file.tell()
value.to_wire(file, origin) value.to_wire(file, origin)
end = file.tell()
assert end - start < 65536
file.seek(start - 2)
stuff = struct.pack("!H", end - start)
file.write(stuff)
file.seek(0, io.SEEK_END)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -25,7 +25,6 @@ import dns.rdatatype
@dns.immutable.immutable @dns.immutable.immutable
class TLSABase(dns.rdata.Rdata): class TLSABase(dns.rdata.Rdata):
"""Base class for TLSA and SMIMEA records""" """Base class for TLSA and SMIMEA records"""
# see: RFC 6698 # see: RFC 6698

View file

@ -17,18 +17,17 @@
"""TXT-like base class.""" """TXT-like base class."""
import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata import dns.rdata
import dns.renderer
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class TXTBase(dns.rdata.Rdata): class TXTBase(dns.rdata.Rdata):
"""Base class for rdata that is like a TXT record (see RFC 1035).""" """Base class for rdata that is like a TXT record (see RFC 1035)."""
__slots__ = ["strings"] __slots__ = ["strings"]
@ -56,7 +55,7 @@ class TXTBase(dns.rdata.Rdata):
self, self,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
relativize: bool = True, relativize: bool = True,
**kw: Dict[str, Any] **kw: Dict[str, Any],
) -> str: ) -> str:
txt = "" txt = ""
prefix = "" prefix = ""
@ -93,9 +92,7 @@ class TXTBase(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
for s in self.strings: for s in self.strings:
l = len(s) with dns.renderer.prefixed_length(file, 1):
assert l < 256
file.write(struct.pack("!B", l))
file.write(s) file.write(s)
@classmethod @classmethod

View file

@ -32,6 +32,24 @@ AUTHORITY = 2
ADDITIONAL = 3 ADDITIONAL = 3
@contextlib.contextmanager
def prefixed_length(output, length_length):
output.write(b"\00" * length_length)
start = output.tell()
yield
end = output.tell()
length = end - start
if length > 0:
try:
output.seek(start - length_length)
try:
output.write(length.to_bytes(length_length, "big"))
except OverflowError:
raise dns.exception.FormError
finally:
output.seek(end)
class Renderer: class Renderer:
"""Helper class for building DNS wire-format messages. """Helper class for building DNS wire-format messages.
@ -134,6 +152,15 @@ class Renderer:
self._rollback(start) self._rollback(start)
raise dns.exception.TooBig raise dns.exception.TooBig
@contextlib.contextmanager
def _temporarily_seek_to(self, where):
current = self.output.tell()
try:
self.output.seek(where)
yield
finally:
self.output.seek(current)
def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
"""Add a question to the message.""" """Add a question to the message."""
@ -269,18 +296,14 @@ class Renderer:
with self._track_size(): with self._track_size():
keyname.to_wire(self.output, compress, self.origin) keyname.to_wire(self.output, compress, self.origin)
self.output.write( self.output.write(
struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0) struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0)
) )
rdata_start = self.output.tell() with prefixed_length(self.output, 2):
tsig.to_wire(self.output) tsig.to_wire(self.output)
after = self.output.tell()
self.output.seek(rdata_start - 2)
self.output.write(struct.pack("!H", after - rdata_start))
self.counts[ADDITIONAL] += 1 self.counts[ADDITIONAL] += 1
self.output.seek(10) with self._temporarily_seek_to(10):
self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
self.output.seek(0, io.SEEK_END)
def write_header(self): def write_header(self):
"""Write the DNS message header. """Write the DNS message header.
@ -290,7 +313,7 @@ class Renderer:
is added. is added.
""" """
self.output.seek(0) with self._temporarily_seek_to(0):
self.output.write( self.output.write(
struct.pack( struct.pack(
"!HHHHHH", "!HHHHHH",
@ -302,7 +325,6 @@ class Renderer:
self.counts[3], self.counts[3],
) )
) )
self.output.seek(0, io.SEEK_END)
def get_wire(self): def get_wire(self):
"""Return the wire format message.""" """Return the wire format message."""

View file

@ -26,7 +26,6 @@ import dns.renderer
class RRset(dns.rdataset.Rdataset): class RRset(dns.rdataset.Rdataset):
"""A DNS RRset (named rdataset). """A DNS RRset (named rdataset).
RRset inherits from Rdataset, and RRsets can be treated as RRset inherits from Rdataset, and RRsets can be treated as
@ -132,7 +131,7 @@ class RRset(dns.rdataset.Rdataset):
self, self,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
relativize: bool = True, relativize: bool = True,
**kw: Dict[str, Any] **kw: Dict[str, Any],
) -> str: ) -> str:
"""Convert the RRset into DNS zone file format. """Convert the RRset into DNS zone file format.
@ -159,7 +158,7 @@ class RRset(dns.rdataset.Rdataset):
file: Any, file: Any,
compress: Optional[dns.name.CompressType] = None, # type: ignore compress: Optional[dns.name.CompressType] = None, # type: ignore
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
**kw: Dict[str, Any] **kw: Dict[str, Any],
) -> int: ) -> int:
"""Convert the RRset to wire format. """Convert the RRset to wire format.
@ -231,7 +230,7 @@ def from_text(
ttl: int, ttl: int,
rdclass: Union[dns.rdataclass.RdataClass, str], rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str], rdtype: Union[dns.rdatatype.RdataType, str],
*text_rdatas: Any *text_rdatas: Any,
) -> RRset: ) -> RRset:
"""Create an RRset with the specified name, TTL, class, and type and with """Create an RRset with the specified name, TTL, class, and type and with
the specified rdatas in text format. the specified rdatas in text format.

View file

@ -19,7 +19,6 @@ import itertools
class Set: class Set:
"""A simple set class. """A simple set class.
This class was originally used to deal with sets being missing in This class was originally used to deal with sets being missing in

View file

@ -203,7 +203,7 @@ class Transaction:
- name - name
- name, rdataclass, rdatatype, [covers] - name, rdatatype, [covers]
- name, rdataset... - name, rdataset...
@ -222,7 +222,7 @@ class Transaction:
- name - name
- name, rdataclass, rdatatype, [covers] - name, rdatatype, [covers]
- name, rdataset... - name, rdataset...

View file

@ -29,47 +29,38 @@ import dns.rdataclass
class BadTime(dns.exception.DNSException): class BadTime(dns.exception.DNSException):
"""The current time is not within the TSIG's validity time.""" """The current time is not within the TSIG's validity time."""
class BadSignature(dns.exception.DNSException): class BadSignature(dns.exception.DNSException):
"""The TSIG signature fails to verify.""" """The TSIG signature fails to verify."""
class BadKey(dns.exception.DNSException): class BadKey(dns.exception.DNSException):
"""The TSIG record owner name does not match the key.""" """The TSIG record owner name does not match the key."""
class BadAlgorithm(dns.exception.DNSException): class BadAlgorithm(dns.exception.DNSException):
"""The TSIG algorithm does not match the key.""" """The TSIG algorithm does not match the key."""
class PeerError(dns.exception.DNSException): class PeerError(dns.exception.DNSException):
"""Base class for all TSIG errors generated by the remote peer""" """Base class for all TSIG errors generated by the remote peer"""
class PeerBadKey(PeerError): class PeerBadKey(PeerError):
"""The peer didn't know the key we used""" """The peer didn't know the key we used"""
class PeerBadSignature(PeerError): class PeerBadSignature(PeerError):
"""The peer didn't like the signature we sent""" """The peer didn't like the signature we sent"""
class PeerBadTime(PeerError): class PeerBadTime(PeerError):
"""The peer didn't like the time we sent""" """The peer didn't like the time we sent"""
class PeerBadTruncation(PeerError): class PeerBadTruncation(PeerError):
"""The peer didn't like amount of truncation in the TSIG we sent""" """The peer didn't like amount of truncation in the TSIG we sent"""

View file

@ -20,9 +20,9 @@
#: MAJOR #: MAJOR
MAJOR = 2 MAJOR = 2
#: MINOR #: MINOR
MINOR = 4 MINOR = 6
#: MICRO #: MICRO
MICRO = 2 MICRO = 1
#: RELEASELEVEL #: RELEASELEVEL
RELEASELEVEL = 0x0F RELEASELEVEL = 0x0F
#: SERIAL #: SERIAL

View file

@ -1,5 +1,7 @@
import sys import sys
import dns._features
if sys.platform == "win32": if sys.platform == "win32":
from typing import Any from typing import Any
@ -15,14 +17,14 @@ if sys.platform == "win32":
except KeyError: except KeyError:
WindowsError = Exception WindowsError = Exception
try: if dns._features.have("wmi"):
import threading import threading
import pythoncom # pylint: disable=import-error import pythoncom # pylint: disable=import-error
import wmi # pylint: disable=import-error import wmi # pylint: disable=import-error
_have_wmi = True _have_wmi = True
except Exception: else:
_have_wmi = False _have_wmi = False
def _config_domain(domain): def _config_domain(domain):
@ -51,9 +53,10 @@ if sys.platform == "win32":
try: try:
system = wmi.WMI() system = wmi.WMI()
for interface in system.Win32_NetworkAdapterConfiguration(): for interface in system.Win32_NetworkAdapterConfiguration():
if interface.IPEnabled and interface.DNSDomain: if interface.IPEnabled and interface.DNSServerSearchOrder:
self.info.domain = _config_domain(interface.DNSDomain)
self.info.nameservers = list(interface.DNSServerSearchOrder) self.info.nameservers = list(interface.DNSServerSearchOrder)
if interface.DNSDomain:
self.info.domain = _config_domain(interface.DNSDomain)
if interface.DNSDomainSuffixSearchOrder: if interface.DNSDomainSuffixSearchOrder:
self.info.search = [ self.info.search = [
_config_domain(x) _config_domain(x)

View file

@ -21,7 +21,18 @@ import contextlib
import io import io
import os import os
import struct import struct
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
MutableMapping,
Optional,
Set,
Tuple,
Union,
)
import dns.exception import dns.exception
import dns.grange import dns.grange
@ -43,47 +54,70 @@ from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers
class BadZone(dns.exception.DNSException): class BadZone(dns.exception.DNSException):
"""The DNS zone is malformed.""" """The DNS zone is malformed."""
class NoSOA(BadZone): class NoSOA(BadZone):
"""The DNS zone has no SOA RR at its origin.""" """The DNS zone has no SOA RR at its origin."""
class NoNS(BadZone): class NoNS(BadZone):
"""The DNS zone has no NS RRset at its origin.""" """The DNS zone has no NS RRset at its origin."""
class UnknownOrigin(BadZone): class UnknownOrigin(BadZone):
"""The DNS zone's origin is unknown.""" """The DNS zone's origin is unknown."""
class UnsupportedDigestScheme(dns.exception.DNSException): class UnsupportedDigestScheme(dns.exception.DNSException):
"""The zone digest's scheme is unsupported.""" """The zone digest's scheme is unsupported."""
class UnsupportedDigestHashAlgorithm(dns.exception.DNSException): class UnsupportedDigestHashAlgorithm(dns.exception.DNSException):
"""The zone digest's origin is unsupported.""" """The zone digest's origin is unsupported."""
class NoDigest(dns.exception.DNSException): class NoDigest(dns.exception.DNSException):
"""The DNS zone has no ZONEMD RRset at its origin.""" """The DNS zone has no ZONEMD RRset at its origin."""
class DigestVerificationFailure(dns.exception.DNSException): class DigestVerificationFailure(dns.exception.DNSException):
"""The ZONEMD digest failed to verify.""" """The ZONEMD digest failed to verify."""
class Zone(dns.transaction.TransactionManager): def _validate_name(
name: dns.name.Name,
origin: Optional[dns.name.Name],
relativize: bool,
) -> dns.name.Name:
# This name validation code is shared by Zone and Version
if origin is None:
# This should probably never happen as other code (e.g.
# _rr_line) will notice the lack of an origin before us, but
# we check just in case!
raise KeyError("no zone origin is defined")
if name.is_absolute():
if not name.is_subdomain(origin):
raise KeyError("name parameter must be a subdomain of the zone origin")
if relativize:
name = name.relativize(origin)
else:
# We have a relative name. Make sure that the derelativized name is
# not too long.
try:
abs_name = name.derelativize(origin)
except dns.name.NameTooLong:
# We map dns.name.NameTooLong to KeyError to be consistent with
# the other exceptions above.
raise KeyError("relative name too long for zone")
if not relativize:
# We have a relative name in a non-relative zone, so use the
# derelativized name.
name = abs_name
return name
class Zone(dns.transaction.TransactionManager):
"""A DNS zone. """A DNS zone.
A ``Zone`` is a mapping from names to nodes. The zone object may be A ``Zone`` is a mapping from names to nodes. The zone object may be
@ -94,7 +128,10 @@ class Zone(dns.transaction.TransactionManager):
the zone. the zone.
""" """
node_factory = dns.node.Node node_factory: Callable[[], dns.node.Node] = dns.node.Node
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
__slots__ = ["rdclass", "origin", "nodes", "relativize"] __slots__ = ["rdclass", "origin", "nodes", "relativize"]
@ -125,7 +162,7 @@ class Zone(dns.transaction.TransactionManager):
raise ValueError("origin parameter must be an absolute name") raise ValueError("origin parameter must be an absolute name")
self.origin = origin self.origin = origin
self.rdclass = rdclass self.rdclass = rdclass
self.nodes: Dict[dns.name.Name, dns.node.Node] = {} self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory()
self.relativize = relativize self.relativize = relativize
def __eq__(self, other): def __eq__(self, other):
@ -154,26 +191,13 @@ class Zone(dns.transaction.TransactionManager):
return not self.__eq__(other) return not self.__eq__(other)
def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
# Note that any changes in this method should have corresponding changes
# made in the Version _validate_name() method.
if isinstance(name, str): if isinstance(name, str):
name = dns.name.from_text(name, None) name = dns.name.from_text(name, None)
elif not isinstance(name, dns.name.Name): elif not isinstance(name, dns.name.Name):
raise KeyError("name parameter must be convertible to a DNS name") raise KeyError("name parameter must be convertible to a DNS name")
if name.is_absolute(): return _validate_name(name, self.origin, self.relativize)
if self.origin is None:
# This should probably never happen as other code (e.g.
# _rr_line) will notice the lack of an origin before us, but
# we check just in case!
raise KeyError("no zone origin is defined")
if not name.is_subdomain(self.origin):
raise KeyError("name parameter must be a subdomain of the zone origin")
if self.relativize:
name = name.relativize(self.origin)
elif not self.relativize:
# We have a relative name in a non-relative zone, so derelativize.
if self.origin is None:
raise KeyError("no zone origin is defined")
name = name.derelativize(self.origin)
return name
def __getitem__(self, key): def __getitem__(self, key):
key = self._validate_name(key) key = self._validate_name(key)
@ -252,9 +276,6 @@ class Zone(dns.transaction.TransactionManager):
*create*, a ``bool``. If true, the node will be created if it does *create*, a ``bool``. If true, the node will be created if it does
not exist. not exist.
Raises ``KeyError`` if the name is not known and create was
not specified, or if the name was not a subdomain of the origin.
Returns a ``dns.node.Node`` or ``None``. Returns a ``dns.node.Node`` or ``None``.
""" """
@ -527,9 +548,6 @@ class Zone(dns.transaction.TransactionManager):
*create*, a ``bool``. If true, the node will be created if it does *create*, a ``bool``. If true, the node will be created if it does
not exist. not exist.
Raises ``KeyError`` if the name is not known and create was
not specified, or if the name was not a subdomain of the origin.
Returns a ``dns.rrset.RRset`` or ``None``. Returns a ``dns.rrset.RRset`` or ``None``.
""" """
@ -952,7 +970,7 @@ class Version:
self, self,
zone: Zone, zone: Zone,
id: int, id: int,
nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None, nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None,
origin: Optional[dns.name.Name] = None, origin: Optional[dns.name.Name] = None,
): ):
self.zone = zone self.zone = zone
@ -960,26 +978,11 @@ class Version:
if nodes is not None: if nodes is not None:
self.nodes = nodes self.nodes = nodes
else: else:
self.nodes = {} self.nodes = zone.map_factory()
self.origin = origin self.origin = origin
def _validate_name(self, name: dns.name.Name) -> dns.name.Name: def _validate_name(self, name: dns.name.Name) -> dns.name.Name:
if name.is_absolute(): return _validate_name(name, self.origin, self.zone.relativize)
if self.origin is None:
# This should probably never happen as other code (e.g.
# _rr_line) will notice the lack of an origin before us, but
# we check just in case!
raise KeyError("no zone origin is defined")
if not name.is_subdomain(self.origin):
raise KeyError("name is not a subdomain of the zone origin")
if self.zone.relativize:
name = name.relativize(self.origin)
elif not self.zone.relativize:
# We have a relative name in a non-relative zone, so derelativize.
if self.origin is None:
raise KeyError("no zone origin is defined")
name = name.derelativize(self.origin)
return name
def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
name = self._validate_name(name) name = self._validate_name(name)
@ -1085,7 +1088,9 @@ class ImmutableVersion(Version):
version.nodes[name] = ImmutableVersionedNode(node) version.nodes[name] = ImmutableVersionedNode(node)
# We're changing the type of the nodes dictionary here on purpose, so # We're changing the type of the nodes dictionary here on purpose, so
# we ignore the mypy error. # we ignore the mypy error.
self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore self.nodes = dns.immutable.Dict(
version.nodes, True, self.zone.map_factory
) # type: ignore
class Transaction(dns.transaction.Transaction): class Transaction(dns.transaction.Transaction):
@ -1101,7 +1106,10 @@ class Transaction(dns.transaction.Transaction):
def _setup_version(self): def _setup_version(self):
assert self.version is None assert self.version is None
self.version = WritableVersion(self.zone, self.replacement) factory = self.manager.writable_version_factory
if factory is None:
factory = WritableVersion
self.version = factory(self.zone, self.replacement)
def _get_rdataset(self, name, rdtype, covers): def _get_rdataset(self, name, rdtype, covers):
return self.version.get_rdataset(name, rdtype, covers) return self.version.get_rdataset(name, rdtype, covers)
@ -1132,7 +1140,10 @@ class Transaction(dns.transaction.Transaction):
self.zone._end_read(self) self.zone._end_read(self)
elif commit and len(self.version.changed) > 0: elif commit and len(self.version.changed) > 0:
if self.make_immutable: if self.make_immutable:
version = ImmutableVersion(self.version) factory = self.manager.immutable_version_factory
if factory is None:
factory = ImmutableVersion
version = factory(self.version)
else: else:
version = self.version version = self.version
self.zone._commit_version(self, version, self.version.origin) self.zone._commit_version(self, version, self.version.origin)
@ -1168,6 +1179,48 @@ class Transaction(dns.transaction.Transaction):
return (absolute, relativize, effective) return (absolute, relativize, effective)
def _from_text(
text: Any,
origin: Optional[Union[dns.name.Name, str]] = None,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
relativize: bool = True,
zone_factory: Any = Zone,
filename: Optional[str] = None,
allow_include: bool = False,
check_origin: bool = True,
idna_codec: Optional[dns.name.IDNACodec] = None,
allow_directives: Union[bool, Iterable[str]] = True,
) -> Zone:
# See the comments for the public APIs from_text() and from_file() for
# details.
# 'text' can also be a file, but we don't publish that fact
# since it's an implementation detail. The official file
# interface is from_file().
if filename is None:
filename = "<string>"
zone = zone_factory(origin, rdclass, relativize=relativize)
with zone.writer(True) as txn:
tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
reader = dns.zonefile.Reader(
tok,
rdclass,
txn,
allow_include=allow_include,
allow_directives=allow_directives,
)
try:
reader.read()
except dns.zonefile.UnknownOrigin:
# for backwards compatibility
raise dns.zone.UnknownOrigin
# Now that we're done reading, do some basic checking of the zone.
if check_origin:
zone.check_origin()
return zone
def from_text( def from_text(
text: str, text: str,
origin: Optional[Union[dns.name.Name, str]] = None, origin: Optional[Union[dns.name.Name, str]] = None,
@ -1228,32 +1281,18 @@ def from_text(
Returns a subclass of ``dns.zone.Zone``. Returns a subclass of ``dns.zone.Zone``.
""" """
return _from_text(
# 'text' can also be a file, but we don't publish that fact text,
# since it's an implementation detail. The official file origin,
# interface is from_file().
if filename is None:
filename = "<string>"
zone = zone_factory(origin, rdclass, relativize=relativize)
with zone.writer(True) as txn:
tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
reader = dns.zonefile.Reader(
tok,
rdclass, rdclass,
txn, relativize,
allow_include=allow_include, zone_factory,
allow_directives=allow_directives, filename,
allow_include,
check_origin,
idna_codec,
allow_directives,
) )
try:
reader.read()
except dns.zonefile.UnknownOrigin:
# for backwards compatibility
raise dns.zone.UnknownOrigin
# Now that we're done reading, do some basic checking of the zone.
if check_origin:
zone.check_origin()
return zone
def from_file( def from_file(
@ -1324,7 +1363,7 @@ def from_file(
else: else:
cm = contextlib.nullcontext(f) cm = contextlib.nullcontext(f)
with cm as f: with cm as f:
return from_text( return _from_text(
f, f,
origin, origin,
rdclass, rdclass,

View file

@ -86,7 +86,6 @@ def _upper_dollarize(s):
class Reader: class Reader:
"""Read a DNS zone file into a transaction.""" """Read a DNS zone file into a transaction."""
def __init__( def __init__(

Some files were not shown because too many files have changed in this diff Show more