mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-06 13:11:15 -07:00
Bump dnspython from 2.6.1 to 2.7.0 (#2440)
* Bump dnspython from 2.6.1 to 2.7.0 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.6.1 to 2.7.0. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.6.1...v2.7.0) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update dnspython==2.7.0 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
0836fb902c
commit
feca713b76
56 changed files with 1382 additions and 665 deletions
545
lib/dns/query.py
545
lib/dns/query.py
|
@ -23,11 +23,13 @@ import enum
|
|||
import errno
|
||||
import os
|
||||
import os.path
|
||||
import random
|
||||
import selectors
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
import dns._features
|
||||
import dns.exception
|
||||
|
@ -129,7 +131,7 @@ if _have_httpx:
|
|||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
if resolver is None and bootstrap_address is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.resolver
|
||||
|
||||
|
@ -217,7 +219,7 @@ def _wait_for(fd, readable, writable, _, expiration):
|
|||
|
||||
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
|
||||
return True
|
||||
sel = _selector_class()
|
||||
sel = selectors.DefaultSelector()
|
||||
events = 0
|
||||
if readable:
|
||||
events |= selectors.EVENT_READ
|
||||
|
@ -235,26 +237,6 @@ def _wait_for(fd, readable, writable, _, expiration):
|
|||
raise dns.exception.Timeout
|
||||
|
||||
|
||||
def _set_selector_class(selector_class):
|
||||
# Internal API. Do not use.
|
||||
|
||||
global _selector_class
|
||||
|
||||
_selector_class = selector_class
|
||||
|
||||
|
||||
if hasattr(selectors, "PollSelector"):
|
||||
# Prefer poll() on platforms that support it because it has no
|
||||
# limits on the maximum value of a file descriptor (plus it will
|
||||
# be more efficient for high values).
|
||||
#
|
||||
# We ignore typing here as we can't say _selector_class is Any
|
||||
# on python < 3.8 due to a bug.
|
||||
_selector_class = selectors.PollSelector # type: ignore
|
||||
else:
|
||||
_selector_class = selectors.SelectSelector # type: ignore
|
||||
|
||||
|
||||
def _wait_for_readable(s, expiration):
|
||||
_wait_for(s, True, False, True, expiration)
|
||||
|
||||
|
@ -355,6 +337,36 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
|
|||
raise
|
||||
|
||||
|
||||
def _maybe_get_resolver(
|
||||
resolver: Optional["dns.resolver.Resolver"],
|
||||
) -> "dns.resolver.Resolver":
|
||||
# We need a separate method for this to avoid overriding the global
|
||||
# variable "dns" with the as-yet undefined local variable "dns"
|
||||
# in https().
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.resolver
|
||||
|
||||
resolver = dns.resolver.Resolver()
|
||||
return resolver
|
||||
|
||||
|
||||
class HTTPVersion(enum.IntEnum):
|
||||
"""Which version of HTTP should be used?
|
||||
|
||||
DEFAULT will select the first version from the list [2, 1.1, 3] that
|
||||
is available.
|
||||
"""
|
||||
|
||||
DEFAULT = 0
|
||||
HTTP_1 = 1
|
||||
H1 = 1
|
||||
HTTP_2 = 2
|
||||
H2 = 2
|
||||
HTTP_3 = 3
|
||||
H3 = 3
|
||||
|
||||
|
||||
def https(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
|
@ -370,7 +382,8 @@ def https(
|
|||
bootstrap_address: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
resolver: Optional["dns.resolver.Resolver"] = None,
|
||||
family: Optional[int] = socket.AF_UNSPEC,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
http_version: HTTPVersion = HTTPVersion.DEFAULT,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
|
@ -420,27 +433,66 @@ def https(
|
|||
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
|
||||
and AAAA records will be retrieved.
|
||||
|
||||
*http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
|
||||
(af, _, the_source) = _destination_and_source(
|
||||
where, port, source, source_port, False
|
||||
)
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = f"https://{where}:{port}{path}"
|
||||
elif af == socket.AF_INET6:
|
||||
url = f"https://[{where}]:{port}{path}"
|
||||
else:
|
||||
url = where
|
||||
|
||||
extensions = {}
|
||||
if bootstrap_address is None:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname is None:
|
||||
raise ValueError("no hostname in URL")
|
||||
if dns.inet.is_address(parsed.hostname):
|
||||
bootstrap_address = parsed.hostname
|
||||
extensions["sni_hostname"] = parsed.hostname
|
||||
if parsed.port is not None:
|
||||
port = parsed.port
|
||||
|
||||
if http_version == HTTPVersion.H3 or (
|
||||
http_version == HTTPVersion.DEFAULT and not have_doh
|
||||
):
|
||||
if bootstrap_address is None:
|
||||
resolver = _maybe_get_resolver(resolver)
|
||||
assert parsed.hostname is not None # for mypy
|
||||
answers = resolver.resolve_name(parsed.hostname, family)
|
||||
bootstrap_address = random.choice(list(answers.addresses()))
|
||||
return _http3(
|
||||
q,
|
||||
bootstrap_address,
|
||||
url,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
verify=verify,
|
||||
post=post,
|
||||
)
|
||||
|
||||
if not have_doh:
|
||||
raise NoDOH # pragma: no cover
|
||||
if session and not isinstance(session, httpx.Client):
|
||||
raise ValueError("session parameter must be an httpx.Client")
|
||||
|
||||
wire = q.to_wire()
|
||||
(af, _, the_source) = _destination_and_source(
|
||||
where, port, source, source_port, False
|
||||
)
|
||||
transport = None
|
||||
headers = {"accept": "application/dns-message"}
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = "https://{}:{}{}".format(where, port, path)
|
||||
elif af == socket.AF_INET6:
|
||||
url = "https://[{}]:{}{}".format(where, port, path)
|
||||
else:
|
||||
url = where
|
||||
|
||||
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
|
||||
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
|
||||
|
||||
# set source port and source address
|
||||
|
||||
|
@ -450,21 +502,22 @@ def https(
|
|||
else:
|
||||
local_address = the_source[0]
|
||||
local_port = the_source[1]
|
||||
transport = _HTTPTransport(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=True,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
if session:
|
||||
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
|
||||
else:
|
||||
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
|
||||
transport = _HTTPTransport(
|
||||
local_address=local_address,
|
||||
http1=h1,
|
||||
http2=h2,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
|
||||
with cm as session:
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
|
@ -475,20 +528,30 @@ def https(
|
|||
"content-length": str(len(wire)),
|
||||
}
|
||||
)
|
||||
response = session.post(url, headers=headers, content=wire, timeout=timeout)
|
||||
response = session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
content=wire,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = session.get(
|
||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
params={"dns": twire},
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
# status codes
|
||||
if response.status_code < 200 or response.status_code > 299:
|
||||
raise ValueError(
|
||||
"{} responded with status code {}"
|
||||
"\nResponse body: {}".format(where, response.status_code, response.content)
|
||||
f"{where} responded with status code {response.status_code}"
|
||||
f"\nResponse body: {response.content}"
|
||||
)
|
||||
r = dns.message.from_wire(
|
||||
response.content,
|
||||
|
@ -503,6 +566,81 @@ def https(
|
|||
return r
|
||||
|
||||
|
||||
def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
|
||||
if headers is None:
|
||||
raise KeyError
|
||||
for header, value in headers:
|
||||
if header == name:
|
||||
return value
|
||||
raise KeyError
|
||||
|
||||
|
||||
def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
|
||||
value = _find_header(headers, b":status")
|
||||
if value is None:
|
||||
raise SyntaxError("no :status header in response")
|
||||
status = int(value)
|
||||
if status < 0:
|
||||
raise SyntaxError("status is negative")
|
||||
if status < 200 or status > 299:
|
||||
error = ""
|
||||
if len(wire) > 0:
|
||||
try:
|
||||
error = ": " + wire.decode()
|
||||
except Exception:
|
||||
pass
|
||||
raise ValueError(f"{peer} responded with status code {status}{error}")
|
||||
|
||||
|
||||
def _http3(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
url: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
verify: Union[bool, str] = True,
|
||||
hostname: Optional[str] = None,
|
||||
post: bool = True,
|
||||
) -> dns.message.Message:
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
|
||||
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
hostname = url_parts.hostname
|
||||
if url_parts.port is not None:
|
||||
port = url_parts.port
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
manager = dns.quic.SyncQuicManager(
|
||||
verify_mode=verify, server_name=hostname, h3=True
|
||||
)
|
||||
|
||||
with manager:
|
||||
connection = manager.connect(where, port, source, source_port)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
with connection.make_stream(timeout) as stream:
|
||||
stream.send_h3(url, wire, post)
|
||||
wire = stream.receive(_remaining(expiration))
|
||||
_check_status(stream.headers(), where, wire)
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
def _udp_recv(sock, max_size, expiration):
|
||||
"""Reads a datagram from the socket.
|
||||
A Timeout exception will be raised if the operation is not completed
|
||||
|
@ -855,7 +993,7 @@ def _net_read(sock, count, expiration):
|
|||
try:
|
||||
n = sock.recv(count)
|
||||
if n == b"":
|
||||
raise EOFError
|
||||
raise EOFError("EOF")
|
||||
count -= len(n)
|
||||
s += n
|
||||
except (BlockingIOError, ssl.SSLWantReadError):
|
||||
|
@ -1023,6 +1161,7 @@ def tcp(
|
|||
cm = _make_socket(af, socket.SOCK_STREAM, source)
|
||||
with cm as s:
|
||||
if not sock:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
_connect(s, destination, expiration)
|
||||
send_tcp(s, wire, expiration)
|
||||
(r, received_time) = receive_tcp(
|
||||
|
@ -1188,6 +1327,7 @@ def quic(
|
|||
ignore_trailing: bool = False,
|
||||
connection: Optional[dns.quic.SyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
hostname: Optional[str] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-QUIC.
|
||||
|
@ -1212,17 +1352,21 @@ def quic(
|
|||
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
|
||||
received message.
|
||||
|
||||
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
|
||||
connection to use to send the query.
|
||||
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
|
||||
to send the query.
|
||||
|
||||
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
|
||||
of the server is done using the default CA bundle; if ``False``, then no
|
||||
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||
directory which will be used for verification.
|
||||
|
||||
*server_hostname*, a ``str`` containing the server's hostname. The
|
||||
default is ``None``, which means that no hostname is known, and if an
|
||||
SSL context is created, hostname checking will be disabled.
|
||||
*hostname*, a ``str`` containing the server's hostname or ``None``. The default is
|
||||
``None``, which means that no hostname is known, and if an SSL context is created,
|
||||
hostname checking will be disabled. This value is ignored if *url* is not
|
||||
``None``.
|
||||
|
||||
*server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
|
||||
only, and has the same meaning as *hostname*.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
|
@ -1230,6 +1374,9 @@ def quic(
|
|||
if not dns.quic.have_quic:
|
||||
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
|
||||
|
||||
if server_hostname is not None and hostname is None:
|
||||
hostname = server_hostname
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.SyncQuicConnection
|
||||
|
@ -1238,9 +1385,7 @@ def quic(
|
|||
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
|
||||
the_connection = connection
|
||||
else:
|
||||
manager = dns.quic.SyncQuicManager(
|
||||
verify_mode=verify, server_name=server_hostname
|
||||
)
|
||||
manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
|
||||
the_manager = manager # for type checking happiness
|
||||
|
||||
with manager:
|
||||
|
@ -1264,6 +1409,70 @@ def quic(
|
|||
return r
|
||||
|
||||
|
||||
class UDPMode(enum.IntEnum):
|
||||
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
|
||||
|
||||
NEVER means "never use UDP; always use TCP"
|
||||
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
|
||||
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
|
||||
"""
|
||||
|
||||
NEVER = 0
|
||||
TRY_FIRST = 1
|
||||
ONLY = 2
|
||||
|
||||
|
||||
def _inbound_xfr(
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
s: socket.socket,
|
||||
query: dns.message.Message,
|
||||
serial: Optional[int],
|
||||
timeout: Optional[float],
|
||||
expiration: float,
|
||||
) -> Any:
|
||||
"""Given a socket, does the zone transfer."""
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
is_udp = s.type == socket.SOCK_DGRAM
|
||||
if is_udp:
|
||||
_udp_send(s, wire, None, expiration)
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
(rwire, _) = _udp_recv(s, 65535, mexpiration)
|
||||
else:
|
||||
ldata = _net_read(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = _net_read(s, l, mexpiration)
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
done = inbound.process_message(r)
|
||||
yield r
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
|
||||
|
||||
def xfr(
|
||||
where: str,
|
||||
zone: Union[dns.name.Name, str],
|
||||
|
@ -1333,134 +1542,52 @@ def xfr(
|
|||
Returns a generator of ``dns.message.Message`` objects.
|
||||
"""
|
||||
|
||||
class DummyTransactionManager(dns.transaction.TransactionManager):
|
||||
def __init__(self, origin, relativize):
|
||||
self.info = (origin, relativize, dns.name.empty if relativize else origin)
|
||||
|
||||
def origin_information(self):
|
||||
return self.info
|
||||
|
||||
def get_class(self) -> dns.rdataclass.RdataClass:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def reader(self):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
|
||||
class DummyTransaction:
|
||||
def nop(self, *args, **kw):
|
||||
pass
|
||||
|
||||
def __getattr__(self, _):
|
||||
return self.nop
|
||||
|
||||
return cast(dns.transaction.Transaction, DummyTransaction())
|
||||
|
||||
if isinstance(zone, str):
|
||||
zone = dns.name.from_text(zone)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
q = dns.message.make_query(zone, rdtype, rdclass)
|
||||
if rdtype == dns.rdatatype.IXFR:
|
||||
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
|
||||
q.authority.append(rrset)
|
||||
rrset = q.find_rrset(
|
||||
q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
|
||||
)
|
||||
soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
|
||||
rrset.add(soa, 0)
|
||||
if keyring is not None:
|
||||
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
|
||||
wire = q.to_wire()
|
||||
(af, destination, source) = _destination_and_source(
|
||||
where, port, source, source_port
|
||||
)
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
tm = DummyTransactionManager(zone, relativize)
|
||||
if use_udp and rdtype != dns.rdatatype.IXFR:
|
||||
raise ValueError("cannot do a UDP AXFR")
|
||||
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
|
||||
with _make_socket(af, sock_type, source) as s:
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
_connect(s, destination, expiration)
|
||||
l = len(wire)
|
||||
if use_udp:
|
||||
_udp_send(s, wire, None, expiration)
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", l) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
done = False
|
||||
delete_mode = True
|
||||
expecting_SOA = False
|
||||
soa_rrset = None
|
||||
if relativize:
|
||||
origin = zone
|
||||
oname = dns.name.empty
|
||||
else:
|
||||
origin = None
|
||||
oname = zone
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if use_udp:
|
||||
(wire, _) = _udp_recv(s, 65535, mexpiration)
|
||||
else:
|
||||
ldata = _net_read(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
wire = _net_read(s, l, mexpiration)
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=True,
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
rcode = r.rcode()
|
||||
if rcode != dns.rcode.NOERROR:
|
||||
raise TransferError(rcode)
|
||||
tsig_ctx = r.tsig_ctx
|
||||
answer_index = 0
|
||||
if soa_rrset is None:
|
||||
if not r.answer or r.answer[0].name != oname:
|
||||
raise dns.exception.FormError("No answer or RRset not for qname")
|
||||
rrset = r.answer[0]
|
||||
if rrset.rdtype != dns.rdatatype.SOA:
|
||||
raise dns.exception.FormError("first RRset is not an SOA")
|
||||
answer_index = 1
|
||||
soa_rrset = rrset.copy()
|
||||
if rdtype == dns.rdatatype.IXFR:
|
||||
if dns.serial.Serial(soa_rrset[0].serial) <= serial:
|
||||
#
|
||||
# We're already up-to-date.
|
||||
#
|
||||
done = True
|
||||
else:
|
||||
expecting_SOA = True
|
||||
#
|
||||
# Process SOAs in the answer section (other than the initial
|
||||
# SOA in the first message).
|
||||
#
|
||||
for rrset in r.answer[answer_index:]:
|
||||
if done:
|
||||
raise dns.exception.FormError("answers after final SOA")
|
||||
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
|
||||
if expecting_SOA:
|
||||
if rrset[0].serial != serial:
|
||||
raise dns.exception.FormError("IXFR base serial mismatch")
|
||||
expecting_SOA = False
|
||||
elif rdtype == dns.rdatatype.IXFR:
|
||||
delete_mode = not delete_mode
|
||||
#
|
||||
# If this SOA RRset is equal to the first we saw then we're
|
||||
# finished. If this is an IXFR we also check that we're
|
||||
# seeing the record in the expected part of the response.
|
||||
#
|
||||
if rrset == soa_rrset and (
|
||||
rdtype == dns.rdatatype.AXFR
|
||||
or (rdtype == dns.rdatatype.IXFR and delete_mode)
|
||||
):
|
||||
done = True
|
||||
elif expecting_SOA:
|
||||
#
|
||||
# We made an IXFR request and are expecting another
|
||||
# SOA RR, but saw something else, so this must be an
|
||||
# AXFR response.
|
||||
#
|
||||
rdtype = dns.rdatatype.AXFR
|
||||
expecting_SOA = False
|
||||
if done and q.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
yield r
|
||||
|
||||
|
||||
class UDPMode(enum.IntEnum):
|
||||
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
|
||||
|
||||
NEVER means "never use UDP; always use TCP"
|
||||
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
|
||||
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
|
||||
"""
|
||||
|
||||
NEVER = 0
|
||||
TRY_FIRST = 1
|
||||
ONLY = 2
|
||||
yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
|
||||
|
||||
|
||||
def inbound_xfr(
|
||||
|
@ -1514,65 +1641,25 @@ def inbound_xfr(
|
|||
(query, serial) = dns.xfr.make_query(txn_manager)
|
||||
else:
|
||||
serial = dns.xfr.extract_serial_from_query(query)
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
|
||||
(af, destination, source) = _destination_and_source(
|
||||
where, port, source, source_port
|
||||
)
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
retry = True
|
||||
while retry:
|
||||
retry = False
|
||||
if is_ixfr and udp_mode != UDPMode.NEVER:
|
||||
sock_type = socket.SOCK_DGRAM
|
||||
is_udp = True
|
||||
else:
|
||||
sock_type = socket.SOCK_STREAM
|
||||
is_udp = False
|
||||
with _make_socket(af, sock_type, source) as s:
|
||||
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
|
||||
with _make_socket(af, socket.SOCK_DGRAM, source) as s:
|
||||
_connect(s, destination, expiration)
|
||||
if is_udp:
|
||||
_udp_send(s, wire, None, expiration)
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
(rwire, _) = _udp_recv(s, 65535, mexpiration)
|
||||
else:
|
||||
ldata = _net_read(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = _net_read(s, l, mexpiration)
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
try:
|
||||
done = inbound.process_message(r)
|
||||
except dns.xfr.UseTCP:
|
||||
assert is_udp # should not happen if we used TCP!
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
done = True
|
||||
retry = True
|
||||
udp_mode = UDPMode.NEVER
|
||||
continue
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if not retry and query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
try:
|
||||
for _ in _inbound_xfr(
|
||||
txn_manager, s, query, serial, timeout, expiration
|
||||
):
|
||||
pass
|
||||
return
|
||||
except dns.xfr.UseTCP:
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
|
||||
with _make_socket(af, socket.SOCK_STREAM, source) as s:
|
||||
_connect(s, destination, expiration)
|
||||
for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue