mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-05 20:51:15 -07:00
Bump dnspython from 2.4.2 to 2.6.1 (#2264)
* Bump dnspython from 2.4.2 to 2.6.1 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.4.2 to 2.6.1. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.4.2...v2.6.1) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update dnspython==2.6.1 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
aca7e72715
commit
cfefa928be
101 changed files with 1052 additions and 459 deletions
|
@ -7,7 +7,9 @@ import socket
|
|||
import sys
|
||||
|
||||
import dns._asyncbackend
|
||||
import dns._features
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
|
||||
_is_win32 = sys.platform == "win32"
|
||||
|
||||
|
@ -121,7 +123,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
return self.writer.get_extra_info("peercert")
|
||||
|
||||
|
||||
try:
|
||||
if dns._features.have("doh"):
|
||||
import anyio
|
||||
import httpcore
|
||||
import httpcore._backends.anyio
|
||||
|
@ -205,7 +207,7 @@ try:
|
|||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
else:
|
||||
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||
|
||||
|
||||
|
@ -224,14 +226,12 @@ class Backend(dns._asyncbackend.Backend):
|
|||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
):
|
||||
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
|
||||
raise NotImplementedError(
|
||||
"destinationless datagram sockets "
|
||||
"are not supported by asyncio "
|
||||
"on Windows"
|
||||
)
|
||||
loop = _get_running_loop()
|
||||
if socktype == socket.SOCK_DGRAM:
|
||||
if _is_win32 and source is None:
|
||||
# Win32 wants explicit binding before recvfrom(). This is the
|
||||
# proper fix for [#637].
|
||||
source = (dns.inet.any_for_af(af), 0)
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
_DatagramProtocol,
|
||||
source,
|
||||
|
@ -266,7 +266,7 @@ class Backend(dns._asyncbackend.Backend):
|
|||
await asyncio.sleep(interval)
|
||||
|
||||
def datagram_connection_required(self):
|
||||
return _is_win32
|
||||
return False
|
||||
|
||||
def get_transport_class(self):
|
||||
return _HTTPTransport
|
||||
|
|
92
lib/dns/_features.py
Normal file
92
lib/dns/_features.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import importlib.metadata
|
||||
import itertools
|
||||
import string
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
def _tuple_from_text(version: str) -> Tuple:
|
||||
text_parts = version.split(".")
|
||||
int_parts = []
|
||||
for text_part in text_parts:
|
||||
digit_prefix = "".join(
|
||||
itertools.takewhile(lambda x: x in string.digits, text_part)
|
||||
)
|
||||
try:
|
||||
int_parts.append(int(digit_prefix))
|
||||
except Exception:
|
||||
break
|
||||
return tuple(int_parts)
|
||||
|
||||
|
||||
def _version_check(
|
||||
requirement: str,
|
||||
) -> bool:
|
||||
"""Is the requirement fulfilled?
|
||||
|
||||
The requirement must be of the form
|
||||
|
||||
package>=version
|
||||
"""
|
||||
package, minimum = requirement.split(">=")
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
except Exception:
|
||||
return False
|
||||
t_version = _tuple_from_text(version)
|
||||
t_minimum = _tuple_from_text(minimum)
|
||||
if t_version < t_minimum:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_cache: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def have(feature: str) -> bool:
|
||||
"""Is *feature* available?
|
||||
|
||||
This tests if all optional packages needed for the
|
||||
feature are available and recent enough.
|
||||
|
||||
Returns ``True`` if the feature is available,
|
||||
and ``False`` if it is not or if metadata is
|
||||
missing.
|
||||
"""
|
||||
value = _cache.get(feature)
|
||||
if value is not None:
|
||||
return value
|
||||
requirements = _requirements.get(feature)
|
||||
if requirements is None:
|
||||
# we make a cache entry here for consistency not performance
|
||||
_cache[feature] = False
|
||||
return False
|
||||
ok = True
|
||||
for requirement in requirements:
|
||||
if not _version_check(requirement):
|
||||
ok = False
|
||||
break
|
||||
_cache[feature] = ok
|
||||
return ok
|
||||
|
||||
|
||||
def force(feature: str, enabled: bool) -> None:
|
||||
"""Force the status of *feature* to be *enabled*.
|
||||
|
||||
This method is provided as a workaround for any cases
|
||||
where importlib.metadata is ineffective, or for testing.
|
||||
"""
|
||||
_cache[feature] = enabled
|
||||
|
||||
|
||||
_requirements: Dict[str, List[str]] = {
|
||||
### BEGIN generated requirements
|
||||
"dnssec": ["cryptography>=41"],
|
||||
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
|
||||
"doq": ["aioquic>=0.9.25"],
|
||||
"idna": ["idna>=3.6"],
|
||||
"trio": ["trio>=0.23"],
|
||||
"wmi": ["wmi>=1.5.1"],
|
||||
### END generated requirements
|
||||
}
|
|
@ -8,9 +8,13 @@ import trio
|
|||
import trio.socket # type: ignore
|
||||
|
||||
import dns._asyncbackend
|
||||
import dns._features
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
|
||||
if not dns._features.have("trio"):
|
||||
raise ImportError("trio not found or too old")
|
||||
|
||||
|
||||
def _maybe_timeout(timeout):
|
||||
if timeout is not None:
|
||||
|
@ -95,7 +99,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
try:
|
||||
if dns._features.have("doh"):
|
||||
import httpcore
|
||||
import httpcore._backends.trio
|
||||
import httpx
|
||||
|
@ -177,7 +181,7 @@ try:
|
|||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
else:
|
||||
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def get_backend(name: str) -> Backend:
|
|||
*name*, a ``str``, the name of the backend. Currently the "trio"
|
||||
and "asyncio" backends are available.
|
||||
|
||||
Raises NotImplementError if an unknown backend name is specified.
|
||||
Raises NotImplementedError if an unknown backend name is specified.
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
backend = _backends.get(name)
|
||||
|
|
|
@ -41,7 +41,7 @@ from dns.query import (
|
|||
NoDOQ,
|
||||
UDPMode,
|
||||
_compute_times,
|
||||
_have_http2,
|
||||
_make_dot_ssl_context,
|
||||
_matches_destination,
|
||||
_remaining,
|
||||
have_doh,
|
||||
|
@ -120,6 +120,8 @@ async def receive_udp(
|
|||
request_mac: Optional[bytes] = b"",
|
||||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
ignore_errors: bool = False,
|
||||
query: Optional[dns.message.Message] = None,
|
||||
) -> Any:
|
||||
"""Read a DNS message from a UDP socket.
|
||||
|
||||
|
@ -133,22 +135,40 @@ async def receive_udp(
|
|||
"""
|
||||
|
||||
wire = b""
|
||||
while 1:
|
||||
while True:
|
||||
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
|
||||
if _matches_destination(
|
||||
if not _matches_destination(
|
||||
sock.family, from_address, destination, ignore_unexpected
|
||||
):
|
||||
break
|
||||
received_time = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
raise_on_truncation=raise_on_truncation,
|
||||
)
|
||||
return (r, received_time, from_address)
|
||||
continue
|
||||
received_time = time.time()
|
||||
try:
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
raise_on_truncation=raise_on_truncation,
|
||||
)
|
||||
except dns.message.Truncated as e:
|
||||
# See the comment in query.py for details.
|
||||
if (
|
||||
ignore_errors
|
||||
and query is not None
|
||||
and not query.is_response(e.message())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
if ignore_errors:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
if ignore_errors and query is not None and not query.is_response(r):
|
||||
continue
|
||||
return (r, received_time, from_address)
|
||||
|
||||
|
||||
async def udp(
|
||||
|
@ -164,6 +184,7 @@ async def udp(
|
|||
raise_on_truncation: bool = False,
|
||||
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via UDP.
|
||||
|
||||
|
@ -205,9 +226,13 @@ async def udp(
|
|||
q.mac,
|
||||
ignore_trailing,
|
||||
raise_on_truncation,
|
||||
ignore_errors,
|
||||
q,
|
||||
)
|
||||
r.time = received_time - begin_time
|
||||
if not q.is_response(r):
|
||||
# We don't need to check q.is_response() if we are in ignore_errors mode
|
||||
# as receive_udp() will have checked it.
|
||||
if not (ignore_errors or q.is_response(r)):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
@ -225,6 +250,7 @@ async def udp_with_fallback(
|
|||
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
|
||||
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> Tuple[dns.message.Message, bool]:
|
||||
"""Return the response to the query, trying UDP first and falling back
|
||||
to TCP if UDP results in a truncated response.
|
||||
|
@ -260,6 +286,7 @@ async def udp_with_fallback(
|
|||
True,
|
||||
udp_sock,
|
||||
backend,
|
||||
ignore_errors,
|
||||
)
|
||||
return (response, False)
|
||||
except dns.message.Truncated:
|
||||
|
@ -292,14 +319,12 @@ async def send_tcp(
|
|||
"""
|
||||
|
||||
if isinstance(what, dns.message.Message):
|
||||
wire = what.to_wire()
|
||||
tcpmsg = what.to_wire(prepend_length=True)
|
||||
else:
|
||||
wire = what
|
||||
l = len(wire)
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = struct.pack("!H", l) + wire
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = len(what).to_bytes(2, "big") + what
|
||||
sent_time = time.time()
|
||||
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
|
||||
return (len(tcpmsg), sent_time)
|
||||
|
@ -418,6 +443,7 @@ async def tls(
|
|||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via TLS.
|
||||
|
||||
|
@ -439,11 +465,7 @@ async def tls(
|
|||
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
|
||||
else:
|
||||
if ssl_context is None:
|
||||
# See the comment about ssl.create_default_context() in query.py
|
||||
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
|
||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
if server_hostname is None:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context = _make_dot_ssl_context(server_hostname, verify)
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
|
@ -538,7 +560,7 @@ async def https(
|
|||
transport = backend.get_transport_class()(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=_have_http2,
|
||||
http2=True,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
|
@ -550,7 +572,7 @@ async def https(
|
|||
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
||||
else:
|
||||
cm = httpx.AsyncClient(
|
||||
http1=True, http2=_have_http2, verify=verify, transport=transport
|
||||
http1=True, http2=True, verify=verify, transport=transport
|
||||
)
|
||||
|
||||
async with cm as the_client:
|
||||
|
|
|
@ -27,6 +27,7 @@ import time
|
|||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
||||
|
||||
import dns._features
|
||||
import dns.exception
|
||||
import dns.name
|
||||
import dns.node
|
||||
|
@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs):
|
|||
) # pragma: no cover
|
||||
|
||||
|
||||
try:
|
||||
if dns._features.have("dnssec"):
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
|
||||
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
|
||||
|
@ -1184,20 +1185,20 @@ try:
|
|||
get_algorithm_cls_from_dnskey,
|
||||
)
|
||||
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
|
||||
except ImportError: # pragma: no cover
|
||||
validate = _need_pyca
|
||||
validate_rrsig = _need_pyca
|
||||
sign = _need_pyca
|
||||
make_dnskey = _need_pyca
|
||||
make_cdnskey = _need_pyca
|
||||
_have_pyca = False
|
||||
else:
|
||||
|
||||
validate = _validate # type: ignore
|
||||
validate_rrsig = _validate_rrsig # type: ignore
|
||||
sign = _sign
|
||||
make_dnskey = _make_dnskey
|
||||
make_cdnskey = _make_cdnskey
|
||||
_have_pyca = True
|
||||
else: # pragma: no cover
|
||||
validate = _need_pyca
|
||||
validate_rrsig = _need_pyca
|
||||
sign = _need_pyca
|
||||
make_dnskey = _need_pyca
|
||||
make_cdnskey = _need_pyca
|
||||
_have_pyca = False
|
||||
|
||||
### BEGIN generated Algorithm constants
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from typing import Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import dns.name
|
||||
from dns.dnssecalgs.base import GenericPrivateKey
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.exception import UnsupportedAlgorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
try:
|
||||
from dns.dnssecalgs.base import GenericPrivateKey
|
||||
if dns._features.have("dnssec"):
|
||||
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
|
||||
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
|
||||
from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
|
||||
|
@ -16,13 +19,9 @@ try:
|
|||
)
|
||||
|
||||
_have_cryptography = True
|
||||
except ImportError:
|
||||
else:
|
||||
_have_cryptography = False
|
||||
|
||||
from dns.dnssectypes import Algorithm
|
||||
from dns.exception import UnsupportedAlgorithm
|
||||
from dns.rdtypes.ANY.DNSKEY import DNSKEY
|
||||
|
||||
AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
|
||||
|
||||
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
"""EDNS Options"""
|
||||
|
||||
import binascii
|
||||
import math
|
||||
import socket
|
||||
import struct
|
||||
|
@ -58,7 +59,6 @@ class OptionType(dns.enum.IntEnum):
|
|||
|
||||
|
||||
class Option:
|
||||
|
||||
"""Base class for all EDNS option types."""
|
||||
|
||||
def __init__(self, otype: Union[OptionType, str]):
|
||||
|
@ -76,6 +76,9 @@ class Option:
|
|||
"""
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def to_text(self) -> str:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
|
||||
"""Build an EDNS option object from wire format.
|
||||
|
@ -141,7 +144,6 @@ class Option:
|
|||
|
||||
|
||||
class GenericOption(Option): # lgtm[py/missing-equals]
|
||||
|
||||
"""Generic Option Class
|
||||
|
||||
This class is used for EDNS option types for which we have no better
|
||||
|
@ -343,6 +345,8 @@ class EDECode(dns.enum.IntEnum):
|
|||
class EDEOption(Option): # lgtm[py/missing-equals]
|
||||
"""Extended DNS Error (EDE, RFC8914)"""
|
||||
|
||||
_preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"}
|
||||
|
||||
def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
|
||||
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
|
||||
extended error.
|
||||
|
@ -360,6 +364,13 @@ class EDEOption(Option): # lgtm[py/missing-equals]
|
|||
|
||||
def to_text(self) -> str:
|
||||
output = f"EDE {self.code}"
|
||||
if self.code in EDECode:
|
||||
desc = EDECode.to_text(self.code)
|
||||
desc = " ".join(
|
||||
word if word in self._preserve_case else word.title()
|
||||
for word in desc.split("_")
|
||||
)
|
||||
output += f" ({desc})"
|
||||
if self.text is not None:
|
||||
output += f": {self.text}"
|
||||
return output
|
||||
|
@ -392,9 +403,37 @@ class EDEOption(Option): # lgtm[py/missing-equals]
|
|||
return cls(code, btext)
|
||||
|
||||
|
||||
class NSIDOption(Option):
|
||||
def __init__(self, nsid: bytes):
|
||||
super().__init__(OptionType.NSID)
|
||||
self.nsid = nsid
|
||||
|
||||
def to_wire(self, file: Any = None) -> Optional[bytes]:
|
||||
if file:
|
||||
file.write(self.nsid)
|
||||
return None
|
||||
else:
|
||||
return self.nsid
|
||||
|
||||
def to_text(self) -> str:
|
||||
if all(c >= 0x20 and c <= 0x7E for c in self.nsid):
|
||||
# All ASCII printable, so it's probably a string.
|
||||
value = self.nsid.decode()
|
||||
else:
|
||||
value = binascii.hexlify(self.nsid).decode()
|
||||
return f"NSID {value}"
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
|
||||
) -> Option:
|
||||
return cls(parser.get_remaining())
|
||||
|
||||
|
||||
_type_to_class: Dict[OptionType, Any] = {
|
||||
OptionType.ECS: ECSOption,
|
||||
OptionType.EDE: EDEOption,
|
||||
OptionType.NSID: NSIDOption,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,24 +1,30 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import collections.abc
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from dns._immutable_ctx import immutable
|
||||
|
||||
|
||||
@immutable
|
||||
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
|
||||
def __init__(self, dictionary: Any, no_copy: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
dictionary: Any,
|
||||
no_copy: bool = False,
|
||||
map_factory: Callable[[], collections.abc.MutableMapping] = dict,
|
||||
):
|
||||
"""Make an immutable dictionary from the specified dictionary.
|
||||
|
||||
If *no_copy* is `True`, then *dictionary* will be wrapped instead
|
||||
of copied. Only set this if you are sure there will be no external
|
||||
references to the dictionary.
|
||||
"""
|
||||
if no_copy and isinstance(dictionary, dict):
|
||||
if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
|
||||
self._odict = dictionary
|
||||
else:
|
||||
self._odict = dict(dictionary)
|
||||
self._odict = map_factory()
|
||||
self._odict.update(dictionary)
|
||||
self._hash = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
|
|
|
@ -178,3 +178,20 @@ def any_for_af(af):
|
|||
elif af == socket.AF_INET6:
|
||||
return "::"
|
||||
raise NotImplementedError(f"unknown address family {af}")
|
||||
|
||||
|
||||
def canonicalize(text: str) -> str:
|
||||
"""Verify that *address* is a valid text form IPv4 or IPv6 address and return its
|
||||
canonical text form. IPv6 addresses with scopes are rejected.
|
||||
|
||||
*text*, a ``str``, the address in textual form.
|
||||
|
||||
Raises ``ValueError`` if the text is not valid.
|
||||
"""
|
||||
try:
|
||||
return dns.ipv6.canonicalize(text)
|
||||
except Exception:
|
||||
try:
|
||||
return dns.ipv4.canonicalize(text)
|
||||
except Exception:
|
||||
raise ValueError
|
||||
|
|
|
@ -62,3 +62,16 @@ def inet_aton(text: Union[str, bytes]) -> bytes:
|
|||
return struct.pack("BBBB", *b)
|
||||
except Exception:
|
||||
raise dns.exception.SyntaxError
|
||||
|
||||
|
||||
def canonicalize(text: Union[str, bytes]) -> str:
|
||||
"""Verify that *address* is a valid text form IPv4 address and return its
|
||||
canonical text form.
|
||||
|
||||
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
|
||||
|
||||
Raises ``dns.exception.SyntaxError`` if the text is not valid.
|
||||
"""
|
||||
# Note that inet_aton() only accepts canonial form, but we still run through
|
||||
# inet_ntoa() to ensure the output is a str.
|
||||
return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))
|
||||
|
|
|
@ -104,7 +104,7 @@ _colon_colon_end = re.compile(rb".*::$")
|
|||
def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
|
||||
"""Convert an IPv6 address in text form to binary form.
|
||||
|
||||
*text*, a ``str``, the IPv6 address in textual form.
|
||||
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
|
||||
|
||||
*ignore_scope*, a ``bool``. If ``True``, a scope will be ignored.
|
||||
If ``False``, the default, it is an error for a scope to be present.
|
||||
|
@ -206,3 +206,14 @@ def is_mapped(address: bytes) -> bool:
|
|||
"""
|
||||
|
||||
return address.startswith(_mapped_prefix)
|
||||
|
||||
|
||||
def canonicalize(text: Union[str, bytes]) -> str:
|
||||
"""Verify that *address* is a valid text form IPv6 address and return its
|
||||
canonical text form. Addresses with scopes are rejected.
|
||||
|
||||
*text*, a ``str`` or ``bytes``, the IPv6 address in textual form.
|
||||
|
||||
Raises ``dns.exception.SyntaxError`` if the text is not valid.
|
||||
"""
|
||||
return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))
|
||||
|
|
|
@ -393,7 +393,7 @@ class Message:
|
|||
section_number = section
|
||||
section = self.section_from_number(section_number)
|
||||
elif isinstance(section, str):
|
||||
section_number = MessageSection.from_text(section)
|
||||
section_number = self._section_enum.from_text(section)
|
||||
section = self.section_from_number(section_number)
|
||||
else:
|
||||
section_number = self.section_number(section)
|
||||
|
@ -489,6 +489,34 @@ class Message:
|
|||
rrset = None
|
||||
return rrset
|
||||
|
||||
def section_count(self, section: SectionType) -> int:
|
||||
"""Returns the number of records in the specified section.
|
||||
|
||||
*section*, an ``int`` section number, a ``str`` section name, or one of
|
||||
the section attributes of this message. This specifies the
|
||||
the section of the message to count. For example::
|
||||
|
||||
my_message.section_count(my_message.answer)
|
||||
my_message.section_count(dns.message.ANSWER)
|
||||
my_message.section_count("ANSWER")
|
||||
"""
|
||||
|
||||
if isinstance(section, int):
|
||||
section_number = section
|
||||
section = self.section_from_number(section_number)
|
||||
elif isinstance(section, str):
|
||||
section_number = self._section_enum.from_text(section)
|
||||
section = self.section_from_number(section_number)
|
||||
else:
|
||||
section_number = self.section_number(section)
|
||||
count = sum(max(1, len(rrs)) for rrs in section)
|
||||
if section_number == MessageSection.ADDITIONAL:
|
||||
if self.opt is not None:
|
||||
count += 1
|
||||
if self.tsig is not None:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _compute_opt_reserve(self) -> int:
|
||||
"""Compute the size required for the OPT RR, padding excluded"""
|
||||
if not self.opt:
|
||||
|
@ -527,6 +555,8 @@ class Message:
|
|||
max_size: int = 0,
|
||||
multi: bool = False,
|
||||
tsig_ctx: Optional[Any] = None,
|
||||
prepend_length: bool = False,
|
||||
prefer_truncation: bool = False,
|
||||
**kw: Dict[str, Any],
|
||||
) -> bytes:
|
||||
"""Return a string containing the message in DNS compressed wire
|
||||
|
@ -549,6 +579,15 @@ class Message:
|
|||
*tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
|
||||
ongoing TSIG context, used when signing zone transfers.
|
||||
|
||||
*prepend_length*, a ``bool``, should be set to ``True`` if the caller
|
||||
wants the message length prepended to the message itself. This is
|
||||
useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ).
|
||||
|
||||
*prefer_truncation*, a ``bool``, should be set to ``True`` if the caller
|
||||
wants the message to be truncated if it would otherwise exceed the
|
||||
maximum length. If the truncation occurs before the additional section,
|
||||
the TC bit will be set.
|
||||
|
||||
Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
|
||||
|
||||
Returns a ``bytes``.
|
||||
|
@ -570,14 +609,21 @@ class Message:
|
|||
r.reserve(opt_reserve)
|
||||
tsig_reserve = self._compute_tsig_reserve()
|
||||
r.reserve(tsig_reserve)
|
||||
for rrset in self.question:
|
||||
r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
|
||||
for rrset in self.answer:
|
||||
r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
|
||||
for rrset in self.authority:
|
||||
r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
|
||||
for rrset in self.additional:
|
||||
r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
|
||||
try:
|
||||
for rrset in self.question:
|
||||
r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
|
||||
for rrset in self.answer:
|
||||
r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
|
||||
for rrset in self.authority:
|
||||
r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
|
||||
for rrset in self.additional:
|
||||
r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
|
||||
except dns.exception.TooBig:
|
||||
if prefer_truncation:
|
||||
if r.section < dns.renderer.ADDITIONAL:
|
||||
r.flags |= dns.flags.TC
|
||||
else:
|
||||
raise
|
||||
r.release_reserved()
|
||||
if self.opt is not None:
|
||||
r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
|
||||
|
@ -598,7 +644,10 @@ class Message:
|
|||
r.write_header()
|
||||
if multi:
|
||||
self.tsig_ctx = ctx
|
||||
return r.get_wire()
|
||||
wire = r.get_wire()
|
||||
if prepend_length:
|
||||
wire = len(wire).to_bytes(2, "big") + wire
|
||||
return wire
|
||||
|
||||
@staticmethod
|
||||
def _make_tsig(
|
||||
|
@ -777,6 +826,8 @@ class Message:
|
|||
if request_payload is None:
|
||||
request_payload = payload
|
||||
self.request_payload = request_payload
|
||||
if pad < 0:
|
||||
raise ValueError("pad must be non-negative")
|
||||
self.pad = pad
|
||||
|
||||
@property
|
||||
|
@ -826,7 +877,7 @@ class Message:
|
|||
if wanted:
|
||||
self.ednsflags |= dns.flags.DO
|
||||
elif self.opt:
|
||||
self.ednsflags &= ~dns.flags.DO
|
||||
self.ednsflags &= ~int(dns.flags.DO)
|
||||
|
||||
def rcode(self) -> dns.rcode.Rcode:
|
||||
"""Return the rcode.
|
||||
|
@ -1035,7 +1086,6 @@ def _message_factory_from_opcode(opcode):
|
|||
|
||||
|
||||
class _WireReader:
|
||||
|
||||
"""Wire format reader.
|
||||
|
||||
parser: the binary parser
|
||||
|
@ -1335,7 +1385,6 @@ def from_wire(
|
|||
|
||||
|
||||
class _TextReader:
|
||||
|
||||
"""Text format reader.
|
||||
|
||||
tok: the tokenizer.
|
||||
|
@ -1768,30 +1817,34 @@ def make_response(
|
|||
our_payload: int = 8192,
|
||||
fudge: int = 300,
|
||||
tsig_error: int = 0,
|
||||
pad: Optional[int] = None,
|
||||
) -> Message:
|
||||
"""Make a message which is a response for the specified query.
|
||||
The message returned is really a response skeleton; it has all
|
||||
of the infrastructure required of a response, but none of the
|
||||
content.
|
||||
The message returned is really a response skeleton; it has all of the infrastructure
|
||||
required of a response, but none of the content.
|
||||
|
||||
The response's question section is a shallow copy of the query's
|
||||
question section, so the query's question RRsets should not be
|
||||
changed.
|
||||
The response's question section is a shallow copy of the query's question section,
|
||||
so the query's question RRsets should not be changed.
|
||||
|
||||
*query*, a ``dns.message.Message``, the query to respond to.
|
||||
|
||||
*recursion_available*, a ``bool``, should RA be set in the response?
|
||||
|
||||
*our_payload*, an ``int``, the payload size to advertise in EDNS
|
||||
responses.
|
||||
*our_payload*, an ``int``, the payload size to advertise in EDNS responses.
|
||||
|
||||
*fudge*, an ``int``, the TSIG time fudge.
|
||||
|
||||
*tsig_error*, an ``int``, the TSIG error.
|
||||
|
||||
Returns a ``dns.message.Message`` object whose specific class is
|
||||
appropriate for the query. For example, if query is a
|
||||
``dns.update.UpdateMessage``, response will be too.
|
||||
*pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
|
||||
if not ``None`` add padding bytes to make the message size a multiple of *pad*.
|
||||
Note that if padding is non-zero, an EDNS PADDING option will always be added to the
|
||||
message. If ``None``, add padding following RFC 8467, namely if the request is
|
||||
padded, pad the response to 468 otherwise do not pad.
|
||||
|
||||
Returns a ``dns.message.Message`` object whose specific class is appropriate for the
|
||||
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
|
||||
too.
|
||||
"""
|
||||
|
||||
if query.flags & dns.flags.QR:
|
||||
|
@ -1804,7 +1857,13 @@ def make_response(
|
|||
response.set_opcode(query.opcode())
|
||||
response.question = list(query.question)
|
||||
if query.edns >= 0:
|
||||
response.use_edns(0, 0, our_payload, query.payload)
|
||||
if pad is None:
|
||||
# Set response padding per RFC 8467
|
||||
pad = 0
|
||||
for option in query.options:
|
||||
if option.otype == dns.edns.OptionType.PADDING:
|
||||
pad = 468
|
||||
response.use_edns(0, 0, our_payload, query.payload, pad=pad)
|
||||
if query.had_tsig:
|
||||
response.use_tsig(
|
||||
query.keyring,
|
||||
|
|
217
lib/dns/name.py
217
lib/dns/name.py
|
@ -20,21 +20,23 @@
|
|||
|
||||
import copy
|
||||
import encodings.idna # type: ignore
|
||||
import functools
|
||||
import struct
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
try:
|
||||
import idna # type: ignore
|
||||
|
||||
have_idna_2008 = True
|
||||
except ImportError: # pragma: no cover
|
||||
have_idna_2008 = False
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import dns._features
|
||||
import dns.enum
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.wire
|
||||
|
||||
if dns._features.have("idna"):
|
||||
import idna # type: ignore
|
||||
|
||||
have_idna_2008 = True
|
||||
else: # pragma: no cover
|
||||
have_idna_2008 = False
|
||||
|
||||
CompressType = Dict["Name", int]
|
||||
|
||||
|
||||
|
@ -128,6 +130,10 @@ class IDNAException(dns.exception.DNSException):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class NeedSubdomainOfOrigin(dns.exception.DNSException):
|
||||
"""An absolute name was provided that is not a subdomain of the specified origin."""
|
||||
|
||||
|
||||
_escaped = b'"().;\\@$'
|
||||
_escaped_text = '"().;\\@$'
|
||||
|
||||
|
@ -350,7 +356,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
|
|||
|
||||
@dns.immutable.immutable
|
||||
class Name:
|
||||
|
||||
"""A DNS name.
|
||||
|
||||
The dns.name.Name class represents a DNS name as a tuple of
|
||||
|
@ -843,6 +848,42 @@ class Name:
|
|||
raise NoParent
|
||||
return Name(self.labels[1:])
|
||||
|
||||
def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
|
||||
"""Return the maximal predecessor of *name* in the DNSSEC ordering in the zone
|
||||
whose origin is *origin*, or return the longest name under *origin* if the
|
||||
name is origin (i.e. wrap around to the longest name, which may still be
|
||||
*origin* due to length considerations.
|
||||
|
||||
The relativity of the name is preserved, so if this name is relative
|
||||
then the method will return a relative name, and likewise if this name
|
||||
is absolute then the predecessor will be absolute.
|
||||
|
||||
*prefix_ok* indicates if prefixing labels is allowed, and
|
||||
defaults to ``True``. Normally it is good to allow this, but if computing
|
||||
a maximal predecessor at a zone cut point then ``False`` must be specified.
|
||||
"""
|
||||
return _handle_relativity_and_call(
|
||||
_absolute_predecessor, self, origin, prefix_ok
|
||||
)
|
||||
|
||||
def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
|
||||
"""Return the minimal successor of *name* in the DNSSEC ordering in the zone
|
||||
whose origin is *origin*, or return *origin* if the successor cannot be
|
||||
computed due to name length limitations.
|
||||
|
||||
Note that *origin* is returned in the "too long" cases because wrapping
|
||||
around to the origin is how NSEC records express "end of the zone".
|
||||
|
||||
The relativity of the name is preserved, so if this name is relative
|
||||
then the method will return a relative name, and likewise if this name
|
||||
is absolute then the successor will be absolute.
|
||||
|
||||
*prefix_ok* indicates if prefixing a new minimal label is allowed, and
|
||||
defaults to ``True``. Normally it is good to allow this, but if computing
|
||||
a minimal successor at a zone cut point then ``False`` must be specified.
|
||||
"""
|
||||
return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok)
|
||||
|
||||
|
||||
#: The root name, '.'
|
||||
root = Name([b""])
|
||||
|
@ -1082,3 +1123,161 @@ def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
|
|||
parser = dns.wire.Parser(message, current)
|
||||
name = from_wire_parser(parser)
|
||||
return (name, parser.current - current)
|
||||
|
||||
|
||||
# RFC 4471 Support
|
||||
|
||||
_MINIMAL_OCTET = b"\x00"
|
||||
_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET)
|
||||
_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET])
|
||||
_MAXIMAL_OCTET = b"\xff"
|
||||
_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET)
|
||||
_AT_SIGN_VALUE = ord("@")
|
||||
_LEFT_SQUARE_BRACKET_VALUE = ord("[")
|
||||
|
||||
|
||||
def _wire_length(labels):
|
||||
return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0)
|
||||
|
||||
|
||||
def _pad_to_max_name(name):
|
||||
needed = 255 - _wire_length(name.labels)
|
||||
new_labels = []
|
||||
while needed > 64:
|
||||
new_labels.append(_MAXIMAL_OCTET * 63)
|
||||
needed -= 64
|
||||
if needed >= 2:
|
||||
new_labels.append(_MAXIMAL_OCTET * (needed - 1))
|
||||
# Note we're already maximal in the needed == 1 case as while we'd like
|
||||
# to add one more byte as a new label, we can't, as adding a new non-empty
|
||||
# label requires at least 2 bytes.
|
||||
new_labels = list(reversed(new_labels))
|
||||
new_labels.extend(name.labels)
|
||||
return Name(new_labels)
|
||||
|
||||
|
||||
def _pad_to_max_label(label, suffix_labels):
|
||||
length = len(label)
|
||||
# We have to subtract one here to account for the length byte of label.
|
||||
remaining = 255 - _wire_length(suffix_labels) - length - 1
|
||||
if remaining <= 0:
|
||||
# Shouldn't happen!
|
||||
return label
|
||||
needed = min(63 - length, remaining)
|
||||
return label + _MAXIMAL_OCTET * needed
|
||||
|
||||
|
||||
def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name:
|
||||
# This is the RFC 4471 predecessor algorithm using the "absolute method" of section
|
||||
# 3.1.1.
|
||||
#
|
||||
# Our caller must ensure that the name and origin are absolute, and that name is a
|
||||
# subdomain of origin.
|
||||
if name == origin:
|
||||
return _pad_to_max_name(name)
|
||||
least_significant_label = name[0]
|
||||
if least_significant_label == _MINIMAL_OCTET:
|
||||
return name.parent()
|
||||
least_octet = least_significant_label[-1]
|
||||
suffix_labels = name.labels[1:]
|
||||
if least_octet == _MINIMAL_OCTET_VALUE:
|
||||
new_labels = [least_significant_label[:-1]]
|
||||
else:
|
||||
octets = bytearray(least_significant_label)
|
||||
octet = octets[-1]
|
||||
if octet == _LEFT_SQUARE_BRACKET_VALUE:
|
||||
octet = _AT_SIGN_VALUE
|
||||
else:
|
||||
octet -= 1
|
||||
octets[-1] = octet
|
||||
least_significant_label = bytes(octets)
|
||||
new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)]
|
||||
new_labels.extend(suffix_labels)
|
||||
name = Name(new_labels)
|
||||
if prefix_ok:
|
||||
return _pad_to_max_name(name)
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name:
|
||||
# This is the RFC 4471 successor algorithm using the "absolute method" of section
|
||||
# 3.1.2.
|
||||
#
|
||||
# Our caller must ensure that the name and origin are absolute, and that name is a
|
||||
# subdomain of origin.
|
||||
if prefix_ok:
|
||||
# Try prefixing \000 as new label
|
||||
try:
|
||||
return _SUCCESSOR_PREFIX.concatenate(name)
|
||||
except NameTooLong:
|
||||
pass
|
||||
while name != origin:
|
||||
# Try extending the least significant label.
|
||||
least_significant_label = name[0]
|
||||
if len(least_significant_label) < 63:
|
||||
# We may be able to extend the least label with a minimal additional byte.
|
||||
# This is only "may" because we could have a maximal length name even though
|
||||
# the least significant label isn't maximally long.
|
||||
new_labels = [least_significant_label + _MINIMAL_OCTET]
|
||||
new_labels.extend(name.labels[1:])
|
||||
try:
|
||||
return dns.name.Name(new_labels)
|
||||
except dns.name.NameTooLong:
|
||||
pass
|
||||
# We can't extend the label either, so we'll try to increment the least
|
||||
# signficant non-maximal byte in it.
|
||||
octets = bytearray(least_significant_label)
|
||||
# We do this reversed iteration with an explicit indexing variable because
|
||||
# if we find something to increment, we're going to want to truncate everything
|
||||
# to the right of it.
|
||||
for i in range(len(octets) - 1, -1, -1):
|
||||
octet = octets[i]
|
||||
if octet == _MAXIMAL_OCTET_VALUE:
|
||||
# We can't increment this, so keep looking.
|
||||
continue
|
||||
# Finally, something we can increment. We have to apply a special rule for
|
||||
# incrementing "@", sending it to "[", because RFC 4034 6.1 says that when
|
||||
# comparing names, uppercase letters compare as if they were their
|
||||
# lower-case equivalents. If we increment "@" to "A", then it would compare
|
||||
# as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have
|
||||
# skipped the most minimal successor, namely "[".
|
||||
if octet == _AT_SIGN_VALUE:
|
||||
octet = _LEFT_SQUARE_BRACKET_VALUE
|
||||
else:
|
||||
octet += 1
|
||||
octets[i] = octet
|
||||
# We can now truncate all of the maximal values we skipped (if any)
|
||||
new_labels = [bytes(octets[: i + 1])]
|
||||
new_labels.extend(name.labels[1:])
|
||||
# We haven't changed the length of the name, so the Name constructor will
|
||||
# always work.
|
||||
return Name(new_labels)
|
||||
# We couldn't increment, so chop off the least significant label and try
|
||||
# again.
|
||||
name = name.parent()
|
||||
|
||||
# We couldn't increment at all, so return the origin, as wrapping around is the
|
||||
# DNSSEC way.
|
||||
return origin
|
||||
|
||||
|
||||
def _handle_relativity_and_call(
|
||||
function: Callable[[Name, Name, bool], Name],
|
||||
name: Name,
|
||||
origin: Name,
|
||||
prefix_ok: bool,
|
||||
) -> Name:
|
||||
# Make "name" absolute if needed, ensure that the origin is absolute,
|
||||
# call function(), and then relativize the result if needed.
|
||||
if not origin.is_absolute():
|
||||
raise NeedAbsoluteNameOrOrigin
|
||||
relative = not name.is_absolute()
|
||||
if relative:
|
||||
name = name.derelativize(origin)
|
||||
elif not name.is_subdomain(origin):
|
||||
raise NeedSubdomainOfOrigin
|
||||
result_name = function(name, origin, prefix_ok)
|
||||
if relative:
|
||||
result_name = result_name.relativize(origin)
|
||||
return result_name
|
||||
|
|
|
@ -115,6 +115,8 @@ class Do53Nameserver(AddressAndPortNameserver):
|
|||
raise_on_truncation=True,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
ignore_errors=True,
|
||||
ignore_unexpected=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -153,15 +155,25 @@ class Do53Nameserver(AddressAndPortNameserver):
|
|||
backend=backend,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
ignore_errors=True,
|
||||
ignore_unexpected=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class DoHNameserver(Nameserver):
|
||||
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
bootstrap_address: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
want_get: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.url = url
|
||||
self.bootstrap_address = bootstrap_address
|
||||
self.verify = verify
|
||||
self.want_get = want_get
|
||||
|
||||
def kind(self):
|
||||
return "DoH"
|
||||
|
@ -195,9 +207,13 @@ class DoHNameserver(Nameserver):
|
|||
request,
|
||||
self.url,
|
||||
timeout=timeout,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
bootstrap_address=self.bootstrap_address,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
verify=self.verify,
|
||||
post=(not self.want_get),
|
||||
)
|
||||
|
||||
async def async_query(
|
||||
|
@ -215,15 +231,27 @@ class DoHNameserver(Nameserver):
|
|||
request,
|
||||
self.url,
|
||||
timeout=timeout,
|
||||
source=source,
|
||||
source_port=source_port,
|
||||
bootstrap_address=self.bootstrap_address,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
verify=self.verify,
|
||||
post=(not self.want_get),
|
||||
)
|
||||
|
||||
|
||||
class DoTNameserver(AddressAndPortNameserver):
|
||||
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
address: str,
|
||||
port: int = 853,
|
||||
hostname: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
):
|
||||
super().__init__(address, port)
|
||||
self.hostname = hostname
|
||||
self.verify = verify
|
||||
|
||||
def kind(self):
|
||||
return "DoT"
|
||||
|
@ -246,6 +274,7 @@ class DoTNameserver(AddressAndPortNameserver):
|
|||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
server_hostname=self.hostname,
|
||||
verify=self.verify,
|
||||
)
|
||||
|
||||
async def async_query(
|
||||
|
@ -267,6 +296,7 @@ class DoTNameserver(AddressAndPortNameserver):
|
|||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
server_hostname=self.hostname,
|
||||
verify=self.verify,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -70,7 +70,6 @@ class NodeKind(enum.Enum):
|
|||
|
||||
|
||||
class Node:
|
||||
|
||||
"""A Node is a set of rdatasets.
|
||||
|
||||
A node is either a CNAME node or an "other data" node. A CNAME
|
||||
|
|
195
lib/dns/query.py
195
lib/dns/query.py
|
@ -22,12 +22,14 @@ import contextlib
|
|||
import enum
|
||||
import errno
|
||||
import os
|
||||
import os.path
|
||||
import selectors
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import dns._features
|
||||
import dns.exception
|
||||
import dns.inet
|
||||
import dns.message
|
||||
|
@ -57,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration):
|
|||
return min(time.time() + timeout, expiration)
|
||||
|
||||
|
||||
_have_httpx = False
|
||||
_have_http2 = False
|
||||
try:
|
||||
import httpcore
|
||||
_have_httpx = dns._features.have("doh")
|
||||
if _have_httpx:
|
||||
import httpcore._backends.sync
|
||||
import httpx
|
||||
|
||||
_CoreNetworkBackend = httpcore.NetworkBackend
|
||||
_CoreSyncStream = httpcore._backends.sync.SyncStream
|
||||
|
||||
_have_httpx = True
|
||||
try:
|
||||
# See if http2 support is available.
|
||||
with httpx.Client(http2=True):
|
||||
_have_http2 = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class _NetworkBackend(_CoreNetworkBackend):
|
||||
def __init__(self, resolver, local_port, bootstrap_address, family):
|
||||
super().__init__()
|
||||
|
@ -147,7 +139,7 @@ try:
|
|||
resolver, local_port, bootstrap_address, family
|
||||
)
|
||||
|
||||
except ImportError: # pragma: no cover
|
||||
else:
|
||||
|
||||
class _HTTPTransport: # type: ignore
|
||||
def connect_tcp(self, host, port, timeout, local_address):
|
||||
|
@ -161,6 +153,8 @@ try:
|
|||
except ImportError: # pragma: no cover
|
||||
|
||||
class ssl: # type: ignore
|
||||
CERT_NONE = 0
|
||||
|
||||
class WantReadException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -459,7 +453,7 @@ def https(
|
|||
transport = _HTTPTransport(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=_have_http2,
|
||||
http2=True,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
|
@ -470,9 +464,7 @@ def https(
|
|||
if session:
|
||||
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
|
||||
else:
|
||||
cm = httpx.Client(
|
||||
http1=True, http2=_have_http2, verify=verify, transport=transport
|
||||
)
|
||||
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
|
||||
with cm as session:
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
|
@ -577,6 +569,8 @@ def receive_udp(
|
|||
request_mac: Optional[bytes] = b"",
|
||||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
ignore_errors: bool = False,
|
||||
query: Optional[dns.message.Message] = None,
|
||||
) -> Any:
|
||||
"""Read a DNS message from a UDP socket.
|
||||
|
||||
|
@ -617,28 +611,58 @@ def receive_udp(
|
|||
``(dns.message.Message, float, tuple)``
|
||||
tuple of the received message, the received time, and the address where
|
||||
the message arrived from.
|
||||
|
||||
*ignore_errors*, a ``bool``. If various format errors or response
|
||||
mismatches occur, ignore them and keep listening for a valid response.
|
||||
The default is ``False``.
|
||||
|
||||
*query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
|
||||
*ignore_errors* is ``True``, check that the received message is a response
|
||||
to this query, and if not keep listening for a valid response.
|
||||
"""
|
||||
|
||||
wire = b""
|
||||
while True:
|
||||
(wire, from_address) = _udp_recv(sock, 65535, expiration)
|
||||
if _matches_destination(
|
||||
if not _matches_destination(
|
||||
sock.family, from_address, destination, ignore_unexpected
|
||||
):
|
||||
break
|
||||
received_time = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
raise_on_truncation=raise_on_truncation,
|
||||
)
|
||||
if destination:
|
||||
return (r, received_time)
|
||||
else:
|
||||
return (r, received_time, from_address)
|
||||
continue
|
||||
received_time = time.time()
|
||||
try:
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=keyring,
|
||||
request_mac=request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
raise_on_truncation=raise_on_truncation,
|
||||
)
|
||||
except dns.message.Truncated as e:
|
||||
# If we got Truncated and not FORMERR, we at least got the header with TC
|
||||
# set, and very likely the question section, so we'll re-raise if the
|
||||
# message seems to be a response as we need to know when truncation happens.
|
||||
# We need to check that it seems to be a response as we don't want a random
|
||||
# injected message with TC set to cause us to bail out.
|
||||
if (
|
||||
ignore_errors
|
||||
and query is not None
|
||||
and not query.is_response(e.message())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
if ignore_errors:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
if ignore_errors and query is not None and not query.is_response(r):
|
||||
continue
|
||||
if destination:
|
||||
return (r, received_time)
|
||||
else:
|
||||
return (r, received_time, from_address)
|
||||
|
||||
|
||||
def udp(
|
||||
|
@ -653,6 +677,7 @@ def udp(
|
|||
ignore_trailing: bool = False,
|
||||
raise_on_truncation: bool = False,
|
||||
sock: Optional[Any] = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via UDP.
|
||||
|
||||
|
@ -689,6 +714,10 @@ def udp(
|
|||
if a socket is provided, it must be a nonblocking datagram socket,
|
||||
and the *source* and *source_port* are ignored.
|
||||
|
||||
*ignore_errors*, a ``bool``. If various format errors or response
|
||||
mismatches occur, ignore them and keep listening for a valid response.
|
||||
The default is ``False``.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
|
||||
|
@ -713,9 +742,13 @@ def udp(
|
|||
q.mac,
|
||||
ignore_trailing,
|
||||
raise_on_truncation,
|
||||
ignore_errors,
|
||||
q,
|
||||
)
|
||||
r.time = received_time - begin_time
|
||||
if not q.is_response(r):
|
||||
# We don't need to check q.is_response() if we are in ignore_errors mode
|
||||
# as receive_udp() will have checked it.
|
||||
if not (ignore_errors or q.is_response(r)):
|
||||
raise BadResponse
|
||||
return r
|
||||
assert (
|
||||
|
@ -735,48 +768,50 @@ def udp_with_fallback(
|
|||
ignore_trailing: bool = False,
|
||||
udp_sock: Optional[Any] = None,
|
||||
tcp_sock: Optional[Any] = None,
|
||||
ignore_errors: bool = False,
|
||||
) -> Tuple[dns.message.Message, bool]:
|
||||
"""Return the response to the query, trying UDP first and falling back
|
||||
to TCP if UDP results in a truncated response.
|
||||
|
||||
*q*, a ``dns.message.Message``, the query to send
|
||||
|
||||
*where*, a ``str`` containing an IPv4 or IPv6 address, where
|
||||
to send the message.
|
||||
*where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message.
|
||||
|
||||
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
|
||||
query times out. If ``None``, the default, wait forever.
|
||||
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
|
||||
times out. If ``None``, the default, wait forever.
|
||||
|
||||
*port*, an ``int``, the port send the message to. The default is 53.
|
||||
|
||||
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
|
||||
the source address. The default is the wildcard address.
|
||||
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
|
||||
address. The default is the wildcard address.
|
||||
|
||||
*source_port*, an ``int``, the port from which to send the message.
|
||||
The default is 0.
|
||||
*source_port*, an ``int``, the port from which to send the message. The default is
|
||||
0.
|
||||
|
||||
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
|
||||
unexpected sources.
|
||||
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected
|
||||
sources.
|
||||
|
||||
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
|
||||
RRset.
|
||||
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
|
||||
|
||||
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
|
||||
junk at end of the received message.
|
||||
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
|
||||
received message.
|
||||
|
||||
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
|
||||
UDP query. If ``None``, the default, a socket is created. Note that
|
||||
if a socket is provided, it must be a nonblocking datagram socket,
|
||||
and the *source* and *source_port* are ignored for the UDP query.
|
||||
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
|
||||
If ``None``, the default, a socket is created. Note that if a socket is provided,
|
||||
it must be a nonblocking datagram socket, and the *source* and *source_port* are
|
||||
ignored for the UDP query.
|
||||
|
||||
*tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
|
||||
TCP query. If ``None``, the default, a socket is created. Note that
|
||||
if a socket is provided, it must be a nonblocking connected stream
|
||||
socket, and *where*, *source* and *source_port* are ignored for the TCP
|
||||
query.
|
||||
TCP query. If ``None``, the default, a socket is created. Note that if a socket is
|
||||
provided, it must be a nonblocking connected stream socket, and *where*, *source*
|
||||
and *source_port* are ignored for the TCP query.
|
||||
|
||||
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
|
||||
if and only if TCP was used.
|
||||
*ignore_errors*, a ``bool``. If various format errors or response mismatches occur
|
||||
while listening for UDP, ignore them and keep listening for a valid response. The
|
||||
default is ``False``.
|
||||
|
||||
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
|
||||
TCP was used.
|
||||
"""
|
||||
try:
|
||||
response = udp(
|
||||
|
@ -791,6 +826,7 @@ def udp_with_fallback(
|
|||
ignore_trailing,
|
||||
True,
|
||||
udp_sock,
|
||||
ignore_errors,
|
||||
)
|
||||
return (response, False)
|
||||
except dns.message.Truncated:
|
||||
|
@ -864,14 +900,12 @@ def send_tcp(
|
|||
"""
|
||||
|
||||
if isinstance(what, dns.message.Message):
|
||||
wire = what.to_wire()
|
||||
tcpmsg = what.to_wire(prepend_length=True)
|
||||
else:
|
||||
wire = what
|
||||
l = len(wire)
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = struct.pack("!H", l) + wire
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = len(what).to_bytes(2, "big") + what
|
||||
sent_time = time.time()
|
||||
_net_write(sock, tcpmsg, expiration)
|
||||
return (len(tcpmsg), sent_time)
|
||||
|
@ -1014,6 +1048,28 @@ def _tls_handshake(s, expiration):
|
|||
_wait_for_writable(s, expiration)
|
||||
|
||||
|
||||
def _make_dot_ssl_context(
|
||||
server_hostname: Optional[str], verify: Union[bool, str]
|
||||
) -> ssl.SSLContext:
|
||||
cafile: Optional[str] = None
|
||||
capath: Optional[str] = None
|
||||
if isinstance(verify, str):
|
||||
if os.path.isfile(verify):
|
||||
cafile = verify
|
||||
elif os.path.isdir(verify):
|
||||
capath = verify
|
||||
else:
|
||||
raise ValueError("invalid verify string")
|
||||
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
|
||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
if server_hostname is None:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.set_alpn_protocols(["dot"])
|
||||
if verify is False:
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
return ssl_context
|
||||
|
||||
|
||||
def tls(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
|
@ -1026,6 +1082,7 @@ def tls(
|
|||
sock: Optional[ssl.SSLSocket] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via TLS.
|
||||
|
||||
|
@ -1065,6 +1122,11 @@ def tls(
|
|||
default is ``None``, which means that no hostname is known, and if an
|
||||
SSL context is created, hostname checking will be disabled.
|
||||
|
||||
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
|
||||
of the server is done using the default CA bundle; if ``False``, then no
|
||||
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||
directory which will be used for verification.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
|
||||
"""
|
||||
|
@ -1091,10 +1153,7 @@ def tls(
|
|||
where, port, source, source_port
|
||||
)
|
||||
if ssl_context is None and not sock:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
if server_hostname is None:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context = _make_dot_ssl_context(server_hostname, verify)
|
||||
|
||||
with _make_socket(
|
||||
af,
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
try:
|
||||
import dns._features
|
||||
import dns.asyncbackend
|
||||
|
||||
if dns._features.have("doq"):
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
|
||||
import dns.asyncbackend
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.quic._asyncio import (
|
||||
AsyncioQuicConnection,
|
||||
|
@ -17,7 +19,7 @@ try:
|
|||
|
||||
def null_factory(
|
||||
*args, # pylint: disable=unused-argument
|
||||
**kwargs # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
return NullContext(None)
|
||||
|
||||
|
@ -31,7 +33,7 @@ try:
|
|||
|
||||
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
|
||||
|
||||
try:
|
||||
if dns._features.have("trio"):
|
||||
import trio
|
||||
|
||||
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
|
||||
|
@ -47,15 +49,13 @@ try:
|
|||
return TrioQuicManager(context, *args, **kwargs)
|
||||
|
||||
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def factories_for_backend(backend=None):
|
||||
if backend is None:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
return _async_factories[backend.name()]
|
||||
|
||||
except ImportError:
|
||||
else: # pragma: no cover
|
||||
have_quic = False
|
||||
|
||||
from typing import Any
|
||||
|
|
|
@ -101,9 +101,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
)
|
||||
if address[0] != self._peer[0] or address[1] != self._peer[1]:
|
||||
continue
|
||||
self._connection.receive_datagram(
|
||||
datagram, self._peer[0], time.time()
|
||||
)
|
||||
self._connection.receive_datagram(datagram, address, time.time())
|
||||
# Wake up the timer in case the sender is sleeping, as there may be
|
||||
# stuff to send now.
|
||||
async with self._wake_timer:
|
||||
|
@ -125,7 +123,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
while not self._done:
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for datagram, address in datagrams:
|
||||
assert address == self._peer[0]
|
||||
assert address == self._peer
|
||||
await self._socket.sendto(datagram, self._peer, None)
|
||||
(expiration, interval) = self._get_timer_values()
|
||||
try:
|
||||
|
@ -147,11 +145,14 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
await stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(
|
||||
event, aioquic.quic.events.ConnectionTerminated
|
||||
) or isinstance(event, aioquic.quic.events.StreamReset):
|
||||
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
|
||||
self._done = True
|
||||
self._receiver_task.cancel()
|
||||
elif isinstance(event, aioquic.quic.events.StreamReset):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(b"", True)
|
||||
|
||||
count += 1
|
||||
if count > 10:
|
||||
# yield
|
||||
|
@ -188,7 +189,6 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
self._connection.close()
|
||||
# sender might be blocked on this, so set it
|
||||
self._socket_created.set()
|
||||
await self._socket.close()
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
try:
|
||||
|
@ -199,14 +199,19 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
|||
await self._sender_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self._socket.close()
|
||||
|
||||
|
||||
class AsyncioQuicManager(AsyncQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
def connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
):
|
||||
(connection, start) = self._connect(
|
||||
address, port, source, source_port, want_session_ticket
|
||||
)
|
||||
if start:
|
||||
connection.run()
|
||||
return connection
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
|
@ -11,6 +13,10 @@ import aioquic.quic.connection # type: ignore
|
|||
import dns.inet
|
||||
|
||||
QUIC_MAX_DATAGRAM = 2048
|
||||
MAX_SESSION_TICKETS = 8
|
||||
# If we hit the max sessions limit we will delete this many of the oldest connections.
|
||||
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
|
||||
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
|
||||
|
||||
|
||||
class UnexpectedEOF(Exception):
|
||||
|
@ -79,7 +85,10 @@ class BaseQuicStream:
|
|||
|
||||
def _common_add_input(self, data, is_end):
|
||||
self._buffer.put(data, is_end)
|
||||
return self._expecting > 0 and self._buffer.have(self._expecting)
|
||||
try:
|
||||
return self._expecting > 0 and self._buffer.have(self._expecting)
|
||||
except UnexpectedEOF:
|
||||
return True
|
||||
|
||||
def _close(self):
|
||||
self._connection.close_stream(self._stream_id)
|
||||
|
@ -142,6 +151,7 @@ class BaseQuicManager:
|
|||
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
|
||||
self._connections = {}
|
||||
self._connection_factory = connection_factory
|
||||
self._session_tickets = {}
|
||||
if conf is None:
|
||||
verify_path = None
|
||||
if isinstance(verify_mode, str):
|
||||
|
@ -156,12 +166,35 @@ class BaseQuicManager:
|
|||
conf.load_verify_locations(verify_path)
|
||||
self._conf = conf
|
||||
|
||||
def _connect(self, address, port=853, source=None, source_port=0):
|
||||
def _connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
):
|
||||
connection = self._connections.get((address, port))
|
||||
if connection is not None:
|
||||
return (connection, False)
|
||||
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
|
||||
qconn.connect(address, time.time())
|
||||
conf = self._conf
|
||||
if want_session_ticket:
|
||||
try:
|
||||
session_ticket = self._session_tickets.pop((address, port))
|
||||
# We found a session ticket, so make a configuration that uses it.
|
||||
conf = copy.copy(conf)
|
||||
conf.session_ticket = session_ticket
|
||||
except KeyError:
|
||||
# No session ticket.
|
||||
pass
|
||||
# Whether or not we found a session ticket, we want a handler to save
|
||||
# one.
|
||||
session_ticket_handler = functools.partial(
|
||||
self.save_session_ticket, address, port
|
||||
)
|
||||
else:
|
||||
session_ticket_handler = None
|
||||
qconn = aioquic.quic.connection.QuicConnection(
|
||||
configuration=conf,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
)
|
||||
lladdress = dns.inet.low_level_address_tuple((address, port))
|
||||
qconn.connect(lladdress, time.time())
|
||||
connection = self._connection_factory(
|
||||
qconn, address, port, source, source_port, self
|
||||
)
|
||||
|
@ -174,6 +207,17 @@ class BaseQuicManager:
|
|||
except KeyError:
|
||||
pass
|
||||
|
||||
def save_session_ticket(self, address, port, ticket):
|
||||
# We rely on dictionaries keys() being in insertion order here. We
|
||||
# can't just popitem() as that would be LIFO which is the opposite of
|
||||
# what we want.
|
||||
l = len(self._session_tickets)
|
||||
if l >= MAX_SESSION_TICKETS:
|
||||
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
|
||||
for key in keys_to_delete:
|
||||
del self._session_tickets[key]
|
||||
self._session_tickets[(address, port)] = ticket
|
||||
|
||||
|
||||
class AsyncQuicManager(BaseQuicManager):
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
|
|
|
@ -82,10 +82,6 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
def __init__(self, connection, address, port, source, source_port, manager):
|
||||
super().__init__(connection, address, port, source, source_port, manager)
|
||||
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
|
||||
self._socket.connect(self._peer)
|
||||
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
|
||||
self._receive_wakeup.setblocking(False)
|
||||
self._socket.setblocking(False)
|
||||
if self._source is not None:
|
||||
try:
|
||||
self._socket.bind(
|
||||
|
@ -94,6 +90,10 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
except Exception:
|
||||
self._socket.close()
|
||||
raise
|
||||
self._socket.connect(self._peer)
|
||||
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
|
||||
self._receive_wakeup.setblocking(False)
|
||||
self._socket.setblocking(False)
|
||||
self._handshake_complete = threading.Event()
|
||||
self._worker_thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
@ -107,7 +107,7 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
except BlockingIOError:
|
||||
return
|
||||
with self._lock:
|
||||
self._connection.receive_datagram(datagram, self._peer[0], time.time())
|
||||
self._connection.receive_datagram(datagram, self._peer, time.time())
|
||||
|
||||
def _drain_wakeup(self):
|
||||
while True:
|
||||
|
@ -128,6 +128,8 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
key.data()
|
||||
with self._lock:
|
||||
self._handle_timer(expiration)
|
||||
self._handle_events()
|
||||
with self._lock:
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for datagram, _ in datagrams:
|
||||
try:
|
||||
|
@ -135,7 +137,6 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
except BlockingIOError:
|
||||
# we let QUIC handle any lossage
|
||||
pass
|
||||
self._handle_events()
|
||||
finally:
|
||||
with self._lock:
|
||||
self._done = True
|
||||
|
@ -155,11 +156,14 @@ class SyncQuicConnection(BaseQuicConnection):
|
|||
stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(
|
||||
event, aioquic.quic.events.ConnectionTerminated
|
||||
) or isinstance(event, aioquic.quic.events.StreamReset):
|
||||
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
|
||||
with self._lock:
|
||||
self._done = True
|
||||
elif isinstance(event, aioquic.quic.events.StreamReset):
|
||||
with self._lock:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
stream._add_input(b"", True)
|
||||
|
||||
def write(self, stream, data, is_end=False):
|
||||
with self._lock:
|
||||
|
@ -203,9 +207,13 @@ class SyncQuicManager(BaseQuicManager):
|
|||
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
def connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
):
|
||||
with self._lock:
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
(connection, start) = self._connect(
|
||||
address, port, source, source_port, want_session_ticket
|
||||
)
|
||||
if start:
|
||||
connection.run()
|
||||
return connection
|
||||
|
@ -214,6 +222,10 @@ class SyncQuicManager(BaseQuicManager):
|
|||
with self._lock:
|
||||
super().closed(address, port)
|
||||
|
||||
def save_session_ticket(self, address, port, ticket):
|
||||
with self._lock:
|
||||
super().save_session_ticket(address, port, ticket)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
|
|
@ -76,30 +76,43 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
def __init__(self, connection, address, port, source, source_port, manager=None):
|
||||
super().__init__(connection, address, port, source, source_port, manager)
|
||||
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
|
||||
if self._source:
|
||||
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
|
||||
self._handshake_complete = trio.Event()
|
||||
self._run_done = trio.Event()
|
||||
self._worker_scope = None
|
||||
self._send_pending = False
|
||||
|
||||
async def _worker(self):
|
||||
try:
|
||||
if self._source:
|
||||
await self._socket.bind(
|
||||
dns.inet.low_level_address_tuple(self._source, self._af)
|
||||
)
|
||||
await self._socket.connect(self._peer)
|
||||
while not self._done:
|
||||
(expiration, interval) = self._get_timer_values(False)
|
||||
if self._send_pending:
|
||||
# Do not block forever if sends are pending. Even though we
|
||||
# have a wake-up mechanism if we've already started the blocking
|
||||
# read, the possibility of context switching in send means that
|
||||
# more writes can happen while we have no wake up context, so
|
||||
# we need self._send_pending to avoid (effectively) a "lost wakeup"
|
||||
# race.
|
||||
interval = 0.0
|
||||
with trio.CancelScope(
|
||||
deadline=trio.current_time() + interval
|
||||
) as self._worker_scope:
|
||||
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
|
||||
self._connection.receive_datagram(
|
||||
datagram, self._peer[0], time.time()
|
||||
)
|
||||
self._connection.receive_datagram(datagram, self._peer, time.time())
|
||||
self._worker_scope = None
|
||||
self._handle_timer(expiration)
|
||||
await self._handle_events()
|
||||
# We clear this now, before sending anything, as sending can cause
|
||||
# context switches that do more sends. We want to know if that
|
||||
# happens so we don't block a long time on the recv() above.
|
||||
self._send_pending = False
|
||||
datagrams = self._connection.datagrams_to_send(time.time())
|
||||
for datagram, _ in datagrams:
|
||||
await self._socket.send(datagram)
|
||||
await self._handle_events()
|
||||
finally:
|
||||
self._done = True
|
||||
self._handshake_complete.set()
|
||||
|
@ -116,11 +129,13 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
await stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(
|
||||
event, aioquic.quic.events.ConnectionTerminated
|
||||
) or isinstance(event, aioquic.quic.events.StreamReset):
|
||||
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
|
||||
self._done = True
|
||||
self._socket.close()
|
||||
elif isinstance(event, aioquic.quic.events.StreamReset):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(b"", True)
|
||||
count += 1
|
||||
if count > 10:
|
||||
# yield
|
||||
|
@ -129,6 +144,7 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
|
||||
async def write(self, stream, data, is_end=False):
|
||||
self._connection.send_stream_data(stream, data, is_end)
|
||||
self._send_pending = True
|
||||
if self._worker_scope is not None:
|
||||
self._worker_scope.cancel()
|
||||
|
||||
|
@ -159,6 +175,7 @@ class TrioQuicConnection(AsyncQuicConnection):
|
|||
self._manager.closed(self._peer[0], self._peer[1])
|
||||
self._closed = True
|
||||
self._connection.close()
|
||||
self._send_pending = True
|
||||
if self._worker_scope is not None:
|
||||
self._worker_scope.cancel()
|
||||
await self._run_done.wait()
|
||||
|
@ -171,8 +188,12 @@ class TrioQuicManager(AsyncQuicManager):
|
|||
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
|
||||
self._nursery = nursery
|
||||
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
(connection, start) = self._connect(address, port, source, source_port)
|
||||
def connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
):
|
||||
(connection, start) = self._connect(
|
||||
address, port, source, source_port, want_session_ticket
|
||||
)
|
||||
if start:
|
||||
self._nursery.start_soon(connection.run)
|
||||
return connection
|
||||
|
|
|
@ -199,7 +199,7 @@ class Rdata:
|
|||
self,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
**kw: Dict[str, Any]
|
||||
**kw: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Convert an rdata to text format.
|
||||
|
||||
|
@ -547,9 +547,7 @@ class Rdata:
|
|||
@classmethod
|
||||
def _as_ipv4_address(cls, value):
|
||||
if isinstance(value, str):
|
||||
# call to check validity
|
||||
dns.ipv4.inet_aton(value)
|
||||
return value
|
||||
return dns.ipv4.canonicalize(value)
|
||||
elif isinstance(value, bytes):
|
||||
return dns.ipv4.inet_ntoa(value)
|
||||
else:
|
||||
|
@ -558,9 +556,7 @@ class Rdata:
|
|||
@classmethod
|
||||
def _as_ipv6_address(cls, value):
|
||||
if isinstance(value, str):
|
||||
# call to check validity
|
||||
dns.ipv6.inet_aton(value)
|
||||
return value
|
||||
return dns.ipv6.canonicalize(value)
|
||||
elif isinstance(value, bytes):
|
||||
return dns.ipv6.inet_ntoa(value)
|
||||
else:
|
||||
|
@ -604,7 +600,6 @@ class Rdata:
|
|||
|
||||
@dns.immutable.immutable
|
||||
class GenericRdata(Rdata):
|
||||
|
||||
"""Generic Rdata Class
|
||||
|
||||
This class is used for rdata types for which we have no better
|
||||
|
@ -621,7 +616,7 @@ class GenericRdata(Rdata):
|
|||
self,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
**kw: Dict[str, Any]
|
||||
**kw: Dict[str, Any],
|
||||
) -> str:
|
||||
return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
|
||||
|
||||
|
@ -647,9 +642,9 @@ class GenericRdata(Rdata):
|
|||
return cls(rdclass, rdtype, parser.get_remaining())
|
||||
|
||||
|
||||
_rdata_classes: Dict[
|
||||
Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any
|
||||
] = {}
|
||||
_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = (
|
||||
{}
|
||||
)
|
||||
_module_prefix = "dns.rdtypes"
|
||||
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import dns.name
|
|||
import dns.rdata
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.renderer
|
||||
import dns.set
|
||||
import dns.ttl
|
||||
|
||||
|
@ -45,7 +46,6 @@ class IncompatibleTypes(dns.exception.DNSException):
|
|||
|
||||
|
||||
class Rdataset(dns.set.Set):
|
||||
|
||||
"""A DNS rdataset."""
|
||||
|
||||
__slots__ = ["rdclass", "rdtype", "covers", "ttl"]
|
||||
|
@ -316,11 +316,9 @@ class Rdataset(dns.set.Set):
|
|||
want_shuffle = False
|
||||
else:
|
||||
rdclass = self.rdclass
|
||||
file.seek(0, io.SEEK_END)
|
||||
if len(self) == 0:
|
||||
name.to_wire(file, compress, origin)
|
||||
stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)
|
||||
file.write(stuff)
|
||||
file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
|
||||
return 1
|
||||
else:
|
||||
l: Union[Rdataset, List[dns.rdata.Rdata]]
|
||||
|
@ -331,16 +329,9 @@ class Rdataset(dns.set.Set):
|
|||
l = self
|
||||
for rd in l:
|
||||
name.to_wire(file, compress, origin)
|
||||
stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
|
||||
file.write(stuff)
|
||||
start = file.tell()
|
||||
rd.to_wire(file, compress, origin)
|
||||
end = file.tell()
|
||||
assert end - start < 65536
|
||||
file.seek(start - 2)
|
||||
stuff = struct.pack("!H", end - start)
|
||||
file.write(stuff)
|
||||
file.seek(0, io.SEEK_END)
|
||||
file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
|
||||
with dns.renderer.prefixed_length(file, 2):
|
||||
rd.to_wire(file, compress, origin)
|
||||
return len(self)
|
||||
|
||||
def match(
|
||||
|
@ -373,7 +364,6 @@ class Rdataset(dns.set.Set):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
|
||||
|
||||
"""An immutable DNS rdataset."""
|
||||
|
||||
_clone_class = Rdataset
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.mxbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX):
|
||||
|
||||
"""AFSDB record"""
|
||||
|
||||
# Use the property mechanism to make "subtype" an alias for the
|
||||
|
|
|
@ -32,7 +32,6 @@ class Relay(dns.rdtypes.util.Gateway):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class AMTRELAY(dns.rdata.Rdata):
|
||||
|
||||
"""AMTRELAY record"""
|
||||
|
||||
# see: RFC 8777
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class AVC(dns.rdtypes.txtbase.TXTBase):
|
||||
|
||||
"""AVC record"""
|
||||
|
||||
# See: IANA dns parameters for AVC
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class CAA(dns.rdata.Rdata):
|
||||
|
||||
"""CAA (Certification Authority Authorization) record"""
|
||||
|
||||
# see: RFC 6844
|
||||
|
|
|
@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
|
|||
|
||||
@dns.immutable.immutable
|
||||
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
|
||||
|
||||
"""CDNSKEY record"""
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.dsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class CDS(dns.rdtypes.dsbase.DSBase):
|
||||
|
||||
"""CDS record"""
|
||||
|
||||
_digest_length_by_type = {
|
||||
|
|
|
@ -67,7 +67,6 @@ def _ctype_to_text(what):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class CERT(dns.rdata.Rdata):
|
||||
|
||||
"""CERT record"""
|
||||
|
||||
# see RFC 4398
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.nsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class CNAME(dns.rdtypes.nsbase.NSBase):
|
||||
|
||||
"""CNAME record
|
||||
|
||||
Note: although CNAME is officially a singleton type, dnspython allows
|
||||
|
|
|
@ -32,7 +32,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class CSYNC(dns.rdata.Rdata):
|
||||
|
||||
"""CSYNC record"""
|
||||
|
||||
__slots__ = ["serial", "flags", "windows"]
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.dsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DLV(dns.rdtypes.dsbase.DSBase):
|
||||
|
||||
"""DLV record"""
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.nsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DNAME(dns.rdtypes.nsbase.UncompressedNS):
|
||||
|
||||
"""DNAME record"""
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
|
|
|
@ -30,5 +30,4 @@ from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import]
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
|
||||
|
||||
"""DNSKEY record"""
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.dsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DS(dns.rdtypes.dsbase.DSBase):
|
||||
|
||||
"""DS record"""
|
||||
|
|
|
@ -22,7 +22,6 @@ import dns.rdtypes.euibase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class EUI48(dns.rdtypes.euibase.EUIBase):
|
||||
|
||||
"""EUI48 record"""
|
||||
|
||||
# see: rfc7043.txt
|
||||
|
|
|
@ -22,7 +22,6 @@ import dns.rdtypes.euibase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class EUI64(dns.rdtypes.euibase.EUIBase):
|
||||
|
||||
"""EUI64 record"""
|
||||
|
||||
# see: rfc7043.txt
|
||||
|
|
|
@ -44,7 +44,6 @@ def _validate_float_string(what):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class GPOS(dns.rdata.Rdata):
|
||||
|
||||
"""GPOS record"""
|
||||
|
||||
# see: RFC 1712
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class HINFO(dns.rdata.Rdata):
|
||||
|
||||
"""HINFO record"""
|
||||
|
||||
# see: RFC 1035
|
||||
|
|
|
@ -27,7 +27,6 @@ import dns.rdatatype
|
|||
|
||||
@dns.immutable.immutable
|
||||
class HIP(dns.rdata.Rdata):
|
||||
|
||||
"""HIP record"""
|
||||
|
||||
# see: RFC 5205
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class ISDN(dns.rdata.Rdata):
|
||||
|
||||
"""ISDN record"""
|
||||
|
||||
# see: RFC 1183
|
||||
|
|
|
@ -8,7 +8,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class L32(dns.rdata.Rdata):
|
||||
|
||||
"""L32 record"""
|
||||
|
||||
# see: rfc6742.txt
|
||||
|
|
|
@ -8,7 +8,6 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class L64(dns.rdata.Rdata):
|
||||
|
||||
"""L64 record"""
|
||||
|
||||
# see: rfc6742.txt
|
||||
|
|
|
@ -105,7 +105,6 @@ def _check_coordinate_list(value, low, high):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class LOC(dns.rdata.Rdata):
|
||||
|
||||
"""LOC record"""
|
||||
|
||||
# see: RFC 1876
|
||||
|
|
|
@ -8,7 +8,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class LP(dns.rdata.Rdata):
|
||||
|
||||
"""LP record"""
|
||||
|
||||
# see: rfc6742.txt
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class MX(dns.rdtypes.mxbase.MXBase):
|
||||
|
||||
"""MX record"""
|
||||
|
|
|
@ -8,7 +8,6 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NID(dns.rdata.Rdata):
|
||||
|
||||
"""NID record"""
|
||||
|
||||
# see: rfc6742.txt
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NINFO(dns.rdtypes.txtbase.TXTBase):
|
||||
|
||||
"""NINFO record"""
|
||||
|
||||
# see: draft-reid-dnsext-zs-01
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NS(dns.rdtypes.nsbase.NSBase):
|
||||
|
||||
"""NS record"""
|
||||
|
|
|
@ -30,7 +30,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NSEC(dns.rdata.Rdata):
|
||||
|
||||
"""NSEC record"""
|
||||
|
||||
__slots__ = ["next", "windows"]
|
||||
|
|
|
@ -46,7 +46,6 @@ class Bitmap(dns.rdtypes.util.Bitmap):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NSEC3(dns.rdata.Rdata):
|
||||
|
||||
"""NSEC3 record"""
|
||||
|
||||
__slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
|
||||
|
@ -64,9 +63,13 @@ class NSEC3(dns.rdata.Rdata):
|
|||
windows = Bitmap(windows)
|
||||
self.windows = tuple(windows.windows)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
def _next_text(self):
|
||||
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
|
||||
next = next.rstrip("=")
|
||||
return next
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
next = self._next_text()
|
||||
if self.salt == b"":
|
||||
salt = "-"
|
||||
else:
|
||||
|
@ -118,3 +121,6 @@ class NSEC3(dns.rdata.Rdata):
|
|||
next = parser.get_counted_bytes()
|
||||
bitmap = Bitmap.from_wire_parser(parser)
|
||||
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
|
||||
|
||||
def next_name(self, origin=None):
|
||||
return dns.name.from_text(self._next_text(), origin)
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NSEC3PARAM(dns.rdata.Rdata):
|
||||
|
||||
"""NSEC3PARAM record"""
|
||||
|
||||
__slots__ = ["algorithm", "flags", "iterations", "salt"]
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class OPENPGPKEY(dns.rdata.Rdata):
|
||||
|
||||
"""OPENPGPKEY record"""
|
||||
|
||||
# see: RFC 7929
|
||||
|
|
|
@ -28,7 +28,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class OPT(dns.rdata.Rdata):
|
||||
|
||||
"""OPT record"""
|
||||
|
||||
__slots__ = ["options"]
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class PTR(dns.rdtypes.nsbase.NSBase):
|
||||
|
||||
"""PTR record"""
|
||||
|
|
|
@ -23,7 +23,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class RP(dns.rdata.Rdata):
|
||||
|
||||
"""RP record"""
|
||||
|
||||
# see: RFC 1183
|
||||
|
|
|
@ -28,7 +28,6 @@ import dns.rdatatype
|
|||
|
||||
|
||||
class BadSigTime(dns.exception.DNSException):
|
||||
|
||||
"""Time in DNS SIG or RRSIG resource record cannot be parsed."""
|
||||
|
||||
|
||||
|
@ -52,7 +51,6 @@ def posixtime_to_sigtime(what):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class RRSIG(dns.rdata.Rdata):
|
||||
|
||||
"""RRSIG record"""
|
||||
|
||||
__slots__ = [
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX):
|
||||
|
||||
"""RT record"""
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class SOA(dns.rdata.Rdata):
|
||||
|
||||
"""SOA record"""
|
||||
|
||||
# see: RFC 1035
|
||||
|
|
|
@ -21,7 +21,6 @@ import dns.rdtypes.txtbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class SPF(dns.rdtypes.txtbase.TXTBase):
|
||||
|
||||
"""SPF record"""
|
||||
|
||||
# see: RFC 4408
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdatatype
|
|||
|
||||
@dns.immutable.immutable
|
||||
class SSHFP(dns.rdata.Rdata):
|
||||
|
||||
"""SSHFP record"""
|
||||
|
||||
# See RFC 4255
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class TKEY(dns.rdata.Rdata):
|
||||
|
||||
"""TKEY Record"""
|
||||
|
||||
__slots__ = [
|
||||
|
|
|
@ -6,5 +6,4 @@ import dns.rdtypes.tlsabase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class TLSA(dns.rdtypes.tlsabase.TLSABase):
|
||||
|
||||
"""TLSA record"""
|
||||
|
|
|
@ -26,7 +26,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class TSIG(dns.rdata.Rdata):
|
||||
|
||||
"""TSIG record"""
|
||||
|
||||
__slots__ = [
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.txtbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class TXT(dns.rdtypes.txtbase.TXTBase):
|
||||
|
||||
"""TXT record"""
|
||||
|
|
|
@ -27,7 +27,6 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class URI(dns.rdata.Rdata):
|
||||
|
||||
"""URI record"""
|
||||
|
||||
# see RFC 7553
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class X25(dns.rdata.Rdata):
|
||||
|
||||
"""X25 record"""
|
||||
|
||||
# see RFC 1183
|
||||
|
|
|
@ -11,7 +11,6 @@ import dns.zonetypes
|
|||
|
||||
@dns.immutable.immutable
|
||||
class ZONEMD(dns.rdata.Rdata):
|
||||
|
||||
"""ZONEMD record"""
|
||||
|
||||
# See RFC 8976
|
||||
|
|
|
@ -23,7 +23,6 @@ import dns.rdtypes.mxbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class A(dns.rdata.Rdata):
|
||||
|
||||
"""A record for Chaosnet"""
|
||||
|
||||
# domain: the domain of the address
|
||||
|
|
|
@ -24,7 +24,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class A(dns.rdata.Rdata):
|
||||
|
||||
"""A record."""
|
||||
|
||||
__slots__ = ["address"]
|
||||
|
|
|
@ -24,7 +24,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class AAAA(dns.rdata.Rdata):
|
||||
|
||||
"""AAAA record."""
|
||||
|
||||
__slots__ = ["address"]
|
||||
|
|
|
@ -29,7 +29,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class APLItem:
|
||||
|
||||
"""An APL list item."""
|
||||
|
||||
__slots__ = ["family", "negation", "address", "prefix"]
|
||||
|
@ -80,7 +79,6 @@ class APLItem:
|
|||
|
||||
@dns.immutable.immutable
|
||||
class APL(dns.rdata.Rdata):
|
||||
|
||||
"""APL record."""
|
||||
|
||||
# see: RFC 3123
|
||||
|
|
|
@ -24,7 +24,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DHCID(dns.rdata.Rdata):
|
||||
|
||||
"""DHCID record"""
|
||||
|
||||
# see: RFC 4701
|
||||
|
|
|
@ -29,7 +29,6 @@ class Gateway(dns.rdtypes.util.Gateway):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class IPSECKEY(dns.rdata.Rdata):
|
||||
|
||||
"""IPSECKEY record"""
|
||||
|
||||
# see: RFC 4025
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.mxbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX):
|
||||
|
||||
"""KX record"""
|
||||
|
|
|
@ -33,7 +33,6 @@ def _write_string(file, s):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NAPTR(dns.rdata.Rdata):
|
||||
|
||||
"""NAPTR record"""
|
||||
|
||||
# see: RFC 3403
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.tokenizer
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NSAP(dns.rdata.Rdata):
|
||||
|
||||
"""NSAP record."""
|
||||
|
||||
# see: RFC 1706
|
||||
|
|
|
@ -21,5 +21,4 @@ import dns.rdtypes.nsbase
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS):
|
||||
|
||||
"""NSAP-PTR record"""
|
||||
|
|
|
@ -26,7 +26,6 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class PX(dns.rdata.Rdata):
|
||||
|
||||
"""PX record."""
|
||||
|
||||
# see: RFC 2163
|
||||
|
|
|
@ -26,7 +26,6 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class SRV(dns.rdata.Rdata):
|
||||
|
||||
"""SRV record"""
|
||||
|
||||
# see: RFC 2782
|
||||
|
|
|
@ -33,7 +33,6 @@ except OSError:
|
|||
|
||||
@dns.immutable.immutable
|
||||
class WKS(dns.rdata.Rdata):
|
||||
|
||||
"""WKS record"""
|
||||
|
||||
# see: RFC 1035
|
||||
|
|
|
@ -36,7 +36,6 @@ class Flag(enum.IntFlag):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DNSKEYBase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for rdata that is like a DNSKEY record"""
|
||||
|
||||
__slots__ = ["flags", "protocol", "algorithm", "key"]
|
||||
|
|
|
@ -26,7 +26,6 @@ import dns.rdatatype
|
|||
|
||||
@dns.immutable.immutable
|
||||
class DSBase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for rdata that is like a DS record"""
|
||||
|
||||
__slots__ = ["key_tag", "algorithm", "digest_type", "digest"]
|
||||
|
|
|
@ -22,7 +22,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class EUIBase(dns.rdata.Rdata):
|
||||
|
||||
"""EUIxx record"""
|
||||
|
||||
# see: rfc7043.txt
|
||||
|
|
|
@ -28,7 +28,6 @@ import dns.rdtypes.util
|
|||
|
||||
@dns.immutable.immutable
|
||||
class MXBase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for rdata that is like an MX record."""
|
||||
|
||||
__slots__ = ["preference", "exchange"]
|
||||
|
@ -71,7 +70,6 @@ class MXBase(dns.rdata.Rdata):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class UncompressedMX(MXBase):
|
||||
|
||||
"""Base class for rdata that is like an MX record, but whose name
|
||||
is not compressed when converted to DNS wire format, and whose
|
||||
digestable form is not downcased."""
|
||||
|
@ -82,7 +80,6 @@ class UncompressedMX(MXBase):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class UncompressedDowncasingMX(MXBase):
|
||||
|
||||
"""Base class for rdata that is like an MX record, but whose name
|
||||
is not compressed when convert to DNS wire format."""
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdata
|
|||
|
||||
@dns.immutable.immutable
|
||||
class NSBase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for rdata that is like an NS record."""
|
||||
|
||||
__slots__ = ["target"]
|
||||
|
@ -56,7 +55,6 @@ class NSBase(dns.rdata.Rdata):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class UncompressedNS(NSBase):
|
||||
|
||||
"""Base class for rdata that is like an NS record, but whose name
|
||||
is not compressed when convert to DNS wire format, and whose
|
||||
digestable form is not downcased."""
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import base64
|
||||
import enum
|
||||
import io
|
||||
import struct
|
||||
|
||||
import dns.enum
|
||||
|
@ -13,6 +12,7 @@ import dns.ipv6
|
|||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.rdtypes.util
|
||||
import dns.renderer
|
||||
import dns.tokenizer
|
||||
import dns.wire
|
||||
|
||||
|
@ -427,7 +427,6 @@ def _validate_and_define(params, key, value):
|
|||
|
||||
@dns.immutable.immutable
|
||||
class SVCBBase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for SVCB-like records"""
|
||||
|
||||
# see: draft-ietf-dnsop-svcb-https-11
|
||||
|
@ -521,19 +520,10 @@ class SVCBBase(dns.rdata.Rdata):
|
|||
for key in sorted(self.params):
|
||||
file.write(struct.pack("!H", key))
|
||||
value = self.params[key]
|
||||
# placeholder for length (or actual length of empty values)
|
||||
file.write(struct.pack("!H", 0))
|
||||
if value is None:
|
||||
continue
|
||||
else:
|
||||
start = file.tell()
|
||||
value.to_wire(file, origin)
|
||||
end = file.tell()
|
||||
assert end - start < 65536
|
||||
file.seek(start - 2)
|
||||
stuff = struct.pack("!H", end - start)
|
||||
file.write(stuff)
|
||||
file.seek(0, io.SEEK_END)
|
||||
with dns.renderer.prefixed_length(file, 2):
|
||||
# Note that we're still writing a length of zero if the value is None
|
||||
if value is not None:
|
||||
value.to_wire(file, origin)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
|
|
|
@ -25,7 +25,6 @@ import dns.rdatatype
|
|||
|
||||
@dns.immutable.immutable
|
||||
class TLSABase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for TLSA and SMIMEA records"""
|
||||
|
||||
# see: RFC 6698
|
||||
|
|
|
@ -17,18 +17,17 @@
|
|||
|
||||
"""TXT-like base class."""
|
||||
|
||||
import struct
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import dns.exception
|
||||
import dns.immutable
|
||||
import dns.rdata
|
||||
import dns.renderer
|
||||
import dns.tokenizer
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class TXTBase(dns.rdata.Rdata):
|
||||
|
||||
"""Base class for rdata that is like a TXT record (see RFC 1035)."""
|
||||
|
||||
__slots__ = ["strings"]
|
||||
|
@ -56,7 +55,7 @@ class TXTBase(dns.rdata.Rdata):
|
|||
self,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
**kw: Dict[str, Any]
|
||||
**kw: Dict[str, Any],
|
||||
) -> str:
|
||||
txt = ""
|
||||
prefix = ""
|
||||
|
@ -93,10 +92,8 @@ class TXTBase(dns.rdata.Rdata):
|
|||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
for s in self.strings:
|
||||
l = len(s)
|
||||
assert l < 256
|
||||
file.write(struct.pack("!B", l))
|
||||
file.write(s)
|
||||
with dns.renderer.prefixed_length(file, 1):
|
||||
file.write(s)
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
|
||||
|
|
|
@ -32,6 +32,24 @@ AUTHORITY = 2
|
|||
ADDITIONAL = 3
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prefixed_length(output, length_length):
|
||||
output.write(b"\00" * length_length)
|
||||
start = output.tell()
|
||||
yield
|
||||
end = output.tell()
|
||||
length = end - start
|
||||
if length > 0:
|
||||
try:
|
||||
output.seek(start - length_length)
|
||||
try:
|
||||
output.write(length.to_bytes(length_length, "big"))
|
||||
except OverflowError:
|
||||
raise dns.exception.FormError
|
||||
finally:
|
||||
output.seek(end)
|
||||
|
||||
|
||||
class Renderer:
|
||||
"""Helper class for building DNS wire-format messages.
|
||||
|
||||
|
@ -134,6 +152,15 @@ class Renderer:
|
|||
self._rollback(start)
|
||||
raise dns.exception.TooBig
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporarily_seek_to(self, where):
|
||||
current = self.output.tell()
|
||||
try:
|
||||
self.output.seek(where)
|
||||
yield
|
||||
finally:
|
||||
self.output.seek(current)
|
||||
|
||||
def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
|
||||
"""Add a question to the message."""
|
||||
|
||||
|
@ -269,18 +296,14 @@ class Renderer:
|
|||
with self._track_size():
|
||||
keyname.to_wire(self.output, compress, self.origin)
|
||||
self.output.write(
|
||||
struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)
|
||||
struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0)
|
||||
)
|
||||
rdata_start = self.output.tell()
|
||||
tsig.to_wire(self.output)
|
||||
with prefixed_length(self.output, 2):
|
||||
tsig.to_wire(self.output)
|
||||
|
||||
after = self.output.tell()
|
||||
self.output.seek(rdata_start - 2)
|
||||
self.output.write(struct.pack("!H", after - rdata_start))
|
||||
self.counts[ADDITIONAL] += 1
|
||||
self.output.seek(10)
|
||||
self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
|
||||
self.output.seek(0, io.SEEK_END)
|
||||
with self._temporarily_seek_to(10):
|
||||
self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
|
||||
|
||||
def write_header(self):
|
||||
"""Write the DNS message header.
|
||||
|
@ -290,19 +313,18 @@ class Renderer:
|
|||
is added.
|
||||
"""
|
||||
|
||||
self.output.seek(0)
|
||||
self.output.write(
|
||||
struct.pack(
|
||||
"!HHHHHH",
|
||||
self.id,
|
||||
self.flags,
|
||||
self.counts[0],
|
||||
self.counts[1],
|
||||
self.counts[2],
|
||||
self.counts[3],
|
||||
with self._temporarily_seek_to(0):
|
||||
self.output.write(
|
||||
struct.pack(
|
||||
"!HHHHHH",
|
||||
self.id,
|
||||
self.flags,
|
||||
self.counts[0],
|
||||
self.counts[1],
|
||||
self.counts[2],
|
||||
self.counts[3],
|
||||
)
|
||||
)
|
||||
)
|
||||
self.output.seek(0, io.SEEK_END)
|
||||
|
||||
def get_wire(self):
|
||||
"""Return the wire format message."""
|
||||
|
|
|
@ -26,7 +26,6 @@ import dns.renderer
|
|||
|
||||
|
||||
class RRset(dns.rdataset.Rdataset):
|
||||
|
||||
"""A DNS RRset (named rdataset).
|
||||
|
||||
RRset inherits from Rdataset, and RRsets can be treated as
|
||||
|
@ -132,7 +131,7 @@ class RRset(dns.rdataset.Rdataset):
|
|||
self,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
relativize: bool = True,
|
||||
**kw: Dict[str, Any]
|
||||
**kw: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Convert the RRset into DNS zone file format.
|
||||
|
||||
|
@ -159,7 +158,7 @@ class RRset(dns.rdataset.Rdataset):
|
|||
file: Any,
|
||||
compress: Optional[dns.name.CompressType] = None, # type: ignore
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
**kw: Dict[str, Any]
|
||||
**kw: Dict[str, Any],
|
||||
) -> int:
|
||||
"""Convert the RRset to wire format.
|
||||
|
||||
|
@ -231,7 +230,7 @@ def from_text(
|
|||
ttl: int,
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
*text_rdatas: Any
|
||||
*text_rdatas: Any,
|
||||
) -> RRset:
|
||||
"""Create an RRset with the specified name, TTL, class, and type and with
|
||||
the specified rdatas in text format.
|
||||
|
|
|
@ -19,7 +19,6 @@ import itertools
|
|||
|
||||
|
||||
class Set:
|
||||
|
||||
"""A simple set class.
|
||||
|
||||
This class was originally used to deal with sets being missing in
|
||||
|
|
|
@ -203,7 +203,7 @@ class Transaction:
|
|||
|
||||
- name
|
||||
|
||||
- name, rdataclass, rdatatype, [covers]
|
||||
- name, rdatatype, [covers]
|
||||
|
||||
- name, rdataset...
|
||||
|
||||
|
@ -222,7 +222,7 @@ class Transaction:
|
|||
|
||||
- name
|
||||
|
||||
- name, rdataclass, rdatatype, [covers]
|
||||
- name, rdatatype, [covers]
|
||||
|
||||
- name, rdataset...
|
||||
|
||||
|
|
|
@ -29,47 +29,38 @@ import dns.rdataclass
|
|||
|
||||
|
||||
class BadTime(dns.exception.DNSException):
|
||||
|
||||
"""The current time is not within the TSIG's validity time."""
|
||||
|
||||
|
||||
class BadSignature(dns.exception.DNSException):
|
||||
|
||||
"""The TSIG signature fails to verify."""
|
||||
|
||||
|
||||
class BadKey(dns.exception.DNSException):
|
||||
|
||||
"""The TSIG record owner name does not match the key."""
|
||||
|
||||
|
||||
class BadAlgorithm(dns.exception.DNSException):
|
||||
|
||||
"""The TSIG algorithm does not match the key."""
|
||||
|
||||
|
||||
class PeerError(dns.exception.DNSException):
|
||||
|
||||
"""Base class for all TSIG errors generated by the remote peer"""
|
||||
|
||||
|
||||
class PeerBadKey(PeerError):
|
||||
|
||||
"""The peer didn't know the key we used"""
|
||||
|
||||
|
||||
class PeerBadSignature(PeerError):
|
||||
|
||||
"""The peer didn't like the signature we sent"""
|
||||
|
||||
|
||||
class PeerBadTime(PeerError):
|
||||
|
||||
"""The peer didn't like the time we sent"""
|
||||
|
||||
|
||||
class PeerBadTruncation(PeerError):
|
||||
|
||||
"""The peer didn't like amount of truncation in the TSIG we sent"""
|
||||
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
#: MAJOR
|
||||
MAJOR = 2
|
||||
#: MINOR
|
||||
MINOR = 4
|
||||
MINOR = 6
|
||||
#: MICRO
|
||||
MICRO = 2
|
||||
MICRO = 1
|
||||
#: RELEASELEVEL
|
||||
RELEASELEVEL = 0x0F
|
||||
#: SERIAL
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import sys
|
||||
|
||||
import dns._features
|
||||
|
||||
if sys.platform == "win32":
|
||||
from typing import Any
|
||||
|
||||
|
@ -15,14 +17,14 @@ if sys.platform == "win32":
|
|||
except KeyError:
|
||||
WindowsError = Exception
|
||||
|
||||
try:
|
||||
if dns._features.have("wmi"):
|
||||
import threading
|
||||
|
||||
import pythoncom # pylint: disable=import-error
|
||||
import wmi # pylint: disable=import-error
|
||||
|
||||
_have_wmi = True
|
||||
except Exception:
|
||||
else:
|
||||
_have_wmi = False
|
||||
|
||||
def _config_domain(domain):
|
||||
|
@ -51,9 +53,10 @@ if sys.platform == "win32":
|
|||
try:
|
||||
system = wmi.WMI()
|
||||
for interface in system.Win32_NetworkAdapterConfiguration():
|
||||
if interface.IPEnabled and interface.DNSDomain:
|
||||
self.info.domain = _config_domain(interface.DNSDomain)
|
||||
if interface.IPEnabled and interface.DNSServerSearchOrder:
|
||||
self.info.nameservers = list(interface.DNSServerSearchOrder)
|
||||
if interface.DNSDomain:
|
||||
self.info.domain = _config_domain(interface.DNSDomain)
|
||||
if interface.DNSDomainSuffixSearchOrder:
|
||||
self.info.search = [
|
||||
_config_domain(x)
|
||||
|
|
203
lib/dns/zone.py
203
lib/dns/zone.py
|
@ -21,7 +21,18 @@ import contextlib
|
|||
import io
|
||||
import os
|
||||
import struct
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import dns.exception
|
||||
import dns.grange
|
||||
|
@ -43,47 +54,70 @@ from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers
|
|||
|
||||
|
||||
class BadZone(dns.exception.DNSException):
|
||||
|
||||
"""The DNS zone is malformed."""
|
||||
|
||||
|
||||
class NoSOA(BadZone):
|
||||
|
||||
"""The DNS zone has no SOA RR at its origin."""
|
||||
|
||||
|
||||
class NoNS(BadZone):
|
||||
|
||||
"""The DNS zone has no NS RRset at its origin."""
|
||||
|
||||
|
||||
class UnknownOrigin(BadZone):
|
||||
|
||||
"""The DNS zone's origin is unknown."""
|
||||
|
||||
|
||||
class UnsupportedDigestScheme(dns.exception.DNSException):
|
||||
|
||||
"""The zone digest's scheme is unsupported."""
|
||||
|
||||
|
||||
class UnsupportedDigestHashAlgorithm(dns.exception.DNSException):
|
||||
|
||||
"""The zone digest's origin is unsupported."""
|
||||
|
||||
|
||||
class NoDigest(dns.exception.DNSException):
|
||||
|
||||
"""The DNS zone has no ZONEMD RRset at its origin."""
|
||||
|
||||
|
||||
class DigestVerificationFailure(dns.exception.DNSException):
|
||||
|
||||
"""The ZONEMD digest failed to verify."""
|
||||
|
||||
|
||||
class Zone(dns.transaction.TransactionManager):
|
||||
def _validate_name(
|
||||
name: dns.name.Name,
|
||||
origin: Optional[dns.name.Name],
|
||||
relativize: bool,
|
||||
) -> dns.name.Name:
|
||||
# This name validation code is shared by Zone and Version
|
||||
if origin is None:
|
||||
# This should probably never happen as other code (e.g.
|
||||
# _rr_line) will notice the lack of an origin before us, but
|
||||
# we check just in case!
|
||||
raise KeyError("no zone origin is defined")
|
||||
if name.is_absolute():
|
||||
if not name.is_subdomain(origin):
|
||||
raise KeyError("name parameter must be a subdomain of the zone origin")
|
||||
if relativize:
|
||||
name = name.relativize(origin)
|
||||
else:
|
||||
# We have a relative name. Make sure that the derelativized name is
|
||||
# not too long.
|
||||
try:
|
||||
abs_name = name.derelativize(origin)
|
||||
except dns.name.NameTooLong:
|
||||
# We map dns.name.NameTooLong to KeyError to be consistent with
|
||||
# the other exceptions above.
|
||||
raise KeyError("relative name too long for zone")
|
||||
if not relativize:
|
||||
# We have a relative name in a non-relative zone, so use the
|
||||
# derelativized name.
|
||||
name = abs_name
|
||||
return name
|
||||
|
||||
|
||||
class Zone(dns.transaction.TransactionManager):
|
||||
"""A DNS zone.
|
||||
|
||||
A ``Zone`` is a mapping from names to nodes. The zone object may be
|
||||
|
@ -94,7 +128,10 @@ class Zone(dns.transaction.TransactionManager):
|
|||
the zone.
|
||||
"""
|
||||
|
||||
node_factory = dns.node.Node
|
||||
node_factory: Callable[[], dns.node.Node] = dns.node.Node
|
||||
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
|
||||
writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
|
||||
immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
|
||||
|
||||
__slots__ = ["rdclass", "origin", "nodes", "relativize"]
|
||||
|
||||
|
@ -125,7 +162,7 @@ class Zone(dns.transaction.TransactionManager):
|
|||
raise ValueError("origin parameter must be an absolute name")
|
||||
self.origin = origin
|
||||
self.rdclass = rdclass
|
||||
self.nodes: Dict[dns.name.Name, dns.node.Node] = {}
|
||||
self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory()
|
||||
self.relativize = relativize
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -154,26 +191,13 @@ class Zone(dns.transaction.TransactionManager):
|
|||
return not self.__eq__(other)
|
||||
|
||||
def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
# Note that any changes in this method should have corresponding changes
|
||||
# made in the Version _validate_name() method.
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name, None)
|
||||
elif not isinstance(name, dns.name.Name):
|
||||
raise KeyError("name parameter must be convertible to a DNS name")
|
||||
if name.is_absolute():
|
||||
if self.origin is None:
|
||||
# This should probably never happen as other code (e.g.
|
||||
# _rr_line) will notice the lack of an origin before us, but
|
||||
# we check just in case!
|
||||
raise KeyError("no zone origin is defined")
|
||||
if not name.is_subdomain(self.origin):
|
||||
raise KeyError("name parameter must be a subdomain of the zone origin")
|
||||
if self.relativize:
|
||||
name = name.relativize(self.origin)
|
||||
elif not self.relativize:
|
||||
# We have a relative name in a non-relative zone, so derelativize.
|
||||
if self.origin is None:
|
||||
raise KeyError("no zone origin is defined")
|
||||
name = name.derelativize(self.origin)
|
||||
return name
|
||||
return _validate_name(name, self.origin, self.relativize)
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = self._validate_name(key)
|
||||
|
@ -252,9 +276,6 @@ class Zone(dns.transaction.TransactionManager):
|
|||
*create*, a ``bool``. If true, the node will be created if it does
|
||||
not exist.
|
||||
|
||||
Raises ``KeyError`` if the name is not known and create was
|
||||
not specified, or if the name was not a subdomain of the origin.
|
||||
|
||||
Returns a ``dns.node.Node`` or ``None``.
|
||||
"""
|
||||
|
||||
|
@ -527,9 +548,6 @@ class Zone(dns.transaction.TransactionManager):
|
|||
*create*, a ``bool``. If true, the node will be created if it does
|
||||
not exist.
|
||||
|
||||
Raises ``KeyError`` if the name is not known and create was
|
||||
not specified, or if the name was not a subdomain of the origin.
|
||||
|
||||
Returns a ``dns.rrset.RRset`` or ``None``.
|
||||
"""
|
||||
|
||||
|
@ -952,7 +970,7 @@ class Version:
|
|||
self,
|
||||
zone: Zone,
|
||||
id: int,
|
||||
nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None,
|
||||
nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
):
|
||||
self.zone = zone
|
||||
|
@ -960,26 +978,11 @@ class Version:
|
|||
if nodes is not None:
|
||||
self.nodes = nodes
|
||||
else:
|
||||
self.nodes = {}
|
||||
self.nodes = zone.map_factory()
|
||||
self.origin = origin
|
||||
|
||||
def _validate_name(self, name: dns.name.Name) -> dns.name.Name:
|
||||
if name.is_absolute():
|
||||
if self.origin is None:
|
||||
# This should probably never happen as other code (e.g.
|
||||
# _rr_line) will notice the lack of an origin before us, but
|
||||
# we check just in case!
|
||||
raise KeyError("no zone origin is defined")
|
||||
if not name.is_subdomain(self.origin):
|
||||
raise KeyError("name is not a subdomain of the zone origin")
|
||||
if self.zone.relativize:
|
||||
name = name.relativize(self.origin)
|
||||
elif not self.zone.relativize:
|
||||
# We have a relative name in a non-relative zone, so derelativize.
|
||||
if self.origin is None:
|
||||
raise KeyError("no zone origin is defined")
|
||||
name = name.derelativize(self.origin)
|
||||
return name
|
||||
return _validate_name(name, self.origin, self.zone.relativize)
|
||||
|
||||
def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
|
||||
name = self._validate_name(name)
|
||||
|
@ -1085,7 +1088,9 @@ class ImmutableVersion(Version):
|
|||
version.nodes[name] = ImmutableVersionedNode(node)
|
||||
# We're changing the type of the nodes dictionary here on purpose, so
|
||||
# we ignore the mypy error.
|
||||
self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore
|
||||
self.nodes = dns.immutable.Dict(
|
||||
version.nodes, True, self.zone.map_factory
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class Transaction(dns.transaction.Transaction):
|
||||
|
@ -1101,7 +1106,10 @@ class Transaction(dns.transaction.Transaction):
|
|||
|
||||
def _setup_version(self):
|
||||
assert self.version is None
|
||||
self.version = WritableVersion(self.zone, self.replacement)
|
||||
factory = self.manager.writable_version_factory
|
||||
if factory is None:
|
||||
factory = WritableVersion
|
||||
self.version = factory(self.zone, self.replacement)
|
||||
|
||||
def _get_rdataset(self, name, rdtype, covers):
|
||||
return self.version.get_rdataset(name, rdtype, covers)
|
||||
|
@ -1132,7 +1140,10 @@ class Transaction(dns.transaction.Transaction):
|
|||
self.zone._end_read(self)
|
||||
elif commit and len(self.version.changed) > 0:
|
||||
if self.make_immutable:
|
||||
version = ImmutableVersion(self.version)
|
||||
factory = self.manager.immutable_version_factory
|
||||
if factory is None:
|
||||
factory = ImmutableVersion
|
||||
version = factory(self.version)
|
||||
else:
|
||||
version = self.version
|
||||
self.zone._commit_version(self, version, self.version.origin)
|
||||
|
@ -1168,6 +1179,48 @@ class Transaction(dns.transaction.Transaction):
|
|||
return (absolute, relativize, effective)
|
||||
|
||||
|
||||
def _from_text(
|
||||
text: Any,
|
||||
origin: Optional[Union[dns.name.Name, str]] = None,
|
||||
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
|
||||
relativize: bool = True,
|
||||
zone_factory: Any = Zone,
|
||||
filename: Optional[str] = None,
|
||||
allow_include: bool = False,
|
||||
check_origin: bool = True,
|
||||
idna_codec: Optional[dns.name.IDNACodec] = None,
|
||||
allow_directives: Union[bool, Iterable[str]] = True,
|
||||
) -> Zone:
|
||||
# See the comments for the public APIs from_text() and from_file() for
|
||||
# details.
|
||||
|
||||
# 'text' can also be a file, but we don't publish that fact
|
||||
# since it's an implementation detail. The official file
|
||||
# interface is from_file().
|
||||
|
||||
if filename is None:
|
||||
filename = "<string>"
|
||||
zone = zone_factory(origin, rdclass, relativize=relativize)
|
||||
with zone.writer(True) as txn:
|
||||
tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
|
||||
reader = dns.zonefile.Reader(
|
||||
tok,
|
||||
rdclass,
|
||||
txn,
|
||||
allow_include=allow_include,
|
||||
allow_directives=allow_directives,
|
||||
)
|
||||
try:
|
||||
reader.read()
|
||||
except dns.zonefile.UnknownOrigin:
|
||||
# for backwards compatibility
|
||||
raise dns.zone.UnknownOrigin
|
||||
# Now that we're done reading, do some basic checking of the zone.
|
||||
if check_origin:
|
||||
zone.check_origin()
|
||||
return zone
|
||||
|
||||
|
||||
def from_text(
|
||||
text: str,
|
||||
origin: Optional[Union[dns.name.Name, str]] = None,
|
||||
|
@ -1228,32 +1281,18 @@ def from_text(
|
|||
|
||||
Returns a subclass of ``dns.zone.Zone``.
|
||||
"""
|
||||
|
||||
# 'text' can also be a file, but we don't publish that fact
|
||||
# since it's an implementation detail. The official file
|
||||
# interface is from_file().
|
||||
|
||||
if filename is None:
|
||||
filename = "<string>"
|
||||
zone = zone_factory(origin, rdclass, relativize=relativize)
|
||||
with zone.writer(True) as txn:
|
||||
tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
|
||||
reader = dns.zonefile.Reader(
|
||||
tok,
|
||||
rdclass,
|
||||
txn,
|
||||
allow_include=allow_include,
|
||||
allow_directives=allow_directives,
|
||||
)
|
||||
try:
|
||||
reader.read()
|
||||
except dns.zonefile.UnknownOrigin:
|
||||
# for backwards compatibility
|
||||
raise dns.zone.UnknownOrigin
|
||||
# Now that we're done reading, do some basic checking of the zone.
|
||||
if check_origin:
|
||||
zone.check_origin()
|
||||
return zone
|
||||
return _from_text(
|
||||
text,
|
||||
origin,
|
||||
rdclass,
|
||||
relativize,
|
||||
zone_factory,
|
||||
filename,
|
||||
allow_include,
|
||||
check_origin,
|
||||
idna_codec,
|
||||
allow_directives,
|
||||
)
|
||||
|
||||
|
||||
def from_file(
|
||||
|
@ -1324,7 +1363,7 @@ def from_file(
|
|||
else:
|
||||
cm = contextlib.nullcontext(f)
|
||||
with cm as f:
|
||||
return from_text(
|
||||
return _from_text(
|
||||
f,
|
||||
origin,
|
||||
rdclass,
|
||||
|
|
|
@ -86,7 +86,6 @@ def _upper_dollarize(s):
|
|||
|
||||
|
||||
class Reader:
|
||||
|
||||
"""Read a DNS zone file into a transaction."""
|
||||
|
||||
def __init__(
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue