Bump dnspython from 2.4.2 to 2.6.1 (#2264)

* Bump dnspython from 2.4.2 to 2.6.1

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.4.2 to 2.6.1.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.4.2...v2.6.1)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.6.1

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2024-03-24 15:25:23 -07:00 committed by GitHub
parent aca7e72715
commit cfefa928be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 1052 additions and 459 deletions

View file

@ -7,7 +7,9 @@ import socket
import sys
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
_is_win32 = sys.platform == "win32"
@ -121,7 +123,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
return self.writer.get_extra_info("peercert")
try:
if dns._features.have("doh"):
import anyio
import httpcore
import httpcore._backends.anyio
@ -205,7 +207,7 @@ try:
resolver, local_port, bootstrap_address, family
)
except ImportError:
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
@ -224,14 +226,12 @@ class Backend(dns._asyncbackend.Backend):
ssl_context=None,
server_hostname=None,
):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
if _is_win32 and source is None:
# Win32 wants explicit binding before recvfrom(). This is the
# proper fix for [#637].
source = (dns.inet.any_for_af(af), 0)
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol,
source,
@ -266,7 +266,7 @@ class Backend(dns._asyncbackend.Backend):
await asyncio.sleep(interval)
def datagram_connection_required(self):
return _is_win32
return False
def get_transport_class(self):
return _HTTPTransport

92
lib/dns/_features.py Normal file
View file

@ -0,0 +1,92 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import importlib.metadata
import itertools
import string
from typing import Dict, List, Tuple
def _tuple_from_text(version: str) -> Tuple:
text_parts = version.split(".")
int_parts = []
for text_part in text_parts:
digit_prefix = "".join(
itertools.takewhile(lambda x: x in string.digits, text_part)
)
try:
int_parts.append(int(digit_prefix))
except Exception:
break
return tuple(int_parts)
def _version_check(
requirement: str,
) -> bool:
"""Is the requirement fulfilled?
The requirement must be of the form
package>=version
"""
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
except Exception:
return False
t_version = _tuple_from_text(version)
t_minimum = _tuple_from_text(minimum)
if t_version < t_minimum:
return False
return True
_cache: Dict[str, bool] = {}
def have(feature: str) -> bool:
"""Is *feature* available?
This tests if all optional packages needed for the
feature are available and recent enough.
Returns ``True`` if the feature is available,
and ``False`` if it is not or if metadata is
missing.
"""
value = _cache.get(feature)
if value is not None:
return value
requirements = _requirements.get(feature)
if requirements is None:
# we make a cache entry here for consistency not performance
_cache[feature] = False
return False
ok = True
for requirement in requirements:
if not _version_check(requirement):
ok = False
break
_cache[feature] = ok
return ok
def force(feature: str, enabled: bool) -> None:
"""Force the status of *feature* to be *enabled*.
This method is provided as a workaround for any cases
where importlib.metadata is ineffective, or for testing.
"""
_cache[feature] = enabled
_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=41"],
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
"doq": ["aioquic>=0.9.25"],
"idna": ["idna>=3.6"],
"trio": ["trio>=0.23"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements
}

View file

@ -8,9 +8,13 @@ import trio
import trio.socket # type: ignore
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
if not dns._features.have("trio"):
raise ImportError("trio not found or too old")
def _maybe_timeout(timeout):
if timeout is not None:
@ -95,7 +99,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
raise NotImplementedError
try:
if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
@ -177,7 +181,7 @@ try:
resolver, local_port, bootstrap_address, family
)
except ImportError:
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore

View file

@ -32,7 +32,7 @@ def get_backend(name: str) -> Backend:
*name*, a ``str``, the name of the backend. Currently the "trio"
and "asyncio" backends are available.
Raises NotImplementError if an unknown backend name is specified.
Raises NotImplementedError if an unknown backend name is specified.
"""
# pylint: disable=import-outside-toplevel,redefined-outer-name
backend = _backends.get(name)

View file

@ -41,7 +41,7 @@ from dns.query import (
NoDOQ,
UDPMode,
_compute_times,
_have_http2,
_make_dot_ssl_context,
_matches_destination,
_remaining,
have_doh,
@ -120,6 +120,8 @@ async def receive_udp(
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: Optional[dns.message.Message] = None,
) -> Any:
"""Read a DNS message from a UDP socket.
@ -133,22 +135,40 @@ async def receive_udp(
"""
wire = b""
while 1:
while True:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if _matches_destination(
if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
break
received_time = time.time()
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
return (r, received_time, from_address)
continue
received_time = time.time()
try:
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# See the comment in query.py for details.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
return (r, received_time, from_address)
async def udp(
@ -164,6 +184,7 @@ async def udp(
raise_on_truncation: bool = False,
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
ignore_errors: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
@ -205,9 +226,13 @@ async def udp(
q.mac,
ignore_trailing,
raise_on_truncation,
ignore_errors,
q,
)
r.time = received_time - begin_time
if not q.is_response(r):
# We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse
return r
@ -225,6 +250,7 @@ async def udp_with_fallback(
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
@ -260,6 +286,7 @@ async def udp_with_fallback(
True,
udp_sock,
backend,
ignore_errors,
)
return (response, False)
except dns.message.Truncated:
@ -292,14 +319,12 @@ async def send_tcp(
"""
if isinstance(what, dns.message.Message):
wire = what.to_wire()
tcpmsg = what.to_wire(prepend_length=True)
else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + wire
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
@ -418,6 +443,7 @@ async def tls(
backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
@ -439,11 +465,7 @@ async def tls(
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
# See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context = _make_dot_ssl_context(server_hostname, verify)
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
@ -538,7 +560,7 @@ async def https(
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=_have_http2,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
@ -550,7 +572,7 @@ async def https(
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=_have_http2, verify=verify, transport=transport
http1=True, http2=True, verify=verify, transport=transport
)
async with cm as the_client:

View file

@ -27,6 +27,7 @@ import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
import dns._features
import dns.exception
import dns.name
import dns.node
@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs):
) # pragma: no cover
try:
if dns._features.have("dnssec"):
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
@ -1184,20 +1185,20 @@ try:
get_algorithm_cls_from_dnskey,
)
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
except ImportError: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False
else:
validate = _validate # type: ignore
validate_rrsig = _validate_rrsig # type: ignore
sign = _sign
make_dnskey = _make_dnskey
make_cdnskey = _make_cdnskey
_have_pyca = True
else: # pragma: no cover
validate = _need_pyca
validate_rrsig = _need_pyca
sign = _need_pyca
make_dnskey = _need_pyca
make_cdnskey = _need_pyca
_have_pyca = False
### BEGIN generated Algorithm constants

View file

@ -1,9 +1,12 @@
from typing import Dict, Optional, Tuple, Type, Union
import dns.name
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
try:
from dns.dnssecalgs.base import GenericPrivateKey
if dns._features.have("dnssec"):
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
@ -16,13 +19,9 @@ try:
)
_have_cryptography = True
except ImportError:
else:
_have_cryptography = False
from dns.dnssectypes import Algorithm
from dns.exception import UnsupportedAlgorithm
from dns.rdtypes.ANY.DNSKEY import DNSKEY
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}

View file

@ -17,6 +17,7 @@
"""EDNS Options"""
import binascii
import math
import socket
import struct
@ -58,7 +59,6 @@ class OptionType(dns.enum.IntEnum):
class Option:
"""Base class for all EDNS option types."""
def __init__(self, otype: Union[OptionType, str]):
@ -76,6 +76,9 @@ class Option:
"""
raise NotImplementedError # pragma: no cover
def to_text(self) -> str:
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format.
@ -141,7 +144,6 @@ class Option:
class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class
This class is used for EDNS option types for which we have no better
@ -343,6 +345,8 @@ class EDECode(dns.enum.IntEnum):
class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
_preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
@ -360,6 +364,13 @@ class EDEOption(Option): # lgtm[py/missing-equals]
def to_text(self) -> str:
output = f"EDE {self.code}"
if self.code in EDECode:
desc = EDECode.to_text(self.code)
desc = " ".join(
word if word in self._preserve_case else word.title()
for word in desc.split("_")
)
output += f" ({desc})"
if self.text is not None:
output += f": {self.text}"
return output
@ -392,9 +403,37 @@ class EDEOption(Option): # lgtm[py/missing-equals]
return cls(code, btext)
class NSIDOption(Option):
def __init__(self, nsid: bytes):
super().__init__(OptionType.NSID)
self.nsid = nsid
def to_wire(self, file: Any = None) -> Optional[bytes]:
if file:
file.write(self.nsid)
return None
else:
return self.nsid
def to_text(self) -> str:
if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
# All ASCII printable, so it's probably a string.
value = self.nsid.decode()
else:
value = binascii.hexlify(self.nsid).decode()
return f"NSID {value}"
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_remaining())
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption,
}

View file

@ -1,24 +1,30 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc
from typing import Any
from typing import Any, Callable
from dns._immutable_ctx import immutable
@immutable
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
def __init__(self, dictionary: Any, no_copy: bool = False):
def __init__(
self,
dictionary: Any,
no_copy: bool = False,
map_factory: Callable[[], collections.abc.MutableMapping] = dict,
):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external
references to the dictionary.
"""
if no_copy and isinstance(dictionary, dict):
if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
self._odict = dictionary
else:
self._odict = dict(dictionary)
self._odict = map_factory()
self._odict.update(dictionary)
self._hash = None
def __getitem__(self, key):

View file

@ -178,3 +178,20 @@ def any_for_af(af):
elif af == socket.AF_INET6:
return "::"
raise NotImplementedError(f"unknown address family {af}")
def canonicalize(text: str) -> str:
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
canonical text form. IPv6 addresses with scopes are rejected.
*text*, a ``str``, the address in textual form.
Raises ``ValueError`` if the text is not valid.
"""
try:
return dns.ipv6.canonicalize(text)
except Exception:
try:
return dns.ipv4.canonicalize(text)
except Exception:
raise ValueError

View file

@ -62,3 +62,16 @@ def inet_aton(text: Union[str, bytes]) -> bytes:
return struct.pack("BBBB", *b)
except Exception:
raise dns.exception.SyntaxError
def canonicalize(text: Union[str, bytes]) -> str:
"""Verify that *address* is a valid text form IPv4 address and return its
canonical text form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
# Note that inet_aton() only accepts canonial form, but we still run through
# inet_ntoa() to ensure the output is a str.
return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))

View file

@ -104,7 +104,7 @@ _colon_colon_end = re.compile(rb".*::$")
def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form.
*text*, a ``str``, the IPv6 address in textual form.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
*ignore_scope*, a ``bool``. If ``True``, a scope will be ignored.
If ``False``, the default, it is an error for a scope to be present.
@ -206,3 +206,14 @@ def is_mapped(address: bytes) -> bool:
"""
return address.startswith(_mapped_prefix)
def canonicalize(text: Union[str, bytes]) -> str:
"""Verify that *address* is a valid text form IPv6 address and return its
canonical text form. Addresses with scopes are rejected.
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
Raises ``dns.exception.SyntaxError`` if the text is not valid.
"""
return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))

View file

@ -393,7 +393,7 @@ class Message:
section_number = section
section = self.section_from_number(section_number)
elif isinstance(section, str):
section_number = MessageSection.from_text(section)
section_number = self._section_enum.from_text(section)
section = self.section_from_number(section_number)
else:
section_number = self.section_number(section)
@ -489,6 +489,34 @@ class Message:
rrset = None
return rrset
def section_count(self, section: SectionType) -> int:
"""Returns the number of records in the specified section.
*section*, an ``int`` section number, a ``str`` section name, or one of
the section attributes of this message. This specifies the
the section of the message to count. For example::
my_message.section_count(my_message.answer)
my_message.section_count(dns.message.ANSWER)
my_message.section_count("ANSWER")
"""
if isinstance(section, int):
section_number = section
section = self.section_from_number(section_number)
elif isinstance(section, str):
section_number = self._section_enum.from_text(section)
section = self.section_from_number(section_number)
else:
section_number = self.section_number(section)
count = sum(max(1, len(rrs)) for rrs in section)
if section_number == MessageSection.ADDITIONAL:
if self.opt is not None:
count += 1
if self.tsig is not None:
count += 1
return count
def _compute_opt_reserve(self) -> int:
"""Compute the size required for the OPT RR, padding excluded"""
if not self.opt:
@ -527,6 +555,8 @@ class Message:
max_size: int = 0,
multi: bool = False,
tsig_ctx: Optional[Any] = None,
prepend_length: bool = False,
prefer_truncation: bool = False,
**kw: Dict[str, Any],
) -> bytes:
"""Return a string containing the message in DNS compressed wire
@ -549,6 +579,15 @@ class Message:
*tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
ongoing TSIG context, used when signing zone transfers.
*prepend_length*, a ``bool``, should be set to ``True`` if the caller
wants the message length prepended to the message itself. This is
useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ).
*prefer_truncation*, a ``bool``, should be set to ``True`` if the caller
wants the message to be truncated if it would otherwise exceed the
maximum length. If the truncation occurs before the additional section,
the TC bit will be set.
Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
Returns a ``bytes``.
@ -570,14 +609,21 @@ class Message:
r.reserve(opt_reserve)
tsig_reserve = self._compute_tsig_reserve()
r.reserve(tsig_reserve)
for rrset in self.question:
r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
for rrset in self.answer:
r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
for rrset in self.authority:
r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
for rrset in self.additional:
r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
try:
for rrset in self.question:
r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
for rrset in self.answer:
r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
for rrset in self.authority:
r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
for rrset in self.additional:
r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
except dns.exception.TooBig:
if prefer_truncation:
if r.section < dns.renderer.ADDITIONAL:
r.flags |= dns.flags.TC
else:
raise
r.release_reserved()
if self.opt is not None:
r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
@ -598,7 +644,10 @@ class Message:
r.write_header()
if multi:
self.tsig_ctx = ctx
return r.get_wire()
wire = r.get_wire()
if prepend_length:
wire = len(wire).to_bytes(2, "big") + wire
return wire
@staticmethod
def _make_tsig(
@ -777,6 +826,8 @@ class Message:
if request_payload is None:
request_payload = payload
self.request_payload = request_payload
if pad < 0:
raise ValueError("pad must be non-negative")
self.pad = pad
@property
@ -826,7 +877,7 @@ class Message:
if wanted:
self.ednsflags |= dns.flags.DO
elif self.opt:
self.ednsflags &= ~dns.flags.DO
self.ednsflags &= ~int(dns.flags.DO)
def rcode(self) -> dns.rcode.Rcode:
"""Return the rcode.
@ -1035,7 +1086,6 @@ def _message_factory_from_opcode(opcode):
class _WireReader:
"""Wire format reader.
parser: the binary parser
@ -1335,7 +1385,6 @@ def from_wire(
class _TextReader:
"""Text format reader.
tok: the tokenizer.
@ -1768,30 +1817,34 @@ def make_response(
our_payload: int = 8192,
fudge: int = 300,
tsig_error: int = 0,
pad: Optional[int] = None,
) -> Message:
"""Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all
of the infrastructure required of a response, but none of the
content.
The message returned is really a response skeleton; it has all of the infrastructure
required of a response, but none of the content.
The response's question section is a shallow copy of the query's
question section, so the query's question RRsets should not be
changed.
The response's question section is a shallow copy of the query's question section,
so the query's question RRsets should not be changed.
*query*, a ``dns.message.Message``, the query to respond to.
*recursion_available*, a ``bool``, should RA be set in the response?
*our_payload*, an ``int``, the payload size to advertise in EDNS
responses.
*our_payload*, an ``int``, the payload size to advertise in EDNS responses.
*fudge*, an ``int``, the TSIG time fudge.
*tsig_error*, an ``int``, the TSIG error.
Returns a ``dns.message.Message`` object whose specific class is
appropriate for the query. For example, if query is a
``dns.update.UpdateMessage``, response will be too.
*pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
if not ``None`` add padding bytes to make the message size a multiple of *pad*.
Note that if padding is non-zero, an EDNS PADDING option will always be added to the
message. If ``None``, add padding following RFC 8467, namely if the request is
padded, pad the response to 468 otherwise do not pad.
Returns a ``dns.message.Message`` object whose specific class is appropriate for the
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
too.
"""
if query.flags & dns.flags.QR:
@ -1804,7 +1857,13 @@ def make_response(
response.set_opcode(query.opcode())
response.question = list(query.question)
if query.edns >= 0:
response.use_edns(0, 0, our_payload, query.payload)
if pad is None:
# Set response padding per RFC 8467
pad = 0
for option in query.options:
if option.otype == dns.edns.OptionType.PADDING:
pad = 468
response.use_edns(0, 0, our_payload, query.payload, pad=pad)
if query.had_tsig:
response.use_tsig(
query.keyring,

View file

@ -20,21 +20,23 @@
import copy
import encodings.idna # type: ignore
import functools
import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union
try:
import idna # type: ignore
have_idna_2008 = True
except ImportError: # pragma: no cover
have_idna_2008 = False
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import dns._features
import dns.enum
import dns.exception
import dns.immutable
import dns.wire
if dns._features.have("idna"):
import idna # type: ignore
have_idna_2008 = True
else: # pragma: no cover
have_idna_2008 = False
CompressType = Dict["Name", int]
@ -128,6 +130,10 @@ class IDNAException(dns.exception.DNSException):
super().__init__(*args, **kwargs)
class NeedSubdomainOfOrigin(dns.exception.DNSException):
"""An absolute name was provided that is not a subdomain of the specified origin."""
_escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
@ -350,7 +356,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
@dns.immutable.immutable
class Name:
"""A DNS name.
The dns.name.Name class represents a DNS name as a tuple of
@ -843,6 +848,42 @@ class Name:
raise NoParent
return Name(self.labels[1:])
def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
"""Return the maximal predecessor of *name* in the DNSSEC ordering in the zone
whose origin is *origin*, or return the longest name under *origin* if the
name is origin (i.e. wrap around to the longest name, which may still be
*origin* due to length considerations.
The relativity of the name is preserved, so if this name is relative
then the method will return a relative name, and likewise if this name
is absolute then the predecessor will be absolute.
*prefix_ok* indicates if prefixing labels is allowed, and
defaults to ``True``. Normally it is good to allow this, but if computing
a maximal predecessor at a zone cut point then ``False`` must be specified.
"""
return _handle_relativity_and_call(
_absolute_predecessor, self, origin, prefix_ok
)
def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
"""Return the minimal successor of *name* in the DNSSEC ordering in the zone
whose origin is *origin*, or return *origin* if the successor cannot be
computed due to name length limitations.
Note that *origin* is returned in the "too long" cases because wrapping
around to the origin is how NSEC records express "end of the zone".
The relativity of the name is preserved, so if this name is relative
then the method will return a relative name, and likewise if this name
is absolute then the successor will be absolute.
*prefix_ok* indicates if prefixing a new minimal label is allowed, and
defaults to ``True``. Normally it is good to allow this, but if computing
a minimal successor at a zone cut point then ``False`` must be specified.
"""
return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok)
#: The root name, '.'
root = Name([b""])
@ -1082,3 +1123,161 @@ def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
parser = dns.wire.Parser(message, current)
name = from_wire_parser(parser)
return (name, parser.current - current)
# RFC 4471 Support
_MINIMAL_OCTET = b"\x00"
_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET)
_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET])
_MAXIMAL_OCTET = b"\xff"
_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET)
_AT_SIGN_VALUE = ord("@")
_LEFT_SQUARE_BRACKET_VALUE = ord("[")
def _wire_length(labels):
return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0)
def _pad_to_max_name(name):
needed = 255 - _wire_length(name.labels)
new_labels = []
while needed > 64:
new_labels.append(_MAXIMAL_OCTET * 63)
needed -= 64
if needed >= 2:
new_labels.append(_MAXIMAL_OCTET * (needed - 1))
# Note we're already maximal in the needed == 1 case as while we'd like
# to add one more byte as a new label, we can't, as adding a new non-empty
# label requires at least 2 bytes.
new_labels = list(reversed(new_labels))
new_labels.extend(name.labels)
return Name(new_labels)
def _pad_to_max_label(label, suffix_labels):
length = len(label)
# We have to subtract one here to account for the length byte of label.
remaining = 255 - _wire_length(suffix_labels) - length - 1
if remaining <= 0:
# Shouldn't happen!
return label
needed = min(63 - length, remaining)
return label + _MAXIMAL_OCTET * needed
def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name:
# This is the RFC 4471 predecessor algorithm using the "absolute method" of section
# 3.1.1.
#
# Our caller must ensure that the name and origin are absolute, and that name is a
# subdomain of origin.
if name == origin:
return _pad_to_max_name(name)
least_significant_label = name[0]
if least_significant_label == _MINIMAL_OCTET:
return name.parent()
least_octet = least_significant_label[-1]
suffix_labels = name.labels[1:]
if least_octet == _MINIMAL_OCTET_VALUE:
new_labels = [least_significant_label[:-1]]
else:
octets = bytearray(least_significant_label)
octet = octets[-1]
if octet == _LEFT_SQUARE_BRACKET_VALUE:
octet = _AT_SIGN_VALUE
else:
octet -= 1
octets[-1] = octet
least_significant_label = bytes(octets)
new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)]
new_labels.extend(suffix_labels)
name = Name(new_labels)
if prefix_ok:
return _pad_to_max_name(name)
else:
return name
def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name:
# This is the RFC 4471 successor algorithm using the "absolute method" of section
# 3.1.2.
#
# Our caller must ensure that the name and origin are absolute, and that name is a
# subdomain of origin.
if prefix_ok:
# Try prefixing \000 as new label
try:
return _SUCCESSOR_PREFIX.concatenate(name)
except NameTooLong:
pass
while name != origin:
# Try extending the least significant label.
least_significant_label = name[0]
if len(least_significant_label) < 63:
# We may be able to extend the least label with a minimal additional byte.
# This is only "may" because we could have a maximal length name even though
# the least significant label isn't maximally long.
new_labels = [least_significant_label + _MINIMAL_OCTET]
new_labels.extend(name.labels[1:])
try:
return dns.name.Name(new_labels)
except dns.name.NameTooLong:
pass
# We can't extend the label either, so we'll try to increment the least
# signficant non-maximal byte in it.
octets = bytearray(least_significant_label)
# We do this reversed iteration with an explicit indexing variable because
# if we find something to increment, we're going to want to truncate everything
# to the right of it.
for i in range(len(octets) - 1, -1, -1):
octet = octets[i]
if octet == _MAXIMAL_OCTET_VALUE:
# We can't increment this, so keep looking.
continue
# Finally, something we can increment. We have to apply a special rule for
# incrementing "@", sending it to "[", because RFC 4034 6.1 says that when
# comparing names, uppercase letters compare as if they were their
# lower-case equivalents. If we increment "@" to "A", then it would compare
# as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have
# skipped the most minimal successor, namely "[".
if octet == _AT_SIGN_VALUE:
octet = _LEFT_SQUARE_BRACKET_VALUE
else:
octet += 1
octets[i] = octet
# We can now truncate all of the maximal values we skipped (if any)
new_labels = [bytes(octets[: i + 1])]
new_labels.extend(name.labels[1:])
# We haven't changed the length of the name, so the Name constructor will
# always work.
return Name(new_labels)
# We couldn't increment, so chop off the least significant label and try
# again.
name = name.parent()
# We couldn't increment at all, so return the origin, as wrapping around is the
# DNSSEC way.
return origin
def _handle_relativity_and_call(
function: Callable[[Name, Name, bool], Name],
name: Name,
origin: Name,
prefix_ok: bool,
) -> Name:
# Make "name" absolute if needed, ensure that the origin is absolute,
# call function(), and then relativize the result if needed.
if not origin.is_absolute():
raise NeedAbsoluteNameOrOrigin
relative = not name.is_absolute()
if relative:
name = name.derelativize(origin)
elif not name.is_subdomain(origin):
raise NeedSubdomainOfOrigin
result_name = function(name, origin, prefix_ok)
if relative:
result_name = result_name.relativize(origin)
return result_name

View file

@ -115,6 +115,8 @@ class Do53Nameserver(AddressAndPortNameserver):
raise_on_truncation=True,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
)
return response
@ -153,15 +155,25 @@ class Do53Nameserver(AddressAndPortNameserver):
backend=backend,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
ignore_errors=True,
ignore_unexpected=True,
)
return response
class DoHNameserver(Nameserver):
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
def __init__(
self,
url: str,
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
def kind(self):
return "DoH"
@ -195,9 +207,13 @@ class DoHNameserver(Nameserver):
request,
self.url,
timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
)
async def async_query(
@ -215,15 +231,27 @@ class DoHNameserver(Nameserver):
request,
self.url,
timeout=timeout,
source=source,
source_port=source_port,
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
)
class DoTNameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
def __init__(
self,
address: str,
port: int = 853,
hostname: Optional[str] = None,
verify: Union[bool, str] = True,
):
super().__init__(address, port)
self.hostname = hostname
self.verify = verify
def kind(self):
return "DoT"
@ -246,6 +274,7 @@ class DoTNameserver(AddressAndPortNameserver):
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)
async def async_query(
@ -267,6 +296,7 @@ class DoTNameserver(AddressAndPortNameserver):
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)

View file

@ -70,7 +70,6 @@ class NodeKind(enum.Enum):
class Node:
"""A Node is a set of rdatasets.
A node is either a CNAME node or an "other data" node. A CNAME

View file

@ -22,12 +22,14 @@ import contextlib
import enum
import errno
import os
import os.path
import selectors
import socket
import struct
import time
from typing import Any, Dict, Optional, Tuple, Union
import dns._features
import dns.exception
import dns.inet
import dns.message
@ -57,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration):
return min(time.time() + timeout, expiration)
_have_httpx = False
_have_http2 = False
try:
import httpcore
_have_httpx = dns._features.have("doh")
if _have_httpx:
import httpcore._backends.sync
import httpx
_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream
_have_httpx = True
try:
# See if http2 support is available.
with httpx.Client(http2=True):
_have_http2 = True
except Exception:
pass
class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
@ -147,7 +139,7 @@ try:
resolver, local_port, bootstrap_address, family
)
except ImportError: # pragma: no cover
else:
class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address):
@ -161,6 +153,8 @@ try:
except ImportError: # pragma: no cover
class ssl: # type: ignore
CERT_NONE = 0
class WantReadException(Exception):
pass
@ -459,7 +453,7 @@ def https(
transport = _HTTPTransport(
local_address=local_address,
http1=True,
http2=_have_http2,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
@ -470,9 +464,7 @@ def https(
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
cm = httpx.Client(
http1=True, http2=_have_http2, verify=verify, transport=transport
)
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
@ -577,6 +569,8 @@ def receive_udp(
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
ignore_errors: bool = False,
query: Optional[dns.message.Message] = None,
) -> Any:
"""Read a DNS message from a UDP socket.
@ -617,28 +611,58 @@ def receive_udp(
``(dns.message.Message, float, tuple)``
tuple of the received message, the received time, and the address where
the message arrived from.
*ignore_errors*, a ``bool``. If various format errors or response
mismatches occur, ignore them and keep listening for a valid response.
The default is ``False``.
*query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
*ignore_errors* is ``True``, check that the received message is a response
to this query, and if not keep listening for a valid response.
"""
wire = b""
while True:
(wire, from_address) = _udp_recv(sock, 65535, expiration)
if _matches_destination(
if not _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
break
received_time = time.time()
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
if destination:
return (r, received_time)
else:
return (r, received_time, from_address)
continue
received_time = time.time()
try:
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
except dns.message.Truncated as e:
# If we got Truncated and not FORMERR, we at least got the header with TC
# set, and very likely the question section, so we'll re-raise if the
# message seems to be a response as we need to know when truncation happens.
# We need to check that it seems to be a response as we don't want a random
# injected message with TC set to cause us to bail out.
if (
ignore_errors
and query is not None
and not query.is_response(e.message())
):
continue
else:
raise
except Exception:
if ignore_errors:
continue
else:
raise
if ignore_errors and query is not None and not query.is_response(r):
continue
if destination:
return (r, received_time)
else:
return (r, received_time, from_address)
def udp(
@ -653,6 +677,7 @@ def udp(
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: Optional[Any] = None,
ignore_errors: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
@ -689,6 +714,10 @@ def udp(
if a socket is provided, it must be a nonblocking datagram socket,
and the *source* and *source_port* are ignored.
*ignore_errors*, a ``bool``. If various format errors or response
mismatches occur, ignore them and keep listening for a valid response.
The default is ``False``.
Returns a ``dns.message.Message``.
"""
@ -713,9 +742,13 @@ def udp(
q.mac,
ignore_trailing,
raise_on_truncation,
ignore_errors,
q,
)
r.time = received_time - begin_time
if not q.is_response(r):
# We don't need to check q.is_response() if we are in ignore_errors mode
# as receive_udp() will have checked it.
if not (ignore_errors or q.is_response(r)):
raise BadResponse
return r
assert (
@ -735,48 +768,50 @@ def udp_with_fallback(
ignore_trailing: bool = False,
udp_sock: Optional[Any] = None,
tcp_sock: Optional[Any] = None,
ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*source_port*, an ``int``, the port from which to send the message. The default is
0.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
unexpected sources.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected
sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
UDP query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking datagram socket,
and the *source* and *source_port* are ignored for the UDP query.
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
If ``None``, the default, a socket is created. Note that if a socket is provided,
it must be a nonblocking datagram socket, and the *source* and *source_port* are
ignored for the UDP query.
*tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
TCP query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking connected stream
socket, and *where*, *source* and *source_port* are ignored for the TCP
query.
TCP query. If ``None``, the default, a socket is created. Note that if a socket is
provided, it must be a nonblocking connected stream socket, and *where*, *source*
and *source_port* are ignored for the TCP query.
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
if and only if TCP was used.
*ignore_errors*, a ``bool``. If various format errors or response mismatches occur
while listening for UDP, ignore them and keep listening for a valid response. The
default is ``False``.
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
TCP was used.
"""
try:
response = udp(
@ -791,6 +826,7 @@ def udp_with_fallback(
ignore_trailing,
True,
udp_sock,
ignore_errors,
)
return (response, False)
except dns.message.Truncated:
@ -864,14 +900,12 @@ def send_tcp(
"""
if isinstance(what, dns.message.Message):
wire = what.to_wire()
tcpmsg = what.to_wire(prepend_length=True)
else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + wire
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time()
_net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time)
@ -1014,6 +1048,28 @@ def _tls_handshake(s, expiration):
_wait_for_writable(s, expiration)
def _make_dot_ssl_context(
server_hostname: Optional[str], verify: Union[bool, str]
) -> ssl.SSLContext:
cafile: Optional[str] = None
capath: Optional[str] = None
if isinstance(verify, str):
if os.path.isfile(verify):
cafile = verify
elif os.path.isdir(verify):
capath = verify
else:
raise ValueError("invalid verify string")
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context.set_alpn_protocols(["dot"])
if verify is False:
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context
def tls(
q: dns.message.Message,
where: str,
@ -1026,6 +1082,7 @@ def tls(
sock: Optional[ssl.SSLSocket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
@ -1065,6 +1122,11 @@ def tls(
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
Returns a ``dns.message.Message``.
"""
@ -1091,10 +1153,7 @@ def tls(
where, port, source, source_port
)
if ssl_context is None and not sock:
ssl_context = ssl.create_default_context()
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context = _make_dot_ssl_context(server_hostname, verify)
with _make_socket(
af,

View file

@ -1,9 +1,11 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
try:
import dns._features
import dns.asyncbackend
if dns._features.have("doq"):
import aioquic.quic.configuration # type: ignore
import dns.asyncbackend
from dns._asyncbackend import NullContext
from dns.quic._asyncio import (
AsyncioQuicConnection,
@ -17,7 +19,7 @@ try:
def null_factory(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
return NullContext(None)
@ -31,7 +33,7 @@ try:
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
try:
if dns._features.have("trio"):
import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
@ -47,15 +49,13 @@ try:
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
except ImportError:
pass
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
except ImportError:
else: # pragma: no cover
have_quic = False
from typing import Any

View file

@ -101,9 +101,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
)
if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue
self._connection.receive_datagram(
datagram, self._peer[0], time.time()
)
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
async with self._wake_timer:
@ -125,7 +123,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams:
assert address == self._peer[0]
assert address == self._peer
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
@ -147,11 +145,14 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
self._receiver_task.cancel()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
@ -188,7 +189,6 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._socket.close()
async with self._wake_timer:
self._wake_timer.notify_all()
try:
@ -199,14 +199,19 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await self._sender_task
except asyncio.CancelledError:
pass
await self._socket.close()
class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection

View file

@ -1,5 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import copy
import functools
import socket
import struct
import time
@ -11,6 +13,10 @@ import aioquic.quic.connection # type: ignore
import dns.inet
QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
# If we hit the max sessions limit we will delete this many of the oldest connections.
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
class UnexpectedEOF(Exception):
@ -79,7 +85,10 @@ class BaseQuicStream:
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
return self._expecting > 0 and self._buffer.have(self._expecting)
try:
return self._expecting > 0 and self._buffer.have(self._expecting)
except UnexpectedEOF:
return True
def _close(self):
self._connection.close_stream(self._stream_id)
@ -142,6 +151,7 @@ class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
@ -156,12 +166,35 @@ class BaseQuicManager:
conf.load_verify_locations(verify_path)
self._conf = conf
def _connect(self, address, port=853, source=None, source_port=0):
def _connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
qconn.connect(address, time.time())
conf = self._conf
if want_session_ticket:
try:
session_ticket = self._session_tickets.pop((address, port))
# We found a session ticket, so make a configuration that uses it.
conf = copy.copy(conf)
conf.session_ticket = session_ticket
except KeyError:
# No session ticket.
pass
# Whether or not we found a session ticket, we want a handler to save
# one.
session_ticket_handler = functools.partial(
self.save_session_ticket, address, port
)
else:
session_ticket_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory(
qconn, address, port, source, source_port, self
)
@ -174,6 +207,17 @@ class BaseQuicManager:
except KeyError:
pass
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._session_tickets)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):

View file

@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
if self._source is not None:
try:
self._socket.bind(
@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection):
except Exception:
self._socket.close()
raise
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
@ -107,7 +107,7 @@ class SyncQuicConnection(BaseQuicConnection):
except BlockingIOError:
return
with self._lock:
self._connection.receive_datagram(datagram, self._peer[0], time.time())
self._connection.receive_datagram(datagram, self._peer, time.time())
def _drain_wakeup(self):
while True:
@ -128,6 +128,8 @@ class SyncQuicConnection(BaseQuicConnection):
key.data()
with self._lock:
self._handle_timer(expiration)
self._handle_events()
with self._lock:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
try:
@ -135,7 +137,6 @@ class SyncQuicConnection(BaseQuicConnection):
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
finally:
with self._lock:
self._done = True
@ -155,11 +156,14 @@ class SyncQuicConnection(BaseQuicConnection):
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
with self._lock:
self._done = True
elif isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(b"", True)
def write(self, stream, data, is_end=False):
with self._lock:
@ -203,9 +207,13 @@ class SyncQuicManager(BaseQuicManager):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
self._lock = threading.Lock()
def connect(self, address, port=853, source=None, source_port=0):
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
with self._lock:
(connection, start) = self._connect(address, port, source, source_port)
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection
@ -214,6 +222,10 @@ class SyncQuicManager(BaseQuicManager):
with self._lock:
super().closed(address, port)
def save_session_ticket(self, address, port, ticket):
with self._lock:
super().save_session_ticket(address, port, ticket)
def __enter__(self):
return self

View file

@ -76,30 +76,43 @@ class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source:
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
self._send_pending = False
async def _worker(self):
try:
if self._source:
await self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
if self._send_pending:
# Do not block forever if sends are pending. Even though we
# have a wake-up mechanism if we've already started the blocking
# read, the possibility of context switching in send means that
# more writes can happen while we have no wake up context, so
# we need self._send_pending to avoid (effectively) a "lost wakeup"
# race.
interval = 0.0
with trio.CancelScope(
deadline=trio.current_time() + interval
) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram(
datagram, self._peer[0], time.time()
)
self._connection.receive_datagram(datagram, self._peer, time.time())
self._worker_scope = None
self._handle_timer(expiration)
await self._handle_events()
# We clear this now, before sending anything, as sending can cause
# context switches that do more sends. We want to know if that
# happens so we don't block a long time on the recv() above.
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
await self._socket.send(datagram)
await self._handle_events()
finally:
self._done = True
self._handshake_complete.set()
@ -116,11 +129,13 @@ class TrioQuicConnection(AsyncQuicConnection):
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
self._socket.close()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
@ -129,6 +144,7 @@ class TrioQuicConnection(AsyncQuicConnection):
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
@ -159,6 +175,7 @@ class TrioQuicConnection(AsyncQuicConnection):
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
await self._run_done.wait()
@ -171,8 +188,12 @@ class TrioQuicManager(AsyncQuicManager):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
self._nursery = nursery
def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
self._nursery.start_soon(connection.run)
return connection

View file

@ -199,7 +199,7 @@ class Rdata:
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
**kw: Dict[str, Any],
) -> str:
"""Convert an rdata to text format.
@ -547,9 +547,7 @@ class Rdata:
@classmethod
def _as_ipv4_address(cls, value):
if isinstance(value, str):
# call to check validity
dns.ipv4.inet_aton(value)
return value
return dns.ipv4.canonicalize(value)
elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value)
else:
@ -558,9 +556,7 @@ class Rdata:
@classmethod
def _as_ipv6_address(cls, value):
if isinstance(value, str):
# call to check validity
dns.ipv6.inet_aton(value)
return value
return dns.ipv6.canonicalize(value)
elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value)
else:
@ -604,7 +600,6 @@ class Rdata:
@dns.immutable.immutable
class GenericRdata(Rdata):
"""Generic Rdata Class
This class is used for rdata types for which we have no better
@ -621,7 +616,7 @@ class GenericRdata(Rdata):
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
**kw: Dict[str, Any],
) -> str:
return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
@ -647,9 +642,9 @@ class GenericRdata(Rdata):
return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes: Dict[
Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any
] = {}
_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = (
{}
)
_module_prefix = "dns.rdtypes"

View file

@ -28,6 +28,7 @@ import dns.name
import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.renderer
import dns.set
import dns.ttl
@ -45,7 +46,6 @@ class IncompatibleTypes(dns.exception.DNSException):
class Rdataset(dns.set.Set):
"""A DNS rdataset."""
__slots__ = ["rdclass", "rdtype", "covers", "ttl"]
@ -316,11 +316,9 @@ class Rdataset(dns.set.Set):
want_shuffle = False
else:
rdclass = self.rdclass
file.seek(0, io.SEEK_END)
if len(self) == 0:
name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)
file.write(stuff)
file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
return 1
else:
l: Union[Rdataset, List[dns.rdata.Rdata]]
@ -331,16 +329,9 @@ class Rdataset(dns.set.Set):
l = self
for rd in l:
name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
file.write(stuff)
start = file.tell()
rd.to_wire(file, compress, origin)
end = file.tell()
assert end - start < 65536
file.seek(start - 2)
stuff = struct.pack("!H", end - start)
file.write(stuff)
file.seek(0, io.SEEK_END)
file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
with dns.renderer.prefixed_length(file, 2):
rd.to_wire(file, compress, origin)
return len(self)
def match(
@ -373,7 +364,6 @@ class Rdataset(dns.set.Set):
@dns.immutable.immutable
class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
"""An immutable DNS rdataset."""
_clone_class = Rdataset

View file

@ -21,7 +21,6 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable
class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""AFSDB record"""
# Use the property mechanism to make "subtype" an alias for the

View file

@ -32,7 +32,6 @@ class Relay(dns.rdtypes.util.Gateway):
@dns.immutable.immutable
class AMTRELAY(dns.rdata.Rdata):
"""AMTRELAY record"""
# see: RFC 8777

View file

@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class AVC(dns.rdtypes.txtbase.TXTBase):
"""AVC record"""
# See: IANA dns parameters for AVC

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable
class CAA(dns.rdata.Rdata):
"""CAA (Certification Authority Authorization) record"""
# see: RFC 6844

View file

@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
@dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""CDNSKEY record"""

View file

@ -21,7 +21,6 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable
class CDS(dns.rdtypes.dsbase.DSBase):
"""CDS record"""
_digest_length_by_type = {

View file

@ -67,7 +67,6 @@ def _ctype_to_text(what):
@dns.immutable.immutable
class CERT(dns.rdata.Rdata):
"""CERT record"""
# see RFC 4398

View file

@ -21,7 +21,6 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable
class CNAME(dns.rdtypes.nsbase.NSBase):
"""CNAME record
Note: although CNAME is officially a singleton type, dnspython allows

View file

@ -32,7 +32,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
@dns.immutable.immutable
class CSYNC(dns.rdata.Rdata):
"""CSYNC record"""
__slots__ = ["serial", "flags", "windows"]

View file

@ -21,5 +21,4 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable
class DLV(dns.rdtypes.dsbase.DSBase):
"""DLV record"""

View file

@ -21,7 +21,6 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable
class DNAME(dns.rdtypes.nsbase.UncompressedNS):
"""DNAME record"""
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
@dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""DNSKEY record"""

View file

@ -21,5 +21,4 @@ import dns.rdtypes.dsbase
@dns.immutable.immutable
class DS(dns.rdtypes.dsbase.DSBase):
"""DS record"""

View file

@ -22,7 +22,6 @@ import dns.rdtypes.euibase
@dns.immutable.immutable
class EUI48(dns.rdtypes.euibase.EUIBase):
"""EUI48 record"""
# see: rfc7043.txt

View file

@ -22,7 +22,6 @@ import dns.rdtypes.euibase
@dns.immutable.immutable
class EUI64(dns.rdtypes.euibase.EUIBase):
"""EUI64 record"""
# see: rfc7043.txt

View file

@ -44,7 +44,6 @@ def _validate_float_string(what):
@dns.immutable.immutable
class GPOS(dns.rdata.Rdata):
"""GPOS record"""
# see: RFC 1712

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable
class HINFO(dns.rdata.Rdata):
"""HINFO record"""
# see: RFC 1035

View file

@ -27,7 +27,6 @@ import dns.rdatatype
@dns.immutable.immutable
class HIP(dns.rdata.Rdata):
"""HIP record"""
# see: RFC 5205

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable
class ISDN(dns.rdata.Rdata):
"""ISDN record"""
# see: RFC 1183

View file

@ -8,7 +8,6 @@ import dns.rdata
@dns.immutable.immutable
class L32(dns.rdata.Rdata):
"""L32 record"""
# see: rfc6742.txt

View file

@ -8,7 +8,6 @@ import dns.rdtypes.util
@dns.immutable.immutable
class L64(dns.rdata.Rdata):
"""L64 record"""
# see: rfc6742.txt

View file

@ -105,7 +105,6 @@ def _check_coordinate_list(value, low, high):
@dns.immutable.immutable
class LOC(dns.rdata.Rdata):
"""LOC record"""
# see: RFC 1876

View file

@ -8,7 +8,6 @@ import dns.rdata
@dns.immutable.immutable
class LP(dns.rdata.Rdata):
"""LP record"""
# see: rfc6742.txt

View file

@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable
class MX(dns.rdtypes.mxbase.MXBase):
"""MX record"""

View file

@ -8,7 +8,6 @@ import dns.rdtypes.util
@dns.immutable.immutable
class NID(dns.rdata.Rdata):
"""NID record"""
# see: rfc6742.txt

View file

@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class NINFO(dns.rdtypes.txtbase.TXTBase):
"""NINFO record"""
# see: draft-reid-dnsext-zs-01

View file

@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable
class NS(dns.rdtypes.nsbase.NSBase):
"""NS record"""

View file

@ -30,7 +30,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
@dns.immutable.immutable
class NSEC(dns.rdata.Rdata):
"""NSEC record"""
__slots__ = ["next", "windows"]

View file

@ -46,7 +46,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
@dns.immutable.immutable
class NSEC3(dns.rdata.Rdata):
"""NSEC3 record"""
__slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
@ -64,9 +63,13 @@ class NSEC3(dns.rdata.Rdata):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw):
def _next_text(self):
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
next = next.rstrip("=")
return next
def to_text(self, origin=None, relativize=True, **kw):
next = self._next_text()
if self.salt == b"":
salt = "-"
else:
@ -118,3 +121,6 @@ class NSEC3(dns.rdata.Rdata):
next = parser.get_counted_bytes()
bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
def next_name(self, origin=None):
return dns.name.from_text(self._next_text(), origin)

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable
class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record"""
__slots__ = ["algorithm", "flags", "iterations", "salt"]

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable
class OPENPGPKEY(dns.rdata.Rdata):
"""OPENPGPKEY record"""
# see: RFC 7929

View file

@ -28,7 +28,6 @@ import dns.rdata
@dns.immutable.immutable
class OPT(dns.rdata.Rdata):
"""OPT record"""
__slots__ = ["options"]

View file

@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable
class PTR(dns.rdtypes.nsbase.NSBase):
"""PTR record"""

View file

@ -23,7 +23,6 @@ import dns.rdata
@dns.immutable.immutable
class RP(dns.rdata.Rdata):
"""RP record"""
# see: RFC 1183

View file

@ -28,7 +28,6 @@ import dns.rdatatype
class BadSigTime(dns.exception.DNSException):
"""Time in DNS SIG or RRSIG resource record cannot be parsed."""
@ -52,7 +51,6 @@ def posixtime_to_sigtime(what):
@dns.immutable.immutable
class RRSIG(dns.rdata.Rdata):
"""RRSIG record"""
__slots__ = [

View file

@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable
class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""RT record"""

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable
class SOA(dns.rdata.Rdata):
"""SOA record"""
# see: RFC 1035

View file

@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class SPF(dns.rdtypes.txtbase.TXTBase):
"""SPF record"""
# see: RFC 4408

View file

@ -25,7 +25,6 @@ import dns.rdatatype
@dns.immutable.immutable
class SSHFP(dns.rdata.Rdata):
"""SSHFP record"""
# See RFC 4255

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable
class TKEY(dns.rdata.Rdata):
"""TKEY Record"""
__slots__ = [

View file

@ -6,5 +6,4 @@ import dns.rdtypes.tlsabase
@dns.immutable.immutable
class TLSA(dns.rdtypes.tlsabase.TLSABase):
"""TLSA record"""

View file

@ -26,7 +26,6 @@ import dns.rdata
@dns.immutable.immutable
class TSIG(dns.rdata.Rdata):
"""TSIG record"""
__slots__ = [

View file

@ -21,5 +21,4 @@ import dns.rdtypes.txtbase
@dns.immutable.immutable
class TXT(dns.rdtypes.txtbase.TXTBase):
"""TXT record"""

View file

@ -27,7 +27,6 @@ import dns.rdtypes.util
@dns.immutable.immutable
class URI(dns.rdata.Rdata):
"""URI record"""
# see RFC 7553

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable
class X25(dns.rdata.Rdata):
"""X25 record"""
# see RFC 1183

View file

@ -11,7 +11,6 @@ import dns.zonetypes
@dns.immutable.immutable
class ZONEMD(dns.rdata.Rdata):
"""ZONEMD record"""
# See RFC 8976

View file

@ -23,7 +23,6 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable
class A(dns.rdata.Rdata):
"""A record for Chaosnet"""
# domain: the domain of the address

View file

@ -24,7 +24,6 @@ import dns.tokenizer
@dns.immutable.immutable
class A(dns.rdata.Rdata):
"""A record."""
__slots__ = ["address"]

View file

@ -24,7 +24,6 @@ import dns.tokenizer
@dns.immutable.immutable
class AAAA(dns.rdata.Rdata):
"""AAAA record."""
__slots__ = ["address"]

View file

@ -29,7 +29,6 @@ import dns.tokenizer
@dns.immutable.immutable
class APLItem:
"""An APL list item."""
__slots__ = ["family", "negation", "address", "prefix"]
@ -80,7 +79,6 @@ class APLItem:
@dns.immutable.immutable
class APL(dns.rdata.Rdata):
"""APL record."""
# see: RFC 3123

View file

@ -24,7 +24,6 @@ import dns.rdata
@dns.immutable.immutable
class DHCID(dns.rdata.Rdata):
"""DHCID record"""
# see: RFC 4701

View file

@ -29,7 +29,6 @@ class Gateway(dns.rdtypes.util.Gateway):
@dns.immutable.immutable
class IPSECKEY(dns.rdata.Rdata):
"""IPSECKEY record"""
# see: RFC 4025

View file

@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
@dns.immutable.immutable
class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""KX record"""

View file

@ -33,7 +33,6 @@ def _write_string(file, s):
@dns.immutable.immutable
class NAPTR(dns.rdata.Rdata):
"""NAPTR record"""
# see: RFC 3403

View file

@ -25,7 +25,6 @@ import dns.tokenizer
@dns.immutable.immutable
class NSAP(dns.rdata.Rdata):
"""NSAP record."""
# see: RFC 1706

View file

@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
@dns.immutable.immutable
class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS):
"""NSAP-PTR record"""

View file

@ -26,7 +26,6 @@ import dns.rdtypes.util
@dns.immutable.immutable
class PX(dns.rdata.Rdata):
"""PX record."""
# see: RFC 2163

View file

@ -26,7 +26,6 @@ import dns.rdtypes.util
@dns.immutable.immutable
class SRV(dns.rdata.Rdata):
"""SRV record"""
# see: RFC 2782

View file

@ -33,7 +33,6 @@ except OSError:
@dns.immutable.immutable
class WKS(dns.rdata.Rdata):
"""WKS record"""
# see: RFC 1035

View file

@ -36,7 +36,6 @@ class Flag(enum.IntFlag):
@dns.immutable.immutable
class DNSKEYBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DNSKEY record"""
__slots__ = ["flags", "protocol", "algorithm", "key"]

View file

@ -26,7 +26,6 @@ import dns.rdatatype
@dns.immutable.immutable
class DSBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DS record"""
__slots__ = ["key_tag", "algorithm", "digest_type", "digest"]

View file

@ -22,7 +22,6 @@ import dns.rdata
@dns.immutable.immutable
class EUIBase(dns.rdata.Rdata):
"""EUIxx record"""
# see: rfc7043.txt

View file

@ -28,7 +28,6 @@ import dns.rdtypes.util
@dns.immutable.immutable
class MXBase(dns.rdata.Rdata):
"""Base class for rdata that is like an MX record."""
__slots__ = ["preference", "exchange"]
@ -71,7 +70,6 @@ class MXBase(dns.rdata.Rdata):
@dns.immutable.immutable
class UncompressedMX(MXBase):
"""Base class for rdata that is like an MX record, but whose name
is not compressed when converted to DNS wire format, and whose
digestable form is not downcased."""
@ -82,7 +80,6 @@ class UncompressedMX(MXBase):
@dns.immutable.immutable
class UncompressedDowncasingMX(MXBase):
"""Base class for rdata that is like an MX record, but whose name
is not compressed when convert to DNS wire format."""

View file

@ -25,7 +25,6 @@ import dns.rdata
@dns.immutable.immutable
class NSBase(dns.rdata.Rdata):
"""Base class for rdata that is like an NS record."""
__slots__ = ["target"]
@ -56,7 +55,6 @@ class NSBase(dns.rdata.Rdata):
@dns.immutable.immutable
class UncompressedNS(NSBase):
"""Base class for rdata that is like an NS record, but whose name
is not compressed when convert to DNS wire format, and whose
digestable form is not downcased."""

View file

@ -2,7 +2,6 @@
import base64
import enum
import io
import struct
import dns.enum
@ -13,6 +12,7 @@ import dns.ipv6
import dns.name
import dns.rdata
import dns.rdtypes.util
import dns.renderer
import dns.tokenizer
import dns.wire
@ -427,7 +427,6 @@ def _validate_and_define(params, key, value):
@dns.immutable.immutable
class SVCBBase(dns.rdata.Rdata):
"""Base class for SVCB-like records"""
# see: draft-ietf-dnsop-svcb-https-11
@ -521,19 +520,10 @@ class SVCBBase(dns.rdata.Rdata):
for key in sorted(self.params):
file.write(struct.pack("!H", key))
value = self.params[key]
# placeholder for length (or actual length of empty values)
file.write(struct.pack("!H", 0))
if value is None:
continue
else:
start = file.tell()
value.to_wire(file, origin)
end = file.tell()
assert end - start < 65536
file.seek(start - 2)
stuff = struct.pack("!H", end - start)
file.write(stuff)
file.seek(0, io.SEEK_END)
with dns.renderer.prefixed_length(file, 2):
# Note that we're still writing a length of zero if the value is None
if value is not None:
value.to_wire(file, origin)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -25,7 +25,6 @@ import dns.rdatatype
@dns.immutable.immutable
class TLSABase(dns.rdata.Rdata):
"""Base class for TLSA and SMIMEA records"""
# see: RFC 6698

View file

@ -17,18 +17,17 @@
"""TXT-like base class."""
import struct
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import dns.exception
import dns.immutable
import dns.rdata
import dns.renderer
import dns.tokenizer
@dns.immutable.immutable
class TXTBase(dns.rdata.Rdata):
"""Base class for rdata that is like a TXT record (see RFC 1035)."""
__slots__ = ["strings"]
@ -56,7 +55,7 @@ class TXTBase(dns.rdata.Rdata):
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
**kw: Dict[str, Any],
) -> str:
txt = ""
prefix = ""
@ -93,10 +92,8 @@ class TXTBase(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
for s in self.strings:
l = len(s)
assert l < 256
file.write(struct.pack("!B", l))
file.write(s)
with dns.renderer.prefixed_length(file, 1):
file.write(s)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -32,6 +32,24 @@ AUTHORITY = 2
ADDITIONAL = 3
@contextlib.contextmanager
def prefixed_length(output, length_length):
output.write(b"\00" * length_length)
start = output.tell()
yield
end = output.tell()
length = end - start
if length > 0:
try:
output.seek(start - length_length)
try:
output.write(length.to_bytes(length_length, "big"))
except OverflowError:
raise dns.exception.FormError
finally:
output.seek(end)
class Renderer:
"""Helper class for building DNS wire-format messages.
@ -134,6 +152,15 @@ class Renderer:
self._rollback(start)
raise dns.exception.TooBig
@contextlib.contextmanager
def _temporarily_seek_to(self, where):
current = self.output.tell()
try:
self.output.seek(where)
yield
finally:
self.output.seek(current)
def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
"""Add a question to the message."""
@ -269,18 +296,14 @@ class Renderer:
with self._track_size():
keyname.to_wire(self.output, compress, self.origin)
self.output.write(
struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)
struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0)
)
rdata_start = self.output.tell()
tsig.to_wire(self.output)
with prefixed_length(self.output, 2):
tsig.to_wire(self.output)
after = self.output.tell()
self.output.seek(rdata_start - 2)
self.output.write(struct.pack("!H", after - rdata_start))
self.counts[ADDITIONAL] += 1
self.output.seek(10)
self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
self.output.seek(0, io.SEEK_END)
with self._temporarily_seek_to(10):
self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
def write_header(self):
"""Write the DNS message header.
@ -290,19 +313,18 @@ class Renderer:
is added.
"""
self.output.seek(0)
self.output.write(
struct.pack(
"!HHHHHH",
self.id,
self.flags,
self.counts[0],
self.counts[1],
self.counts[2],
self.counts[3],
with self._temporarily_seek_to(0):
self.output.write(
struct.pack(
"!HHHHHH",
self.id,
self.flags,
self.counts[0],
self.counts[1],
self.counts[2],
self.counts[3],
)
)
)
self.output.seek(0, io.SEEK_END)
def get_wire(self):
"""Return the wire format message."""

View file

@ -26,7 +26,6 @@ import dns.renderer
class RRset(dns.rdataset.Rdataset):
"""A DNS RRset (named rdataset).
RRset inherits from Rdataset, and RRsets can be treated as
@ -132,7 +131,7 @@ class RRset(dns.rdataset.Rdataset):
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
**kw: Dict[str, Any],
) -> str:
"""Convert the RRset into DNS zone file format.
@ -159,7 +158,7 @@ class RRset(dns.rdataset.Rdataset):
file: Any,
compress: Optional[dns.name.CompressType] = None, # type: ignore
origin: Optional[dns.name.Name] = None,
**kw: Dict[str, Any]
**kw: Dict[str, Any],
) -> int:
"""Convert the RRset to wire format.
@ -231,7 +230,7 @@ def from_text(
ttl: int,
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
*text_rdatas: Any
*text_rdatas: Any,
) -> RRset:
"""Create an RRset with the specified name, TTL, class, and type and with
the specified rdatas in text format.

View file

@ -19,7 +19,6 @@ import itertools
class Set:
"""A simple set class.
This class was originally used to deal with sets being missing in

View file

@ -203,7 +203,7 @@ class Transaction:
- name
- name, rdataclass, rdatatype, [covers]
- name, rdatatype, [covers]
- name, rdataset...
@ -222,7 +222,7 @@ class Transaction:
- name
- name, rdataclass, rdatatype, [covers]
- name, rdatatype, [covers]
- name, rdataset...

View file

@ -29,47 +29,38 @@ import dns.rdataclass
class BadTime(dns.exception.DNSException):
"""The current time is not within the TSIG's validity time."""
class BadSignature(dns.exception.DNSException):
"""The TSIG signature fails to verify."""
class BadKey(dns.exception.DNSException):
"""The TSIG record owner name does not match the key."""
class BadAlgorithm(dns.exception.DNSException):
"""The TSIG algorithm does not match the key."""
class PeerError(dns.exception.DNSException):
"""Base class for all TSIG errors generated by the remote peer"""
class PeerBadKey(PeerError):
"""The peer didn't know the key we used"""
class PeerBadSignature(PeerError):
"""The peer didn't like the signature we sent"""
class PeerBadTime(PeerError):
"""The peer didn't like the time we sent"""
class PeerBadTruncation(PeerError):
"""The peer didn't like amount of truncation in the TSIG we sent"""

View file

@ -20,9 +20,9 @@
#: MAJOR
MAJOR = 2
#: MINOR
MINOR = 4
MINOR = 6
#: MICRO
MICRO = 2
MICRO = 1
#: RELEASELEVEL
RELEASELEVEL = 0x0F
#: SERIAL

View file

@ -1,5 +1,7 @@
import sys
import dns._features
if sys.platform == "win32":
from typing import Any
@ -15,14 +17,14 @@ if sys.platform == "win32":
except KeyError:
WindowsError = Exception
try:
if dns._features.have("wmi"):
import threading
import pythoncom # pylint: disable=import-error
import wmi # pylint: disable=import-error
_have_wmi = True
except Exception:
else:
_have_wmi = False
def _config_domain(domain):
@ -51,9 +53,10 @@ if sys.platform == "win32":
try:
system = wmi.WMI()
for interface in system.Win32_NetworkAdapterConfiguration():
if interface.IPEnabled and interface.DNSDomain:
self.info.domain = _config_domain(interface.DNSDomain)
if interface.IPEnabled and interface.DNSServerSearchOrder:
self.info.nameservers = list(interface.DNSServerSearchOrder)
if interface.DNSDomain:
self.info.domain = _config_domain(interface.DNSDomain)
if interface.DNSDomainSuffixSearchOrder:
self.info.search = [
_config_domain(x)

View file

@ -21,7 +21,18 @@ import contextlib
import io
import os
import struct
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
MutableMapping,
Optional,
Set,
Tuple,
Union,
)
import dns.exception
import dns.grange
@ -43,47 +54,70 @@ from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers
class BadZone(dns.exception.DNSException):
"""The DNS zone is malformed."""
class NoSOA(BadZone):
"""The DNS zone has no SOA RR at its origin."""
class NoNS(BadZone):
"""The DNS zone has no NS RRset at its origin."""
class UnknownOrigin(BadZone):
"""The DNS zone's origin is unknown."""
class UnsupportedDigestScheme(dns.exception.DNSException):
"""The zone digest's scheme is unsupported."""
class UnsupportedDigestHashAlgorithm(dns.exception.DNSException):
"""The zone digest's origin is unsupported."""
class NoDigest(dns.exception.DNSException):
"""The DNS zone has no ZONEMD RRset at its origin."""
class DigestVerificationFailure(dns.exception.DNSException):
"""The ZONEMD digest failed to verify."""
class Zone(dns.transaction.TransactionManager):
def _validate_name(
name: dns.name.Name,
origin: Optional[dns.name.Name],
relativize: bool,
) -> dns.name.Name:
# This name validation code is shared by Zone and Version
if origin is None:
# This should probably never happen as other code (e.g.
# _rr_line) will notice the lack of an origin before us, but
# we check just in case!
raise KeyError("no zone origin is defined")
if name.is_absolute():
if not name.is_subdomain(origin):
raise KeyError("name parameter must be a subdomain of the zone origin")
if relativize:
name = name.relativize(origin)
else:
# We have a relative name. Make sure that the derelativized name is
# not too long.
try:
abs_name = name.derelativize(origin)
except dns.name.NameTooLong:
# We map dns.name.NameTooLong to KeyError to be consistent with
# the other exceptions above.
raise KeyError("relative name too long for zone")
if not relativize:
# We have a relative name in a non-relative zone, so use the
# derelativized name.
name = abs_name
return name
class Zone(dns.transaction.TransactionManager):
"""A DNS zone.
A ``Zone`` is a mapping from names to nodes. The zone object may be
@ -94,7 +128,10 @@ class Zone(dns.transaction.TransactionManager):
the zone.
"""
node_factory = dns.node.Node
node_factory: Callable[[], dns.node.Node] = dns.node.Node
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
__slots__ = ["rdclass", "origin", "nodes", "relativize"]
@ -125,7 +162,7 @@ class Zone(dns.transaction.TransactionManager):
raise ValueError("origin parameter must be an absolute name")
self.origin = origin
self.rdclass = rdclass
self.nodes: Dict[dns.name.Name, dns.node.Node] = {}
self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory()
self.relativize = relativize
def __eq__(self, other):
@ -154,26 +191,13 @@ class Zone(dns.transaction.TransactionManager):
return not self.__eq__(other)
def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
# Note that any changes in this method should have corresponding changes
# made in the Version _validate_name() method.
if isinstance(name, str):
name = dns.name.from_text(name, None)
elif not isinstance(name, dns.name.Name):
raise KeyError("name parameter must be convertible to a DNS name")
if name.is_absolute():
if self.origin is None:
# This should probably never happen as other code (e.g.
# _rr_line) will notice the lack of an origin before us, but
# we check just in case!
raise KeyError("no zone origin is defined")
if not name.is_subdomain(self.origin):
raise KeyError("name parameter must be a subdomain of the zone origin")
if self.relativize:
name = name.relativize(self.origin)
elif not self.relativize:
# We have a relative name in a non-relative zone, so derelativize.
if self.origin is None:
raise KeyError("no zone origin is defined")
name = name.derelativize(self.origin)
return name
return _validate_name(name, self.origin, self.relativize)
def __getitem__(self, key):
key = self._validate_name(key)
@ -252,9 +276,6 @@ class Zone(dns.transaction.TransactionManager):
*create*, a ``bool``. If true, the node will be created if it does
not exist.
Raises ``KeyError`` if the name is not known and create was
not specified, or if the name was not a subdomain of the origin.
Returns a ``dns.node.Node`` or ``None``.
"""
@ -527,9 +548,6 @@ class Zone(dns.transaction.TransactionManager):
*create*, a ``bool``. If true, the node will be created if it does
not exist.
Raises ``KeyError`` if the name is not known and create was
not specified, or if the name was not a subdomain of the origin.
Returns a ``dns.rrset.RRset`` or ``None``.
"""
@ -952,7 +970,7 @@ class Version:
self,
zone: Zone,
id: int,
nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None,
nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None,
origin: Optional[dns.name.Name] = None,
):
self.zone = zone
@ -960,26 +978,11 @@ class Version:
if nodes is not None:
self.nodes = nodes
else:
self.nodes = {}
self.nodes = zone.map_factory()
self.origin = origin
def _validate_name(self, name: dns.name.Name) -> dns.name.Name:
if name.is_absolute():
if self.origin is None:
# This should probably never happen as other code (e.g.
# _rr_line) will notice the lack of an origin before us, but
# we check just in case!
raise KeyError("no zone origin is defined")
if not name.is_subdomain(self.origin):
raise KeyError("name is not a subdomain of the zone origin")
if self.zone.relativize:
name = name.relativize(self.origin)
elif not self.zone.relativize:
# We have a relative name in a non-relative zone, so derelativize.
if self.origin is None:
raise KeyError("no zone origin is defined")
name = name.derelativize(self.origin)
return name
return _validate_name(name, self.origin, self.zone.relativize)
def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
name = self._validate_name(name)
@ -1085,7 +1088,9 @@ class ImmutableVersion(Version):
version.nodes[name] = ImmutableVersionedNode(node)
# We're changing the type of the nodes dictionary here on purpose, so
# we ignore the mypy error.
self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore
self.nodes = dns.immutable.Dict(
version.nodes, True, self.zone.map_factory
) # type: ignore
class Transaction(dns.transaction.Transaction):
@ -1101,7 +1106,10 @@ class Transaction(dns.transaction.Transaction):
def _setup_version(self):
assert self.version is None
self.version = WritableVersion(self.zone, self.replacement)
factory = self.manager.writable_version_factory
if factory is None:
factory = WritableVersion
self.version = factory(self.zone, self.replacement)
def _get_rdataset(self, name, rdtype, covers):
return self.version.get_rdataset(name, rdtype, covers)
@ -1132,7 +1140,10 @@ class Transaction(dns.transaction.Transaction):
self.zone._end_read(self)
elif commit and len(self.version.changed) > 0:
if self.make_immutable:
version = ImmutableVersion(self.version)
factory = self.manager.immutable_version_factory
if factory is None:
factory = ImmutableVersion
version = factory(self.version)
else:
version = self.version
self.zone._commit_version(self, version, self.version.origin)
@ -1168,6 +1179,48 @@ class Transaction(dns.transaction.Transaction):
return (absolute, relativize, effective)
def _from_text(
text: Any,
origin: Optional[Union[dns.name.Name, str]] = None,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
relativize: bool = True,
zone_factory: Any = Zone,
filename: Optional[str] = None,
allow_include: bool = False,
check_origin: bool = True,
idna_codec: Optional[dns.name.IDNACodec] = None,
allow_directives: Union[bool, Iterable[str]] = True,
) -> Zone:
# See the comments for the public APIs from_text() and from_file() for
# details.
# 'text' can also be a file, but we don't publish that fact
# since it's an implementation detail. The official file
# interface is from_file().
if filename is None:
filename = "<string>"
zone = zone_factory(origin, rdclass, relativize=relativize)
with zone.writer(True) as txn:
tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
reader = dns.zonefile.Reader(
tok,
rdclass,
txn,
allow_include=allow_include,
allow_directives=allow_directives,
)
try:
reader.read()
except dns.zonefile.UnknownOrigin:
# for backwards compatibility
raise dns.zone.UnknownOrigin
# Now that we're done reading, do some basic checking of the zone.
if check_origin:
zone.check_origin()
return zone
def from_text(
text: str,
origin: Optional[Union[dns.name.Name, str]] = None,
@ -1228,32 +1281,18 @@ def from_text(
Returns a subclass of ``dns.zone.Zone``.
"""
# 'text' can also be a file, but we don't publish that fact
# since it's an implementation detail. The official file
# interface is from_file().
if filename is None:
filename = "<string>"
zone = zone_factory(origin, rdclass, relativize=relativize)
with zone.writer(True) as txn:
tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
reader = dns.zonefile.Reader(
tok,
rdclass,
txn,
allow_include=allow_include,
allow_directives=allow_directives,
)
try:
reader.read()
except dns.zonefile.UnknownOrigin:
# for backwards compatibility
raise dns.zone.UnknownOrigin
# Now that we're done reading, do some basic checking of the zone.
if check_origin:
zone.check_origin()
return zone
return _from_text(
text,
origin,
rdclass,
relativize,
zone_factory,
filename,
allow_include,
check_origin,
idna_codec,
allow_directives,
)
def from_file(
@ -1324,7 +1363,7 @@ def from_file(
else:
cm = contextlib.nullcontext(f)
with cm as f:
return from_text(
return _from_text(
f,
origin,
rdclass,

View file

@ -86,7 +86,6 @@ def _upper_dollarize(s):
class Reader:
"""Read a DNS zone file into a transaction."""
def __init__(

Some files were not shown because too many files have changed in this diff Show more