From 3c93b5600fc765501d03c7bc6afec9fb7130a7e9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 25 Jan 2022 11:08:24 -0800 Subject: [PATCH] Bump dnspython from 2.0.0 to 2.2.0 (#1618) * Bump dnspython from 2.0.0 to 2.2.0 Bumps [dnspython]() from 2.0.0 to 2.2.0. --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update dnspython==2.2.0 Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/dns/__init__.py | 5 + lib/dns/_asyncbackend.py | 19 +- lib/dns/_asyncio_backend.py | 29 +- lib/dns/_curio_backend.py | 4 +- lib/dns/_immutable_attr.py | 84 +++ lib/dns/_immutable_ctx.py | 75 +++ lib/dns/_trio_backend.py | 4 +- lib/dns/asyncbackend.py | 7 +- lib/dns/asyncbackend.pyi | 13 + lib/dns/asyncquery.py | 365 +++++++------ lib/dns/asyncquery.pyi | 43 ++ lib/dns/asyncresolver.py | 125 ++--- lib/dns/asyncresolver.pyi | 26 + lib/dns/dnssec.py | 315 +++++------ lib/dns/dnssec.pyi | 21 + lib/dns/e164.pyi | 10 + lib/dns/edns.py | 132 ++++- lib/dns/entropy.pyi | 10 + lib/dns/enum.py | 2 +- lib/dns/exception.py | 14 + lib/dns/exception.pyi | 10 + lib/dns/flags.py | 23 +- lib/dns/grange.py | 24 +- lib/dns/immutable.py | 70 +++ lib/dns/inet.py | 2 +- lib/dns/inet.pyi | 4 + lib/dns/ipv6.py | 16 +- lib/dns/message.py | 481 ++++++++++++----- lib/dns/message.pyi | 47 ++ lib/dns/name.py | 21 +- lib/dns/name.pyi | 40 ++ lib/dns/namedict.py | 2 +- lib/dns/node.py | 143 ++++- lib/dns/node.pyi | 17 + lib/dns/opcode.py | 12 +- lib/dns/query.py | 408 ++++++++++----- lib/dns/query.pyi | 64 +++ lib/dns/rcode.py | 29 +- lib/dns/rdata.py | 364 ++++++++++--- lib/dns/rdata.pyi | 19 + lib/dns/rdataclass.py | 15 +- lib/dns/rdataset.py | 121 ++++- lib/dns/rdataset.pyi | 58 ++ lib/dns/rdatatype.py | 98 +++- lib/dns/rdtypes/ANY/AFSDB.py | 2 + lib/dns/rdtypes/ANY/AMTRELAY.py | 27 +- lib/dns/rdtypes/ANY/AVC.py | 2 + lib/dns/rdtypes/ANY/CAA.py | 14 +- lib/dns/rdtypes/ANY/CDNSKEY.py | 6 +- lib/dns/rdtypes/ANY/CDS.py | 7 + lib/dns/rdtypes/ANY/CERT.py | 26 +- lib/dns/rdtypes/ANY/CNAME.py | 2 + lib/dns/rdtypes/ANY/CSYNC.py | 19 +- lib/dns/rdtypes/ANY/DLV.py | 2 + lib/dns/rdtypes/ANY/DNAME.py | 2 + lib/dns/rdtypes/ANY/DNSKEY.py | 6 +- lib/dns/rdtypes/ANY/DS.py | 2 + lib/dns/rdtypes/ANY/EUI48.py | 2 + lib/dns/rdtypes/ANY/EUI64.py | 2 + lib/dns/rdtypes/ANY/GPOS.py | 21 +- lib/dns/rdtypes/ANY/HINFO.py | 13 +- lib/dns/rdtypes/ANY/HIP.py | 17 +- lib/dns/rdtypes/ANY/ISDN.py | 21 +- lib/dns/rdtypes/ANY/L32.py | 40 ++ lib/dns/rdtypes/ANY/L64.py | 48 ++ lib/dns/rdtypes/ANY/LOC.py | 68 +-- lib/dns/rdtypes/ANY/LP.py | 41 ++ lib/dns/rdtypes/ANY/MX.py | 2 + lib/dns/rdtypes/ANY/NID.py | 47 ++ lib/dns/rdtypes/ANY/NINFO.py | 2 + lib/dns/rdtypes/ANY/NS.py | 2 + lib/dns/rdtypes/ANY/NSEC.py | 17 +- lib/dns/rdtypes/ANY/NSEC3.py | 28 +- lib/dns/rdtypes/ANY/NSEC3PARAM.py | 14 +- lib/dns/rdtypes/ANY/OPENPGPKEY.py | 6 +- lib/dns/rdtypes/ANY/OPT.py | 11 +- lib/dns/rdtypes/ANY/PTR.py | 2 + lib/dns/rdtypes/ANY/RP.py | 7 +- lib/dns/rdtypes/ANY/RRSIG.py | 22 +- lib/dns/rdtypes/ANY/RT.py | 2 + lib/dns/rdtypes/ANY/SMIMEA.py | 9 + lib/dns/rdtypes/ANY/SOA.py | 17 +- lib/dns/rdtypes/ANY/SPF.py | 2 + lib/dns/rdtypes/ANY/SSHFP.py | 13 +- lib/dns/rdtypes/ANY/TKEY.py | 118 +++++ lib/dns/rdtypes/ANY/TLSA.py | 65 +-- lib/dns/rdtypes/ANY/TSIG.py | 55 +- lib/dns/rdtypes/ANY/TXT.py | 2 + lib/dns/rdtypes/ANY/URI.py | 25 +- lib/dns/rdtypes/ANY/X25.py | 8 +- lib/dns/rdtypes/ANY/ZONEMD.py | 65 +++ lib/dns/rdtypes/ANY/__init__.py | 5 + lib/dns/rdtypes/CH/A.py | 10 +- lib/dns/rdtypes/IN/A.py | 9 +- lib/dns/rdtypes/IN/AAAA.py | 9 +- lib/dns/rdtypes/IN/APL.py | 32 +- lib/dns/rdtypes/IN/DHCID.py | 6 +- lib/dns/rdtypes/IN/HTTPS.py | 8 + lib/dns/rdtypes/IN/IPSECKEY.py | 26 +- lib/dns/rdtypes/IN/KX.py | 2 + lib/dns/rdtypes/IN/NAPTR.py | 31 +- lib/dns/rdtypes/IN/NSAP.py | 5 +- lib/dns/rdtypes/IN/NSAP_PTR.py | 2 + lib/dns/rdtypes/IN/PX.py | 17 +- lib/dns/rdtypes/IN/SRV.py | 22 +- lib/dns/rdtypes/IN/SVCB.py | 8 + lib/dns/rdtypes/IN/WKS.py | 37 +- lib/dns/rdtypes/IN/__init__.py | 2 + lib/dns/rdtypes/__init__.py | 5 + lib/dns/rdtypes/dnskeybase.py | 24 +- lib/dns/rdtypes/dnskeybase.pyi | 38 ++ lib/dns/rdtypes/dsbase.py | 31 +- lib/dns/rdtypes/euibase.py | 9 +- lib/dns/rdtypes/mxbase.py | 17 +- lib/dns/rdtypes/nsbase.py | 6 +- lib/dns/rdtypes/svcbbase.py | 555 ++++++++++++++++++++ lib/dns/rdtypes/tlsabase.py | 72 +++ lib/dns/rdtypes/txtbase.py | 25 +- lib/dns/rdtypes/txtbase.pyi | 6 + lib/dns/rdtypes/util.py | 162 ++++-- lib/dns/resolver.py | 653 ++++++++++++----------- lib/dns/resolver.pyi | 61 +++ lib/dns/reversename.pyi | 6 + lib/dns/rrset.py | 55 +- lib/dns/rrset.pyi | 10 + lib/dns/set.py | 8 +- lib/dns/tokenizer.py | 71 ++- lib/dns/transaction.py | 587 +++++++++++++++++++++ lib/dns/tsig.py | 170 +++++- lib/dns/tsigkeyring.py | 7 +- lib/dns/tsigkeyring.pyi | 7 + lib/dns/ttl.py | 29 +- lib/dns/update.py | 11 +- lib/dns/update.pyi | 21 + lib/dns/version.py | 2 +- lib/dns/versioned.py | 274 ++++++++++ lib/dns/win32util.py | 235 +++++++++ lib/dns/wire.py | 3 + lib/dns/xfr.py | 313 +++++++++++ lib/dns/zone.py | 845 +++++++++++++++--------------- lib/dns/zone.pyi | 55 ++ lib/dns/zonefile.py | 624 ++++++++++++++++++++++ requirements.txt | 2 +- 143 files changed, 7498 insertions(+), 2054 deletions(-) create mode 100644 lib/dns/_immutable_attr.py create mode 100644 lib/dns/_immutable_ctx.py create mode 100644 lib/dns/asyncbackend.pyi create mode 100644 lib/dns/asyncquery.pyi create mode 100644 lib/dns/asyncresolver.pyi create mode 100644 lib/dns/dnssec.pyi create mode 100644 lib/dns/e164.pyi create mode 100644 lib/dns/entropy.pyi create mode 100644 lib/dns/exception.pyi create mode 100644 lib/dns/immutable.py create mode 100644 lib/dns/inet.pyi create mode 100644 lib/dns/message.pyi create mode 100644 lib/dns/name.pyi create mode 100644 lib/dns/node.pyi create mode 100644 lib/dns/query.pyi create mode 100644 lib/dns/rdata.pyi create mode 100644 lib/dns/rdataset.pyi create mode 100644 lib/dns/rdtypes/ANY/L32.py create mode 100644 lib/dns/rdtypes/ANY/L64.py create mode 100644 lib/dns/rdtypes/ANY/LP.py create mode 100644 lib/dns/rdtypes/ANY/NID.py create mode 100644 lib/dns/rdtypes/ANY/SMIMEA.py create mode 100644 lib/dns/rdtypes/ANY/TKEY.py create mode 100644 lib/dns/rdtypes/ANY/ZONEMD.py create mode 100644 lib/dns/rdtypes/IN/HTTPS.py create mode 100644 lib/dns/rdtypes/IN/SVCB.py create mode 100644 lib/dns/rdtypes/dnskeybase.pyi create mode 100644 lib/dns/rdtypes/svcbbase.py create mode 100644 lib/dns/rdtypes/tlsabase.py create mode 100644 lib/dns/rdtypes/txtbase.pyi create mode 100644 lib/dns/resolver.pyi create mode 100644 lib/dns/reversename.pyi create mode 100644 lib/dns/rrset.pyi create mode 100644 lib/dns/transaction.py create mode 100644 lib/dns/tsigkeyring.pyi create mode 100644 lib/dns/update.pyi create mode 100644 lib/dns/versioned.py create mode 100644 lib/dns/win32util.py create mode 100644 lib/dns/xfr.py create mode 100644 lib/dns/zone.pyi create mode 100644 lib/dns/zonefile.py diff --git a/lib/dns/__init__.py b/lib/dns/__init__.py index b944701d..0473ca17 100644 --- a/lib/dns/__init__.py +++ b/lib/dns/__init__.py @@ -27,6 +27,7 @@ __all__ = [ 'entropy', 'exception', 'flags', + 'immutable', 'inet', 'ipv4', 'ipv6', @@ -48,14 +49,18 @@ __all__ = [ 'serial', 'set', 'tokenizer', + 'transaction', 'tsig', 'tsigkeyring', 'ttl', 'rdtypes', 'update', 'version', + 'versioned', 'wire', + 'xfr', 'zone', + 'zonefile', ] from dns.version import version as __version__ # noqa diff --git a/lib/dns/_asyncbackend.py b/lib/dns/_asyncbackend.py index c7ecfada..1f3a8287 100644 --- a/lib/dns/_asyncbackend.py +++ b/lib/dns/_asyncbackend.py @@ -27,6 +27,12 @@ class Socket: # pragma: no cover async def close(self): pass + async def getpeername(self): + raise NotImplementedError + + async def getsockname(self): + raise NotImplementedError + async def __aenter__(self): return self @@ -36,18 +42,18 @@ class Socket: # pragma: no cover class DatagramSocket(Socket): # pragma: no cover async def sendto(self, what, destination, timeout): - pass + raise NotImplementedError async def recvfrom(self, size, timeout): - pass + raise NotImplementedError class StreamSocket(Socket): # pragma: no cover - async def sendall(self, what, destination, timeout): - pass + async def sendall(self, what, timeout): + raise NotImplementedError async def recv(self, size, timeout): - pass + raise NotImplementedError class Backend: # pragma: no cover @@ -58,3 +64,6 @@ class Backend: # pragma: no cover source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): raise NotImplementedError + + def datagram_connection_required(self): + return False diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py index 3af34ff8..d737d13c 100644 --- a/lib/dns/_asyncio_backend.py +++ b/lib/dns/_asyncio_backend.py @@ -4,11 +4,14 @@ import socket import asyncio +import sys import dns._asyncbackend import dns.exception +_is_win32 = sys.platform == 'win32' + def _get_running_loop(): try: return asyncio.get_running_loop() @@ -25,16 +28,16 @@ class _DatagramProtocol: self.transport = transport def datagram_received(self, data, addr): - if self.recvfrom: + if self.recvfrom and not self.recvfrom.done(): self.recvfrom.set_result((data, addr)) self.recvfrom = None def error_received(self, exc): # pragma: no cover - if self.recvfrom: + if self.recvfrom and not self.recvfrom.done(): self.recvfrom.set_exception(exc) def connection_lost(self, exc): - if self.recvfrom: + if self.recvfrom and not self.recvfrom.done(): self.recvfrom.set_exception(exc) def close(self): @@ -79,21 +82,19 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): return self.transport.get_extra_info('sockname') -class StreamSocket(dns._asyncbackend.DatagramSocket): +class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, af, reader, writer): self.family = af self.reader = reader self.writer = writer async def sendall(self, what, timeout): - self.writer.write(what), + self.writer.write(what) return await _maybe_wait_for(self.writer.drain(), timeout) - raise dns.exception.Timeout(timeout=timeout) - async def recv(self, count, timeout): - return await _maybe_wait_for(self.reader.read(count), + async def recv(self, size, timeout): + return await _maybe_wait_for(self.reader.read(size), timeout) - raise dns.exception.Timeout(timeout=timeout) async def close(self): self.writer.close() @@ -116,11 +117,16 @@ class Backend(dns._asyncbackend.Backend): async def make_socket(self, af, socktype, proto=0, source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): + if destination is None and socktype == socket.SOCK_DGRAM and \ + _is_win32: + raise NotImplementedError('destinationless datagram sockets ' + 'are not supported by asyncio ' + 'on Windows') loop = _get_running_loop() if socktype == socket.SOCK_DGRAM: transport, protocol = await loop.create_datagram_endpoint( _DatagramProtocol, source, family=af, - proto=proto) + proto=proto, remote_addr=destination) return DatagramSocket(af, transport, protocol) elif socktype == socket.SOCK_STREAM: (r, w) = await _maybe_wait_for( @@ -138,3 +144,6 @@ class Backend(dns._asyncbackend.Backend): async def sleep(self, interval): await asyncio.sleep(interval) + + def datagram_connection_required(self): + return _is_win32 diff --git a/lib/dns/_curio_backend.py b/lib/dns/_curio_backend.py index 300e1b89..6fa7b3a1 100644 --- a/lib/dns/_curio_backend.py +++ b/lib/dns/_curio_backend.py @@ -21,6 +21,8 @@ def _maybe_timeout(timeout): # for brevity _lltuple = dns.inet.low_level_address_tuple +# pylint: disable=redefined-outer-name + class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): @@ -47,7 +49,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): return self.socket.getsockname() -class StreamSocket(dns._asyncbackend.DatagramSocket): +class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, socket): self.socket = socket self.family = socket.family diff --git a/lib/dns/_immutable_attr.py b/lib/dns/_immutable_attr.py new file mode 100644 index 00000000..f7b9f8b0 --- /dev/null +++ b/lib/dns/_immutable_attr.py @@ -0,0 +1,84 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This implementation of the immutable decorator is for python 3.6, +# which doesn't have Context Variables. This implementation is somewhat +# costly for classes with slots, as it adds a __dict__ to them. + + +import inspect + + +class _Immutable: + """Immutable mixin class""" + + # Note we MUST NOT have __slots__ as that causes + # + # TypeError: multiple bases have instance lay-out conflict + # + # when we get mixed in with another class with slots. When we + # get mixed into something with slots, it effectively adds __dict__ to + # the slots of the other class, which allows attribute setting to work, + # albeit at the cost of the dictionary. + + def __setattr__(self, name, value): + if not hasattr(self, '_immutable_init') or \ + self._immutable_init is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if not hasattr(self, '_immutable_init') or \ + self._immutable_init is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__delattr__(name) + + +def _immutable_init(f): + def nf(*args, **kwargs): + try: + # Are we already initializing an immutable class? + previous = args[0]._immutable_init + except AttributeError: + # We are the first! + previous = None + object.__setattr__(args[0], '_immutable_init', args[0]) + try: + # call the actual __init__ + f(*args, **kwargs) + finally: + if not previous: + # If we started the initialzation, establish immutability + # by removing the attribute that allows mutation + object.__delattr__(args[0], '_immutable_init') + nf.__signature__ = inspect.signature(f) + return nf + + +def immutable(cls): + if _Immutable in cls.__mro__: + # Some ancestor already has the mixin, so just make sure we keep + # following the __init__ protocol. + cls.__init__ = _immutable_init(cls.__init__) + if hasattr(cls, '__setstate__'): + cls.__setstate__ = _immutable_init(cls.__setstate__) + ncls = cls + else: + # Mixin the Immutable class and follow the __init__ protocol. + class ncls(_Immutable, cls): + + @_immutable_init + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if hasattr(cls, '__setstate__'): + @_immutable_init + def __setstate__(self, *args, **kwargs): + super().__setstate__(*args, **kwargs) + + # make ncls have the same name and module as cls + ncls.__name__ = cls.__name__ + ncls.__qualname__ = cls.__qualname__ + ncls.__module__ = cls.__module__ + return ncls diff --git a/lib/dns/_immutable_ctx.py b/lib/dns/_immutable_ctx.py new file mode 100644 index 00000000..ececdbeb --- /dev/null +++ b/lib/dns/_immutable_ctx.py @@ -0,0 +1,75 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This implementation of the immutable decorator requires python >= +# 3.7, and is significantly more storage efficient when making classes +# with slots immutable. It's also faster. + +import contextvars +import inspect + + +_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False) + + +class _Immutable: + """Immutable mixin class""" + + # We set slots to the empty list to say "we don't have any attributes". + # We do this so that if we're mixed in with a class with __slots__, we + # don't cause a __dict__ to be added which would waste space. + + __slots__ = () + + def __setattr__(self, name, value): + if _in__init__.get() is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if _in__init__.get() is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__delattr__(name) + + +def _immutable_init(f): + def nf(*args, **kwargs): + previous = _in__init__.set(args[0]) + try: + # call the actual __init__ + f(*args, **kwargs) + finally: + _in__init__.reset(previous) + nf.__signature__ = inspect.signature(f) + return nf + + +def immutable(cls): + if _Immutable in cls.__mro__: + # Some ancestor already has the mixin, so just make sure we keep + # following the __init__ protocol. + cls.__init__ = _immutable_init(cls.__init__) + if hasattr(cls, '__setstate__'): + cls.__setstate__ = _immutable_init(cls.__setstate__) + ncls = cls + else: + # Mixin the Immutable class and follow the __init__ protocol. + class ncls(_Immutable, cls): + # We have to do the __slots__ declaration here too! + __slots__ = () + + @_immutable_init + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if hasattr(cls, '__setstate__'): + @_immutable_init + def __setstate__(self, *args, **kwargs): + super().__setstate__(*args, **kwargs) + + # make ncls have the same name and module as cls + ncls.__name__ = cls.__name__ + ncls.__qualname__ = cls.__qualname__ + ncls.__module__ = cls.__module__ + return ncls diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py index 92ea8796..a00d4a4e 100644 --- a/lib/dns/_trio_backend.py +++ b/lib/dns/_trio_backend.py @@ -21,6 +21,8 @@ def _maybe_timeout(timeout): # for brevity _lltuple = dns.inet.low_level_address_tuple +# pylint: disable=redefined-outer-name + class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): @@ -47,7 +49,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): return self.socket.getsockname() -class StreamSocket(dns._asyncbackend.DatagramSocket): +class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, family, stream, tls=False): self.family = family self.stream = stream diff --git a/lib/dns/asyncbackend.py b/lib/dns/asyncbackend.py index 9582a6f8..089d3d35 100644 --- a/lib/dns/asyncbackend.py +++ b/lib/dns/asyncbackend.py @@ -2,9 +2,12 @@ import dns.exception +# pylint: disable=unused-import + from dns._asyncbackend import Socket, DatagramSocket, \ StreamSocket, Backend # noqa: +# pylint: enable=unused-import _default_backend = None @@ -18,13 +21,14 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException): def get_backend(name): - """Get the specified asychronous backend. + """Get the specified asynchronous backend. *name*, a ``str``, the name of the backend. Currently the "trio", "curio", and "asyncio" backends are available. Raises NotImplementError if an unknown backend name is specified. """ + # pylint: disable=import-outside-toplevel,redefined-outer-name backend = _backends.get(name) if backend: return backend @@ -50,6 +54,7 @@ def sniff(): Returns the name of the library, or raises AsyncLibraryNotFoundError if the library cannot be determined. """ + # pylint: disable=import-outside-toplevel try: if _no_sniffio: raise ImportError diff --git a/lib/dns/asyncbackend.pyi b/lib/dns/asyncbackend.pyi new file mode 100644 index 00000000..1ec9d32b --- /dev/null +++ b/lib/dns/asyncbackend.pyi @@ -0,0 +1,13 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +class Backend: + ... + +def get_backend(name: str) -> Backend: + ... +def sniff() -> str: + ... +def get_default_backend() -> Backend: + ... +def set_default_backend(name: str) -> Backend: + ... diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py index b7926480..4ec97fb7 100644 --- a/lib/dns/asyncquery.py +++ b/lib/dns/asyncquery.py @@ -17,6 +17,7 @@ """Talk to a DNS server.""" +import base64 import socket import struct import time @@ -30,8 +31,11 @@ import dns.rcode import dns.rdataclass import dns.rdatatype -from dns.query import _compute_times, _matches_destination, BadResponse, ssl +from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ + UDPMode, _have_httpx, _have_http2, NoDOH +if _have_httpx: + import httpx # for brevity _lltuple = dns.inet.low_level_address_tuple @@ -94,36 +98,8 @@ async def receive_udp(sock, destination=None, expiration=None, *sock*, a ``dns.asyncbackend.DatagramSocket``. - *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the message is expected to arrive from. - When receiving a response, this would be where the associated query was - sent. - - *expiration*, a ``float`` or ``None``, the absolute time at which - a timeout exception should be raised. If ``None``, no timeout will - occur. - - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *keyring*, a ``dict``, the keyring to use for TSIG. - - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. - - Raises if the message is malformed, if network errors occur, of if - there is a timeout. - - Returns a ``(dns.message.Message, float, tuple)`` tuple of the received - message, the received time, and the address where the message arrived from. + See :py:func:`dns.query.receive_udp()` for the documentation of the other + parameters, exceptions, and return type of this method. """ wire = b'' @@ -145,34 +121,6 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, backend=None): """Return the response obtained after sending a query via UDP. - *q*, a ``dns.message.Message``, the query to send - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the - query times out. If ``None``, the default, wait forever. - - *port*, an ``int``, the port send the message to. The default is 53. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. - *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, the socket to use for the query. If ``None``, the default, a socket is created. Note that if a socket is provided, the @@ -181,7 +129,8 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, the default, then dnspython will use the default backend. - Returns a ``dns.message.Message``. + See :py:func:`dns.query.udp()` for the documentation of the other + parameters, exceptions, and return type of this method. """ wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) @@ -196,7 +145,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, if not backend: backend = dns.asyncbackend.get_default_backend() stuple = _source_tuple(af, source, source_port) - s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) + if backend.datagram_connection_required(): + dtuple = (where, port) + else: + dtuple = None + s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, + dtuple) await send_udp(s, wire, destination, expiration) (r, received_time, _) = await receive_udp(s, destination, expiration, ignore_unexpected, @@ -219,31 +173,6 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None, """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. - *q*, a ``dns.message.Message``, the query to send - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the - query times out. If ``None``, the default, wait forever. - - *port*, an ``int``, the port send the message to. The default is 53. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, the socket to use for the UDP query. If ``None``, the default, a socket is created. Note that if a socket is provided the *source*, @@ -257,8 +186,9 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, the default, then dnspython will use the default backend. - Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` - if and only if TCP was used. + See :py:func:`dns.query.udp_with_fallback()` for the documentation + of the other parameters, exceptions, and return type of this + method. """ try: response = await udp(q, where, timeout, port, source, source_port, @@ -275,15 +205,10 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None, async def send_tcp(sock, what, expiration=None): """Send a DNS message to the specified TCP socket. - *sock*, a ``socket``. + *sock*, a ``dns.asyncbackend.StreamSocket``. - *what*, a ``bytes`` or ``dns.message.Message``, the message to send. - - *expiration*, a ``float`` or ``None``, the absolute time at which - a timeout exception should be raised. If ``None``, no timeout will - occur. - - Returns an ``(int, float)`` tuple of bytes sent and the sent time. + See :py:func:`dns.query.send_tcp()` for the documentation of the other + parameters, exceptions, and return type of this method. """ if isinstance(what, dns.message.Message): @@ -294,7 +219,7 @@ async def send_tcp(sock, what, expiration=None): # onto the net tcpmsg = struct.pack("!H", l) + what sent_time = time.time() - await sock.sendall(tcpmsg, expiration) + await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) return (len(tcpmsg), sent_time) @@ -316,27 +241,10 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False): """Read a DNS message from a TCP socket. - *sock*, a ``socket``. + *sock*, a ``dns.asyncbackend.StreamSocket``. - *expiration*, a ``float`` or ``None``, the absolute time at which - a timeout exception should be raised. If ``None``, no timeout will - occur. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *keyring*, a ``dict``, the keyring to use for TSIG. - - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - Raises if the message is malformed, if network errors occur, of if - there is a timeout. - - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + See :py:func:`dns.query.receive_tcp()` for the documentation of the other + parameters, exceptions, and return type of this method. """ ldata = await _read_exactly(sock, 2, expiration) @@ -354,28 +262,6 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, backend=None): """Return the response obtained after sending a query via TCP. - *q*, a ``dns.message.Message``, the query to send - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the - query times out. If ``None``, the default, wait forever. - - *port*, an ``int``, the port send the message to. The default is 53. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the socket to use for the query. If ``None``, the default, a socket is created. Note that if a socket is provided @@ -384,7 +270,8 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, the default, then dnspython will use the default backend. - Returns a ``dns.message.Message``. + See :py:func:`dns.query.tcp()` for the documentation of the other + parameters, exceptions, and return type of this method. """ wire = q.to_wire() @@ -426,28 +313,6 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, backend=None, ssl_context=None, server_hostname=None): """Return the response obtained after sending a query via TLS. - *q*, a ``dns.message.Message``, the query to send - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the - query times out. If ``None``, the default, wait forever. - - *port*, an ``int``, the port send the message to. The default is 853. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket to use for the query. If ``None``, the default, a socket is created. Note that if a socket is provided, it must be a @@ -458,15 +323,8 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, the default, then dnspython will use the default backend. - *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing - a TLS connection. If ``None``, the default, creates one with the default - configuration. - - *server_hostname*, a ``str`` containing the server's hostname. The - default is ``None``, which means that no hostname is known, and if an - SSL context is created, hostname checking will be disabled. - - Returns a ``dns.message.Message``. + See :py:func:`dns.query.tls()` for the documentation of the other + parameters, exceptions, and return type of this method. """ # After 3.6 is no longer supported, this can use an AsyncExitStack. (begin_time, expiration) = _compute_times(timeout) @@ -498,3 +356,168 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, finally: if not sock and s: await s.close() + +async def https(q, where, timeout=None, port=443, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, client=None, + path='/dns-query', post=True, verify=True): + """Return the response obtained after sending a query via DNS-over-HTTPS. + + *client*, a ``httpx.AsyncClient``. If provided, the client to use for + the query. + + Unlike the other dnspython async functions, a backend cannot be provided + in this function because httpx always auto-detects the async backend. + + See :py:func:`dns.query.https()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if not _have_httpx: + raise NoDOH('httpx is not available.') # pragma: no cover + + wire = q.to_wire() + try: + af = dns.inet.af_for_address(where) + except ValueError: + af = None + transport = None + headers = { + "accept": "application/dns-message" + } + if af is not None: + if af == socket.AF_INET: + url = 'https://{}:{}{}'.format(where, port, path) + elif af == socket.AF_INET6: + url = 'https://[{}]:{}{}'.format(where, port, path) + else: + url = where + if source is not None: + transport = httpx.AsyncHTTPTransport(local_address=source[0]) + + # After 3.6 is no longer supported, this can use an AsyncExitStack + client_to_close = None + try: + if not client: + client = httpx.AsyncClient(http1=True, http2=_have_http2, + verify=verify, transport=transport) + client_to_close = client + + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH + # GET and POST examples + if post: + headers.update({ + "content-type": "application/dns-message", + "content-length": str(len(wire)) + }) + response = await client.post(url, headers=headers, content=wire, + timeout=timeout) + else: + wire = base64.urlsafe_b64encode(wire).rstrip(b"=") + wire = wire.decode() # httpx does a repr() if we give it bytes + response = await client.get(url, headers=headers, timeout=timeout, + params={"dns": wire}) + finally: + if client_to_close: + await client.aclose() + + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH + # status codes + if response.status_code < 200 or response.status_code > 299: + raise ValueError('{} responded with status code {}' + '\nResponse body: {}'.format(where, + response.status_code, + response.content)) + r = dns.message.from_wire(response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + r.time = response.elapsed + if not q.is_response(r): + raise BadResponse + return r + +async def inbound_xfr(where, txn_manager, query=None, + port=53, timeout=None, lifetime=None, source=None, + source_port=0, udp_mode=UDPMode.NEVER, backend=None): + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.inbound_xfr()` for the documentation of + the other parameters, exceptions, and return type of this method. + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + if not backend: + backend = dns.asyncbackend.get_default_backend() + s = await backend.make_socket(af, sock_type, 0, stuple, dtuple, + _timeout(expiration)) + async with s: + if is_udp: + await s.sendto(wire, dtuple, _timeout(expiration)) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + await s.sendall(tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, + is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or \ + (expiration is not None and mexpiration > expiration): + mexpiration = expiration + if is_udp: + destination = _lltuple((where, port), af) + while True: + timeout = _timeout(mexpiration) + (rwire, from_address) = await s.recvfrom(65535, + timeout) + if _matches_destination(af, from_address, + destination, True): + break + else: + ldata = await _read_exactly(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = await _read_exactly(s, l, mexpiration) + is_ixfr = (rdtype == dns.rdatatype.IXFR) + r = dns.message.from_wire(rwire, keyring=query.keyring, + request_mac=query.mac, xfr=True, + origin=origin, tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr) + try: + done = inbound.process_message(r) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") diff --git a/lib/dns/asyncquery.pyi b/lib/dns/asyncquery.pyi new file mode 100644 index 00000000..21ef60dd --- /dev/null +++ b/lib/dns/asyncquery.pyi @@ -0,0 +1,43 @@ +from typing import Optional, Union, Dict, Generator, Any +from . import tsig, rdatatype, rdataclass, name, message, asyncbackend + +# If the ssl import works, then +# +# error: Name 'ssl' already defined (by an import) +# +# is expected and can be ignored. +try: + import ssl +except ImportError: + class ssl: # type: ignore + SSLContext : Dict = {} + +async def udp(q : message.Message, where : str, + timeout : Optional[float] = None, port=53, + source : Optional[str] = None, source_port : Optional[int] = 0, + ignore_unexpected : Optional[bool] = False, + one_rr_per_rrset : Optional[bool] = False, + ignore_trailing : Optional[bool] = False, + sock : Optional[asyncbackend.DatagramSocket] = None, + backend : Optional[asyncbackend.Backend]) -> message.Message: + pass + +async def tcp(q : message.Message, where : str, timeout : float = None, port=53, + af : Optional[int] = None, source : Optional[str] = None, + source_port : Optional[int] = 0, + one_rr_per_rrset : Optional[bool] = False, + ignore_trailing : Optional[bool] = False, + sock : Optional[asyncbackend.StreamSocket] = None, + backend : Optional[asyncbackend.Backend]) -> message.Message: + pass + +async def tls(q : message.Message, where : str, + timeout : Optional[float] = None, port=53, + source : Optional[str] = None, source_port : Optional[int] = 0, + one_rr_per_rrset : Optional[bool] = False, + ignore_trailing : Optional[bool] = False, + sock : Optional[asyncbackend.StreamSocket] = None, + backend : Optional[asyncbackend.Backend], + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None) -> message.Message: + pass diff --git a/lib/dns/asyncresolver.py b/lib/dns/asyncresolver.py index 3ac334f5..ed29deed 100644 --- a/lib/dns/asyncresolver.py +++ b/lib/dns/asyncresolver.py @@ -34,7 +34,8 @@ _udp = dns.asyncquery.udp _tcp = dns.asyncquery.tcp -class Resolver(dns.resolver.Resolver): +class Resolver(dns.resolver.BaseResolver): + """Asynchronous DNS stub resolver.""" async def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, @@ -43,53 +44,12 @@ class Resolver(dns.resolver.Resolver): backend=None): """Query nameservers asynchronously to find the answer to the question. - The *qname*, *rdtype*, and *rdclass* parameters may be objects - of the appropriate type, or strings that can be converted into objects - of the appropriate type. - - *qname*, a ``dns.name.Name`` or ``str``, the query name. - - *rdtype*, an ``int`` or ``str``, the query type. - - *rdclass*, an ``int`` or ``str``, the query class. - - *tcp*, a ``bool``. If ``True``, use TCP to make the query. - - *source*, a ``str`` or ``None``. If not ``None``, bind to this IP - address when making queries. - - *raise_on_no_answer*, a ``bool``. If ``True``, raise - ``dns.resolver.NoAnswer`` if there's no answer to the question. - - *source_port*, an ``int``, the port from which to send the message. - - *lifetime*, a ``float``, how many seconds a query should run - before timing out. - - *search*, a ``bool`` or ``None``, determines whether the - search list configured in the system's resolver configuration - are used for relative names, and whether the resolver's domain - may be added to relative names. The default is ``None``, - which causes the value of the resolver's - ``use_search_by_default`` attribute to be used. - *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, the default, then dnspython will use the default backend. - Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. - - Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after - DNAME substitution. - - Raises ``dns.resolver.NoAnswer`` if *raise_on_no_answer* is - ``True`` and the query name exists but has no RRset of the - desired type and class. - - Raises ``dns.resolver.NoNameservers`` if no non-broken - nameservers are available to answer the question. - - Returns a ``dns.resolver.Answer`` instance. - + See :py:func:`dns.resolver.Resolver.resolve()` for the + documentation of the other parameters, exceptions, and return + type of this method. """ resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, @@ -111,7 +71,8 @@ class Resolver(dns.resolver.Resolver): (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: await backend.sleep(backoff) - timeout = self._compute_timeout(start, lifetime) + timeout = self._compute_timeout(start, lifetime, + resolution.errors) try: if dns.inet.is_address(nameserver): if tcp: @@ -126,8 +87,9 @@ class Resolver(dns.resolver.Resolver): raise_on_truncation=True, backend=backend) else: - # We don't do DoH yet. - raise NotImplementedError + response = await dns.asyncquery.https(request, + nameserver, + timeout=timeout) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -139,11 +101,6 @@ class Resolver(dns.resolver.Resolver): if answer is not None: return answer - async def query(self, *args, **kwargs): - # We have to define something here as we don't want to inherit the - # parent's query(). - raise NotImplementedError - async def resolve_address(self, ipaddr, *args, **kwargs): """Use an asynchronous resolver to run a reverse query for PTR records. @@ -165,6 +122,30 @@ class Resolver(dns.resolver.Resolver): rdclass=dns.rdataclass.IN, *args, **kwargs) + # pylint: disable=redefined-outer-name + + async def canonical_name(self, name): + """Determine the canonical name of *name*. + + The canonical name is the name the resolver uses for queries + after all CNAME and DNAME renamings have been applied. + + *name*, a ``dns.name.Name`` or ``str``, the query name. + + This method can raise any exception that ``resolve()`` can + raise, other than ``dns.resolver.NoAnswer`` and + ``dns.resolver.NXDOMAIN``. + + Returns a ``dns.name.Name``. + """ + try: + answer = await self.resolve(name, raise_on_no_answer=False) + canonical_name = answer.canonical_name + except dns.resolver.NXDOMAIN as e: + canonical_name = e.canonical_name + return canonical_name + + default_resolver = None @@ -188,52 +169,46 @@ def reset_default_resolver(): async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, tcp=False, source=None, raise_on_no_answer=True, - source_port=0, search=None, backend=None): + source_port=0, lifetime=None, search=None, backend=None): """Query nameservers asynchronously to find the answer to the question. This is a convenience function that uses the default resolver object to make the query. - See ``dns.asyncresolver.Resolver.resolve`` for more information on the - parameters. + See :py:func:`dns.asyncresolver.Resolver.resolve` for more + information on the parameters. """ return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, source, raise_on_no_answer, - source_port, search, backend) + source_port, lifetime, search, + backend) async def resolve_address(ipaddr, *args, **kwargs): """Use a resolver to run a reverse query for PTR records. - See ``dns.asyncresolver.Resolver.resolve_address`` for more + See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more information on the parameters. """ return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +async def canonical_name(name): + """Determine the canonical name of *name*. + + See :py:func:`dns.resolver.Resolver.canonical_name` for more + information on the parameters and possible exceptions. + """ + + return await get_default_resolver().canonical_name(name) async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, backend=None): """Find the name of the zone which contains the specified name. - *name*, an absolute ``dns.name.Name`` or ``str``, the query name. - - *rdclass*, an ``int``, the query class. - - *tcp*, a ``bool``. If ``True``, use TCP to make the query. - - *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the - resolver to use. If ``None``, the default resolver is used. - - *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, - the default, then dnspython will use the default backend. - - Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS - root. (This is only likely to happen if you're using non-default - root servers in your network and they are misconfigured.) - - Returns a ``dns.name.Name``. + See :py:func:`dns.resolver.Resolver.zone_for_name` for more + information on the parameters and possible exceptions. """ if isinstance(name, str): diff --git a/lib/dns/asyncresolver.pyi b/lib/dns/asyncresolver.pyi new file mode 100644 index 00000000..92759d29 --- /dev/null +++ b/lib/dns/asyncresolver.pyi @@ -0,0 +1,26 @@ +from typing import Union, Optional, List, Any, Dict +from . import exception, rdataclass, name, rdatatype, asyncbackend + +async def resolve(qname : str, rdtype : Union[int,str] = 0, + rdclass : Union[int,str] = 0, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime : Optional[float]=None, + search : Optional[bool]=None, + backend : Optional[asyncbackend.Backend]=None): + ... +async def resolve_address(self, ipaddr: str, + *args: Any, **kwargs: Optional[Dict]): + ... + +class Resolver: + def __init__(self, filename : Optional[str] = '/etc/resolv.conf', + configure : Optional[bool] = True): + self.nameservers : List[str] + async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, + rdclass : Union[int,str] = rdataclass.IN, + tcp : bool = False, source : Optional[str] = None, + raise_on_no_answer=True, source_port : int = 0, + lifetime : Optional[float]=None, + search : Optional[bool]=None, + backend : Optional[asyncbackend.Backend]=None): + ... diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index c50abf8d..6e9946f4 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -64,9 +64,6 @@ class Algorithm(dns.enum.IntEnum): return 255 -globals().update(Algorithm.__members__) - - def algorithm_from_text(text): """Convert text into a DNSSEC algorithm value. @@ -169,23 +166,15 @@ def make_ds(name, key, algorithm, origin=None): def _find_candidate_keys(keys, rrsig): - candidate_keys = [] value = keys.get(rrsig.signer) - if value is None: - return None if isinstance(value, dns.node.Node): - try: - rdataset = value.find_rdataset(dns.rdataclass.IN, - dns.rdatatype.DNSKEY) - except KeyError: - return None + rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY) else: rdataset = value - for rdata in rdataset: - if rdata.algorithm == rrsig.algorithm and \ - key_id(rdata) == rrsig.key_tag: - candidate_keys.append(rdata) - return candidate_keys + if rdataset is None: + return None + return [rd for rd in rdataset if + rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag] def _is_rsa(algorithm): @@ -254,6 +243,82 @@ def _bytes_to_long(b): return int.from_bytes(b, 'big') +def _validate_signature(sig, data, key, chosen_hash): + if _is_rsa(key.algorithm): + keyptr = key.key + (bytes_,) = struct.unpack('!B', keyptr[0:1]) + keyptr = keyptr[1:] + if bytes_ == 0: + (bytes_,) = struct.unpack('!H', keyptr[0:2]) + keyptr = keyptr[2:] + rsa_e = keyptr[0:bytes_] + rsa_n = keyptr[bytes_:] + try: + public_key = rsa.RSAPublicNumbers( + _bytes_to_long(rsa_e), + _bytes_to_long(rsa_n)).public_key(default_backend()) + except ValueError: + raise ValidationFailure('invalid public key') + public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) + elif _is_dsa(key.algorithm): + keyptr = key.key + (t,) = struct.unpack('!B', keyptr[0:1]) + keyptr = keyptr[1:] + octets = 64 + t * 8 + dsa_q = keyptr[0:20] + keyptr = keyptr[20:] + dsa_p = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_g = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_y = keyptr[0:octets] + try: + public_key = dsa.DSAPublicNumbers( + _bytes_to_long(dsa_y), + dsa.DSAParameterNumbers( + _bytes_to_long(dsa_p), + _bytes_to_long(dsa_q), + _bytes_to_long(dsa_g))).public_key(default_backend()) + except ValueError: + raise ValidationFailure('invalid public key') + public_key.verify(sig, data, chosen_hash) + elif _is_ecdsa(key.algorithm): + keyptr = key.key + if key.algorithm == Algorithm.ECDSAP256SHA256: + curve = ec.SECP256R1() + octets = 32 + else: + curve = ec.SECP384R1() + octets = 48 + ecdsa_x = keyptr[0:octets] + ecdsa_y = keyptr[octets:octets * 2] + try: + public_key = ec.EllipticCurvePublicNumbers( + curve=curve, + x=_bytes_to_long(ecdsa_x), + y=_bytes_to_long(ecdsa_y)).public_key(default_backend()) + except ValueError: + raise ValidationFailure('invalid public key') + public_key.verify(sig, data, ec.ECDSA(chosen_hash)) + elif _is_eddsa(key.algorithm): + keyptr = key.key + if key.algorithm == Algorithm.ED25519: + loader = ed25519.Ed25519PublicKey + else: + loader = ed448.Ed448PublicKey + try: + public_key = loader.from_public_bytes(keyptr) + except ValueError: + raise ValidationFailure('invalid public key') + public_key.verify(sig, data) + elif _is_gost(key.algorithm): + raise UnsupportedAlgorithm( + 'algorithm "%s" not supported by dnspython' % + algorithm_to_text(key.algorithm)) + else: + raise ValidationFailure('unknown algorithm %u' % key.algorithm) + + def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): """Validate an RRset against a single signature rdata, throwing an exception if validation is not successful. @@ -291,143 +356,69 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): if candidate_keys is None: raise ValidationFailure('unknown key') + # For convenience, allow the rrset to be specified as a (name, + # rdataset) tuple as well as a proper rrset + if isinstance(rrset, tuple): + rrname = rrset[0] + rdataset = rrset[1] + else: + rrname = rrset.name + rdataset = rrset + + if now is None: + now = time.time() + if rrsig.expiration < now: + raise ValidationFailure('expired') + if rrsig.inception > now: + raise ValidationFailure('not yet valid') + + if _is_dsa(rrsig.algorithm): + sig_r = rrsig.signature[1:21] + sig_s = rrsig.signature[21:] + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), + _bytes_to_long(sig_s)) + elif _is_ecdsa(rrsig.algorithm): + if rrsig.algorithm == Algorithm.ECDSAP256SHA256: + octets = 32 + else: + octets = 48 + sig_r = rrsig.signature[0:octets] + sig_s = rrsig.signature[octets:] + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), + _bytes_to_long(sig_s)) + else: + sig = rrsig.signature + + data = b'' + data += rrsig.to_wire(origin=origin)[:18] + data += rrsig.signer.to_digestable(origin) + + # Derelativize the name before considering labels. + rrname = rrname.derelativize(origin) + + if len(rrname) - 1 < rrsig.labels: + raise ValidationFailure('owner name longer than RRSIG labels') + elif rrsig.labels < len(rrname) - 1: + suffix = rrname.split(rrsig.labels + 1)[1] + rrname = dns.name.from_text('*', suffix) + rrnamebuf = rrname.to_digestable() + rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass, + rrsig.original_ttl) + rdatas = [rdata.to_digestable(origin) for rdata in rdataset] + for rdata in sorted(rdatas): + data += rrnamebuf + data += rrfixed + rrlen = struct.pack('!H', len(rdata)) + data += rrlen + data += rdata + + chosen_hash = _make_hash(rrsig.algorithm) + for candidate_key in candidate_keys: - # For convenience, allow the rrset to be specified as a (name, - # rdataset) tuple as well as a proper rrset - if isinstance(rrset, tuple): - rrname = rrset[0] - rdataset = rrset[1] - else: - rrname = rrset.name - rdataset = rrset - - if now is None: - now = time.time() - if rrsig.expiration < now: - raise ValidationFailure('expired') - if rrsig.inception > now: - raise ValidationFailure('not yet valid') - - if _is_rsa(rrsig.algorithm): - keyptr = candidate_key.key - (bytes_,) = struct.unpack('!B', keyptr[0:1]) - keyptr = keyptr[1:] - if bytes_ == 0: - (bytes_,) = struct.unpack('!H', keyptr[0:2]) - keyptr = keyptr[2:] - rsa_e = keyptr[0:bytes_] - rsa_n = keyptr[bytes_:] - try: - public_key = rsa.RSAPublicNumbers( - _bytes_to_long(rsa_e), - _bytes_to_long(rsa_n)).public_key(default_backend()) - except ValueError: - raise ValidationFailure('invalid public key') - sig = rrsig.signature - elif _is_dsa(rrsig.algorithm): - keyptr = candidate_key.key - (t,) = struct.unpack('!B', keyptr[0:1]) - keyptr = keyptr[1:] - octets = 64 + t * 8 - dsa_q = keyptr[0:20] - keyptr = keyptr[20:] - dsa_p = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_g = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_y = keyptr[0:octets] - try: - public_key = dsa.DSAPublicNumbers( - _bytes_to_long(dsa_y), - dsa.DSAParameterNumbers( - _bytes_to_long(dsa_p), - _bytes_to_long(dsa_q), - _bytes_to_long(dsa_g))).public_key(default_backend()) - except ValueError: - raise ValidationFailure('invalid public key') - sig_r = rrsig.signature[1:21] - sig_s = rrsig.signature[21:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), - _bytes_to_long(sig_s)) - elif _is_ecdsa(rrsig.algorithm): - keyptr = candidate_key.key - if rrsig.algorithm == Algorithm.ECDSAP256SHA256: - curve = ec.SECP256R1() - octets = 32 - else: - curve = ec.SECP384R1() - octets = 48 - ecdsa_x = keyptr[0:octets] - ecdsa_y = keyptr[octets:octets * 2] - try: - public_key = ec.EllipticCurvePublicNumbers( - curve=curve, - x=_bytes_to_long(ecdsa_x), - y=_bytes_to_long(ecdsa_y)).public_key(default_backend()) - except ValueError: - raise ValidationFailure('invalid public key') - sig_r = rrsig.signature[0:octets] - sig_s = rrsig.signature[octets:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), - _bytes_to_long(sig_s)) - - elif _is_eddsa(rrsig.algorithm): - keyptr = candidate_key.key - if rrsig.algorithm == Algorithm.ED25519: - loader = ed25519.Ed25519PublicKey - else: - loader = ed448.Ed448PublicKey - try: - public_key = loader.from_public_bytes(keyptr) - except ValueError: - raise ValidationFailure('invalid public key') - sig = rrsig.signature - elif _is_gost(rrsig.algorithm): - raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' % - algorithm_to_text(rrsig.algorithm)) - else: - raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm) - - data = b'' - data += rrsig.to_wire(origin=origin)[:18] - data += rrsig.signer.to_digestable(origin) - - if rrsig.labels < len(rrname) - 1: - suffix = rrname.split(rrsig.labels + 1)[1] - rrname = dns.name.from_text('*', suffix) - rrnamebuf = rrname.to_digestable(origin) - rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass, - rrsig.original_ttl) - rrlist = sorted(rdataset) - for rr in rrlist: - data += rrnamebuf - data += rrfixed - rrdata = rr.to_digestable(origin) - rrlen = struct.pack('!H', len(rrdata)) - data += rrlen - data += rrdata - - chosen_hash = _make_hash(rrsig.algorithm) try: - if _is_rsa(rrsig.algorithm): - public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) - elif _is_dsa(rrsig.algorithm): - public_key.verify(sig, data, chosen_hash) - elif _is_ecdsa(rrsig.algorithm): - public_key.verify(sig, data, ec.ECDSA(chosen_hash)) - elif _is_eddsa(rrsig.algorithm): - public_key.verify(sig, data) - else: - # Raise here for code clarity; this won't actually ever happen - # since if the algorithm is really unknown we'd already have - # raised an exception above - raise ValidationFailure('unknown algorithm %u' % - rrsig.algorithm) # pragma: no cover - # If we got here, we successfully verified so we can return - # without error + _validate_signature(sig, data, candidate_key, chosen_hash) return - except InvalidSignature: + except (InvalidSignature, ValidationFailure): # this happens on an individual validation failure continue # nothing verified -- raise failure: @@ -546,7 +537,7 @@ def nsec3_hash(domain, salt, iterations, algorithm): domain_encoded = domain.canonicalize().to_wire() digest = hashlib.sha1(domain_encoded + salt_encoded).digest() - for i in range(iterations): + for _ in range(iterations): digest = hashlib.sha1(digest + salt_encoded).digest() output = base64.b32encode(digest).decode("utf-8") @@ -579,3 +570,25 @@ else: validate = _validate # type: ignore validate_rrsig = _validate_rrsig # type: ignore _have_pyca = True + +### BEGIN generated Algorithm constants + +RSAMD5 = Algorithm.RSAMD5 +DH = Algorithm.DH +DSA = Algorithm.DSA +ECC = Algorithm.ECC +RSASHA1 = Algorithm.RSASHA1 +DSANSEC3SHA1 = Algorithm.DSANSEC3SHA1 +RSASHA1NSEC3SHA1 = Algorithm.RSASHA1NSEC3SHA1 +RSASHA256 = Algorithm.RSASHA256 +RSASHA512 = Algorithm.RSASHA512 +ECCGOST = Algorithm.ECCGOST +ECDSAP256SHA256 = Algorithm.ECDSAP256SHA256 +ECDSAP384SHA384 = Algorithm.ECDSAP384SHA384 +ED25519 = Algorithm.ED25519 +ED448 = Algorithm.ED448 +INDIRECT = Algorithm.INDIRECT +PRIVATEDNS = Algorithm.PRIVATEDNS +PRIVATEOID = Algorithm.PRIVATEOID + +### END generated Algorithm constants diff --git a/lib/dns/dnssec.pyi b/lib/dns/dnssec.pyi new file mode 100644 index 00000000..e126f9b8 --- /dev/null +++ b/lib/dns/dnssec.pyi @@ -0,0 +1,21 @@ +from typing import Union, Dict, Tuple, Optional +from . import rdataset, rrset, exception, name, rdtypes, rdata, node +import dns.rdtypes.ANY.DS as DS +import dns.rdtypes.ANY.DNSKEY as DNSKEY + +_have_pyca : bool + +def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None: + ... + +def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None: + ... + +class ValidationFailure(exception.DNSException): + ... + +def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS: + ... + +def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str: + ... diff --git a/lib/dns/e164.pyi b/lib/dns/e164.pyi new file mode 100644 index 00000000..37a99fed --- /dev/null +++ b/lib/dns/e164.pyi @@ -0,0 +1,10 @@ +from typing import Optional, Iterable +from . import name, resolver +def from_e164(text : str, origin=name.Name(".")) -> name.Name: + ... + +def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str: + ... + +def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer: + ... diff --git a/lib/dns/edns.py b/lib/dns/edns.py index 28718d52..9d7e909d 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -23,6 +23,8 @@ import struct import dns.enum import dns.inet +import dns.rdata + class OptionType(dns.enum.IntEnum): #: NSID @@ -45,12 +47,13 @@ class OptionType(dns.enum.IntEnum): PADDING = 12 #: CHAIN CHAIN = 13 + #: EDE (extended-dns-error) + EDE = 15 @classmethod def _maximum(cls): return 65535 -globals().update(OptionType.__members__) class Option: @@ -61,7 +64,7 @@ class Option: *otype*, an ``int``, is the option type. """ - self.otype = otype + self.otype = OptionType.make(otype) def to_wire(self, file=None): """Convert an option to wire format. @@ -149,7 +152,7 @@ class GenericOption(Option): def __init__(self, otype, data): super().__init__(otype) - self.data = data + self.data = dns.rdata.Rdata._as_bytes(data, True) def to_wire(self, file=None): if file: @@ -186,12 +189,18 @@ class ECSOption(Option): self.family = 2 if srclen is None: srclen = 56 + address = dns.rdata.Rdata._as_ipv6_address(address) + srclen = dns.rdata.Rdata._as_int(srclen, 0, 128) + scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 128) elif af == socket.AF_INET: self.family = 1 if srclen is None: srclen = 24 - else: - raise ValueError('Bad ip family') + address = dns.rdata.Rdata._as_ipv4_address(address) + srclen = dns.rdata.Rdata._as_int(srclen, 0, 32) + scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32) + else: # pragma: no cover (this will never happen) + raise ValueError('Bad address family') self.address = address self.srclen = srclen @@ -293,10 +302,95 @@ class ECSOption(Option): return cls(addr, src, scope) +class EDECode(dns.enum.IntEnum): + OTHER = 0 + UNSUPPORTED_DNSKEY_ALGORITHM = 1 + UNSUPPORTED_DS_DIGEST_TYPE = 2 + STALE_ANSWER = 3 + FORGED_ANSWER = 4 + DNSSEC_INDETERMINATE = 5 + DNSSEC_BOGUS = 6 + SIGNATURE_EXPIRED = 7 + SIGNATURE_NOT_YET_VALID = 8 + DNSKEY_MISSING = 9 + RRSIGS_MISSING = 10 + NO_ZONE_KEY_BIT_SET = 11 + NSEC_MISSING = 12 + CACHED_ERROR = 13 + NOT_READY = 14 + BLOCKED = 15 + CENSORED = 16 + FILTERED = 17 + PROHIBITED = 18 + STALE_NXDOMAIN_ANSWER = 19 + NOT_AUTHORITATIVE = 20 + NOT_SUPPORTED = 21 + NO_REACHABLE_AUTHORITY = 22 + NETWORK_ERROR = 23 + INVALID_DATA = 24 + + @classmethod + def _maximum(cls): + return 65535 + + +class EDEOption(Option): + """Extended DNS Error (EDE, RFC8914)""" + + def __init__(self, code, text=None): + """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the + extended error. + + *text*, a ``str`` or ``None``, specifying additional information about + the error. + """ + + super().__init__(OptionType.EDE) + + self.code = EDECode.make(code) + if text is not None and not isinstance(text, str): + raise ValueError('text must be string or None') + + self.code = code + self.text = text + + def to_text(self): + output = f'EDE {self.code}' + if self.text is not None: + output += f': {self.text}' + return output + + def to_wire(self, file=None): + value = struct.pack('!H', self.code) + if self.text is not None: + value += self.text.encode('utf8') + + if file: + file.write(value) + else: + return value + + @classmethod + def from_wire_parser(cls, otype, parser): + code = parser.get_uint16() + text = parser.get_remaining() + + if text: + if text[-1] == 0: # text MAY be null-terminated + text = text[:-1] + text = text.decode('utf8') + else: + text = None + + return cls(code, text) + + _type_to_class = { - OptionType.ECS: ECSOption + OptionType.ECS: ECSOption, + OptionType.EDE: EDEOption, } + def get_option_class(otype): """Return the class for the specified option type. @@ -342,3 +436,29 @@ def option_from_wire(otype, wire, current, olen): parser = dns.wire.Parser(wire, current) with parser.restrict_to(olen): return option_from_wire_parser(otype, parser) + +def register_type(implementation, otype): + """Register the implementation of an option type. + + *implementation*, a ``class``, is a subclass of ``dns.edns.Option``. + + *otype*, an ``int``, is the option type. + """ + + _type_to_class[otype] = implementation + +### BEGIN generated OptionType constants + +NSID = OptionType.NSID +DAU = OptionType.DAU +DHU = OptionType.DHU +N3U = OptionType.N3U +ECS = OptionType.ECS +EXPIRE = OptionType.EXPIRE +COOKIE = OptionType.COOKIE +KEEPALIVE = OptionType.KEEPALIVE +PADDING = OptionType.PADDING +CHAIN = OptionType.CHAIN +EDE = OptionType.EDE + +### END generated OptionType constants diff --git a/lib/dns/entropy.pyi b/lib/dns/entropy.pyi new file mode 100644 index 00000000..818f805a --- /dev/null +++ b/lib/dns/entropy.pyi @@ -0,0 +1,10 @@ +from typing import Optional +from random import SystemRandom + +system_random : Optional[SystemRandom] + +def random_16() -> int: + pass + +def between(first: int, last: int) -> int: + pass diff --git a/lib/dns/enum.py b/lib/dns/enum.py index 11536f2b..b822dd51 100644 --- a/lib/dns/enum.py +++ b/lib/dns/enum.py @@ -75,7 +75,7 @@ class IntEnum(enum.IntEnum): @classmethod def _maximum(cls): - raise NotImplementedError + raise NotImplementedError # pragma: no cover @classmethod def _short_name(cls): diff --git a/lib/dns/exception.py b/lib/dns/exception.py index 8f1d4888..93923734 100644 --- a/lib/dns/exception.py +++ b/lib/dns/exception.py @@ -126,3 +126,17 @@ class Timeout(DNSException): """The DNS operation timed out.""" supp_kwargs = {'timeout'} fmt = "The DNS operation timed out after {timeout} seconds" + + +class ExceptionWrapper: + def __init__(self, exception_class): + self.exception_class = exception_class + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and not isinstance(exc_val, + self.exception_class): + raise self.exception_class(str(exc_val)) from exc_val + return False diff --git a/lib/dns/exception.pyi b/lib/dns/exception.pyi new file mode 100644 index 00000000..b29bfbea --- /dev/null +++ b/lib/dns/exception.pyi @@ -0,0 +1,10 @@ +from typing import Set, Optional, Dict + +class DNSException(Exception): + supp_kwargs : Set[str] + kwargs : Optional[Dict] + fmt : Optional[str] + +class SyntaxError(DNSException): ... +class FormError(DNSException): ... +class Timeout(DNSException): ... diff --git a/lib/dns/flags.py b/lib/dns/flags.py index 4eb6d90c..96522879 100644 --- a/lib/dns/flags.py +++ b/lib/dns/flags.py @@ -37,8 +37,6 @@ class Flag(enum.IntFlag): #: Checking Disabled CD = 0x0010 -globals().update(Flag.__members__) - # EDNS flags @@ -47,9 +45,6 @@ class EDNSFlag(enum.IntFlag): DO = 0x8000 -globals().update(EDNSFlag.__members__) - - def _from_text(text, enum_class): flags = 0 tokens = text.split() @@ -104,3 +99,21 @@ def edns_to_text(flags): """ return _to_text(flags, EDNSFlag) + +### BEGIN generated Flag constants + +QR = Flag.QR +AA = Flag.AA +TC = Flag.TC +RD = Flag.RD +RA = Flag.RA +AD = Flag.AD +CD = Flag.CD + +### END generated Flag constants + +### BEGIN generated EDNSFlag constants + +DO = EDNSFlag.DO + +### END generated EDNSFlag constants diff --git a/lib/dns/grange.py b/lib/dns/grange.py index ffe8be7c..112ede47 100644 --- a/lib/dns/grange.py +++ b/lib/dns/grange.py @@ -28,11 +28,12 @@ def from_text(text): Returns a tuple of three ``int`` values ``(start, stop, step)``. """ - # TODO, figure out the bounds on start, stop and step. + start = -1 + stop = -1 step = 1 cur = '' state = 0 - # state 0 1 2 3 4 + # state 0 1 2 # x - y / z if text and text[0] == '-': @@ -42,28 +43,27 @@ def from_text(text): if c == '-' and state == 0: start = int(cur) cur = '' - state = 2 + state = 1 elif c == '/': stop = int(cur) cur = '' - state = 4 + state = 2 elif c.isdigit(): cur += c else: raise dns.exception.SyntaxError("Could not parse %s" % (c)) - if state in (1, 3): - raise dns.exception.SyntaxError() - - if state == 2: + if state == 0: + raise dns.exception.SyntaxError("no stop value specified") + elif state == 1: stop = int(cur) - - if state == 4: + else: + assert state == 2 step = int(cur) assert step >= 1 assert start >= 0 - assert start <= stop - # TODO, can start == stop? + if start > stop: + raise dns.exception.SyntaxError('start must be <= stop') return (start, stop, step) diff --git a/lib/dns/immutable.py b/lib/dns/immutable.py new file mode 100644 index 00000000..db7abbcc --- /dev/null +++ b/lib/dns/immutable.py @@ -0,0 +1,70 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import collections.abc +import sys + +# pylint: disable=unused-import +if sys.version_info >= (3, 7): + odict = dict + from dns._immutable_ctx import immutable +else: + # pragma: no cover + from collections import OrderedDict as odict + from dns._immutable_attr import immutable # noqa +# pylint: enable=unused-import + + +@immutable +class Dict(collections.abc.Mapping): + def __init__(self, dictionary, no_copy=False): + """Make an immutable dictionary from the specified dictionary. + + If *no_copy* is `True`, then *dictionary* will be wrapped instead + of copied. Only set this if you are sure there will be no external + references to the dictionary. + """ + if no_copy and isinstance(dictionary, odict): + self._odict = dictionary + else: + self._odict = odict(dictionary) + self._hash = None + + def __getitem__(self, key): + return self._odict.__getitem__(key) + + def __hash__(self): # pylint: disable=invalid-hash-returned + if self._hash is None: + h = 0 + for key in sorted(self._odict.keys()): + h ^= hash(key) + object.__setattr__(self, '_hash', h) + # this does return an int, but pylint doesn't figure that out + return self._hash + + def __len__(self): + return len(self._odict) + + def __iter__(self): + return iter(self._odict) + + +def constify(o): + """ + Convert mutable types to immutable types. + """ + if isinstance(o, bytearray): + return bytes(o) + if isinstance(o, tuple): + try: + hash(o) + return o + except Exception: + return tuple(constify(elt) for elt in o) + if isinstance(o, list): + return tuple(constify(elt) for elt in o) + if isinstance(o, dict): + cdict = odict() + for k, v in o.items(): + cdict[k] = constify(v) + return Dict(cdict, True) + return o diff --git a/lib/dns/inet.py b/lib/dns/inet.py index 25d99c2c..d3bdc64c 100644 --- a/lib/dns/inet.py +++ b/lib/dns/inet.py @@ -162,7 +162,7 @@ def low_level_address_tuple(high_tuple, af=None): return (addrpart, port, 0, int(scope)) try: return (addrpart, port, 0, socket.if_nametoindex(scope)) - except AttributeError: + except AttributeError: # pragma: no cover (we can't really test this) ai_flags = socket.AI_NUMERICHOST ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) return tup diff --git a/lib/dns/inet.pyi b/lib/dns/inet.pyi new file mode 100644 index 00000000..6d9dcc70 --- /dev/null +++ b/lib/dns/inet.pyi @@ -0,0 +1,4 @@ +from typing import Union +from socket import AddressFamily + +AF_INET6 : Union[int, AddressFamily] diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index 5424fcea..0db6fcfa 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -121,7 +121,13 @@ def inet_aton(text, ignore_scope=False): elif l > 2: raise dns.exception.SyntaxError - if text == b'::': + if text == b'': + raise dns.exception.SyntaxError + elif text.endswith(b':') and not text.endswith(b'::'): + raise dns.exception.SyntaxError + elif text.startswith(b':') and not text.startswith(b'::'): + raise dns.exception.SyntaxError + elif text == b'::': text = b'0::' # # Get rid of the icky dot-quad syntax if we have it. @@ -129,9 +135,9 @@ def inet_aton(text, ignore_scope=False): m = _v4_ending.match(text) if m is not None: b = dns.ipv4.inet_aton(m.group(2)) - text = (u"{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), - b[0], b[1], b[2], - b[3])).encode() + text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), + b[0], b[1], b[2], + b[3])).encode() # # Try to turn '::' into ':'; if no match try to # turn '::' into ':' @@ -157,7 +163,7 @@ def inet_aton(text, ignore_scope=False): if seen_empty: raise dns.exception.SyntaxError seen_empty = True - for i in range(0, 8 - l + 1): + for _ in range(0, 8 - l + 1): canonical.append(b'0000') else: lc = len(c) diff --git a/lib/dns/message.py b/lib/dns/message.py index 60b74c19..1e67a17b 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -35,6 +35,7 @@ import dns.rdataclass import dns.rdatatype import dns.rrset import dns.renderer +import dns.ttl import dns.tsig import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.TSIG @@ -80,6 +81,21 @@ class Truncated(dns.exception.DNSException): return self.kwargs['message'] +class NotQueryResponse(dns.exception.DNSException): + """Message is not a response to a query.""" + + +class ChainTooLong(dns.exception.DNSException): + """The CNAME chain is too long.""" + + +class AnswerForNXDOMAIN(dns.exception.DNSException): + """The rcode is NXDOMAIN but an answer was found.""" + +class NoPreviousName(dns.exception.SyntaxError): + """No previous name was known.""" + + class MessageSection(dns.enum.IntEnum): """Message sections""" QUESTION = 0 @@ -91,8 +107,15 @@ class MessageSection(dns.enum.IntEnum): def _maximum(cls): return 3 -globals().update(MessageSection.__members__) +class MessageError: + def __init__(self, exception, offset): + self.exception = exception + self.offset = offset + + +DEFAULT_EDNS_PAYLOAD = 1232 +MAX_CHAIN = 16 class Message: """A DNS message.""" @@ -115,6 +138,7 @@ class Message: self.origin = None self.tsig_ctx = None self.index = {} + self.errors = [] @property def question(self): @@ -169,10 +193,8 @@ class Message: s = io.StringIO() s.write('id %d\n' % self.id) - s.write('opcode %s\n' % - dns.opcode.to_text(dns.opcode.from_flags(self.flags))) - rc = dns.rcode.from_flags(self.flags, self.ednsflags) - s.write('rcode %s\n' % dns.rcode.to_text(rc)) + s.write('opcode %s\n' % dns.opcode.to_text(self.opcode())) + s.write('rcode %s\n' % dns.rcode.to_text(self.rcode())) s.write('flags %s\n' % dns.flags.to_text(self.flags)) if self.edns >= 0: s.write('edns %s\n' % self.edns) @@ -221,7 +243,8 @@ class Message: return not self.__eq__(other) def is_response(self, other): - """Is *other* a response this message? + """Is *other*, also a ``dns.message.Message``, a response to this + message? Returns a ``bool``. """ @@ -231,9 +254,13 @@ class Message: dns.opcode.from_flags(self.flags) != \ dns.opcode.from_flags(other.flags): return False - if dns.rcode.from_flags(other.flags, other.ednsflags) != \ - dns.rcode.NOERROR: - return True + if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL, + dns.rcode.NOTIMP, dns.rcode.REFUSED}: + # We don't check the question section in these cases if + # the other question section is empty, even though they + # still really ought to have a question section. + if len(other.question) == 0: + return True if dns.opcode.is_update(self.flags): # This is assuming the "sender doesn't include anything # from the update", but we don't care to check the other @@ -330,7 +357,8 @@ class Message: return rrset else: for rrset in section: - if rrset.match(name, rdclass, rdtype, covers, deleting): + if rrset.full_match(name, rdclass, rdtype, covers, + deleting): return rrset if not create: raise KeyError @@ -403,8 +431,8 @@ class Message: *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple message sequence. - *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used - when signing zone transfers. + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the + ongoing TSIG context, used when signing zone transfers. Raises ``dns.exception.TooBig`` if *max_size* was exceeded. @@ -467,8 +495,8 @@ class Message: *key*, a ``dns.tsig.Key`` is the key to use. If a key is specified, the *keyring* and *algorithm* fields are not used. - *keyring*, a ``dict`` or ``dns.tsig.Key``, is either the TSIG - keyring or key to use. + *keyring*, a ``dict``, ``callable`` or ``dns.tsig.Key``, is either + the TSIG keyring or key to use. The format of a keyring dict is a mapping from TSIG key name, as ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``. @@ -476,7 +504,9 @@ class Message: used will be the first key in the *keyring*. Note that the order of keys in a dictionary is not defined, so applications should supply a keyname when a ``dict`` keyring is used, unless they know the keyring - contains only one key. + contains only one key. If a ``callable`` keyring is specified, the + callable will be called with the message and the keyname, and is + expected to return a key. *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of thes TSIG key to use; defaults to ``None``. If *keyring* is a @@ -497,7 +527,10 @@ class Message: """ if isinstance(keyring, dns.tsig.Key): - self.keyring = keyring + key = keyring + keyname = key.name + elif callable(keyring): + key = keyring(self, keyname) else: if isinstance(keyname, str): keyname = dns.name.from_text(keyname) @@ -506,7 +539,7 @@ class Message: key = keyring[keyname] if isinstance(key, bytes): key = dns.tsig.Key(keyname, key, algorithm) - self.keyring = key + self.keyring = key if original_id is None: original_id = self.id self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge, @@ -545,13 +578,13 @@ class Message: return bool(self.tsig) @staticmethod - def _make_opt(flags=0, payload=1280, options=None): + def _make_opt(flags=0, payload=DEFAULT_EDNS_PAYLOAD, options=None): opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, options or ()) return dns.rrset.from_rdata(dns.name.root, int(flags), opt) - def use_edns(self, edns=0, ednsflags=0, payload=1280, request_payload=None, - options=None): + def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD, + request_payload=None, options=None): """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -575,26 +608,21 @@ class Message: if edns is None or edns is False: edns = -1 - if edns is True: + elif edns is True: edns = 0 - if request_payload is None: - request_payload = payload if edns < 0: - ednsflags = 0 - payload = 0 - request_payload = 0 - options = [] + self.opt = None + self.request_payload = 0 else: # make sure the EDNS version in ednsflags agrees with edns ednsflags &= 0xFF00FFFF ednsflags |= (edns << 16) if options is None: options = [] - if edns >= 0: self.opt = self._make_opt(ednsflags, payload, options) - else: - self.opt = None - self.request_payload = request_payload + if request_payload is None: + request_payload = payload + self.request_payload = request_payload @property def edns(self): @@ -650,7 +678,7 @@ class Message: Returns an ``int``. """ - return dns.rcode.from_flags(self.flags, self.ednsflags) + return dns.rcode.from_flags(int(self.flags), int(self.ednsflags)) def set_rcode(self, rcode): """Set the rcode. @@ -668,7 +696,7 @@ class Message: Returns an ``int``. """ - return dns.opcode.from_flags(self.flags) + return dns.opcode.from_flags(int(self.flags)) def set_opcode(self, opcode): """Set the opcode. @@ -682,9 +710,13 @@ class Message: # What the caller picked is fine. return value + # pylint: disable=unused-argument + def _parse_rr_header(self, section, name, rdclass, rdtype): return (rdclass, rdtype, None, False) + # pylint: enable=unused-argument + def _parse_special_rr_header(self, section, count, position, name, rdclass, rdtype): if rdtype == dns.rdatatype.OPT: @@ -699,14 +731,129 @@ class Message: return (rdclass, rdtype, None, False) +class ChainingResult: + """The result of a call to dns.message.QueryMessage.resolve_chaining(). + + The ``answer`` attribute is the answer RRSet, or ``None`` if it doesn't + exist. + + The ``canonical_name`` attribute is the canonical name after all + chaining has been applied (this is the name as ``rrset.name`` in cases + where rrset is not ``None``). + + The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to + use if caching the data. It is the smallest of all the CNAME TTLs + and either the answer TTL if it exists or the SOA TTL and SOA + minimum values for negative answers. + + The ``cnames`` attribute is a list of all the CNAME RRSets followed to + get to the canonical name. + """ + def __init__(self, canonical_name, answer, minimum_ttl, cnames): + self.canonical_name = canonical_name + self.answer = answer + self.minimum_ttl = minimum_ttl + self.cnames = cnames + + class QueryMessage(Message): - pass + def resolve_chaining(self): + """Follow the CNAME chain in the response to determine the answer + RRset. + + Raises ``dns.message.NotQueryResponse`` if the message is not + a response. + + Raises ``dns.message.ChainTooLong`` if the CNAME chain is too long. + + Raises ``dns.message.AnswerForNXDOMAIN`` if the rcode is NXDOMAIN + but an answer was found. + + Raises ``dns.exception.FormError`` if the question count is not 1. + + Returns a ChainingResult object. + """ + if self.flags & dns.flags.QR == 0: + raise NotQueryResponse + if len(self.question) != 1: + raise dns.exception.FormError + question = self.question[0] + qname = question.name + min_ttl = dns.ttl.MAX_TTL + answer = None + count = 0 + cnames = [] + while count < MAX_CHAIN: + try: + answer = self.find_rrset(self.answer, qname, question.rdclass, + question.rdtype) + min_ttl = min(min_ttl, answer.ttl) + break + except KeyError: + if question.rdtype != dns.rdatatype.CNAME: + try: + crrset = self.find_rrset(self.answer, qname, + question.rdclass, + dns.rdatatype.CNAME) + cnames.append(crrset) + min_ttl = min(min_ttl, crrset.ttl) + for rd in crrset: + qname = rd.target + break + count += 1 + continue + except KeyError: + # Exit the chaining loop + break + else: + # Exit the chaining loop + break + if count >= MAX_CHAIN: + raise ChainTooLong + if self.rcode() == dns.rcode.NXDOMAIN and answer is not None: + raise AnswerForNXDOMAIN + if answer is None: + # Further minimize the TTL with NCACHE. + auname = qname + while True: + # Look for an SOA RR whose owner name is a superdomain + # of qname. + try: + srrset = self.find_rrset(self.authority, auname, + question.rdclass, + dns.rdatatype.SOA) + min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum) + break + except KeyError: + try: + auname = auname.parent() + except dns.name.NoParent: + break + return ChainingResult(qname, answer, min_ttl, cnames) + + def canonical_name(self): + """Return the canonical name of the first name in the question + section. + + Raises ``dns.message.NotQueryResponse`` if the message is not + a response. + + Raises ``dns.message.ChainTooLong`` if the CNAME chain is too long. + + Raises ``dns.message.AnswerForNXDOMAIN`` if the rcode is NXDOMAIN + but an answer was found. + + Raises ``dns.exception.FormError`` if the question count is not 1. + """ + return self.resolve_chaining().canonical_name def _maybe_import_update(): # We avoid circular imports by doing this here. We do it in another # function as doing it in _message_factory_from_opcode() makes "dns" # a local symbol, and the first line fails :) + + # pylint: disable=redefined-outer-name,import-outside-toplevel,unused-import import dns.update # noqa: F401 @@ -733,11 +880,14 @@ class _WireReader: ignore_trailing: Ignore trailing junk at end of request? multi: Is this message part of a multi-message sequence? DNS dynamic updates. + continue_on_error: try to extract as much information as possible from + the message, accumulating MessageErrors in the *errors* attribute instead of + raising them. """ def __init__(self, wire, initialize_message, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, - keyring=None, multi=False): + keyring=None, multi=False, continue_on_error=False): self.parser = dns.wire.Parser(wire) self.message = None self.initialize_message = initialize_message @@ -746,6 +896,8 @@ class _WireReader: self.ignore_trailing = ignore_trailing self.keyring = keyring self.multi = multi + self.continue_on_error = continue_on_error + self.errors = [] def _get_question(self, section_number, qcount): """Read the next *qcount* records from the wire data and add them to @@ -753,7 +905,7 @@ class _WireReader: """ section = self.message.sections[section_number] - for i in range(qcount): + for _ in range(qcount): qname = self.parser.get_name(self.message.origin) (rdtype, rdclass) = self.parser.get_struct('!HH') (rdclass, rdtype, _, _) = \ @@ -762,11 +914,14 @@ class _WireReader: self.message.find_rrset(section, qname, rdclass, rdtype, create=True, force_unique=True) + def _add_error(self, e): + self.errors.append(MessageError(e, self.parser.current)) + def _get_section(self, section_number, count): """Read the next I{count} records from the wire data and add them to the specified section. - section: the section of the message to which to add records + section_number: the section of the message to which to add records count: the number of records to read """ @@ -789,53 +944,65 @@ class _WireReader: (rdclass, rdtype, deleting, empty) = \ self.message._parse_rr_header(section_number, name, rdclass, rdtype) - if empty: - if rdlen > 0: - raise dns.exception.FormError - rd = None - covers = dns.rdatatype.NONE - else: - with self.parser.restrict_to(rdlen): - rd = dns.rdata.from_wire_parser(rdclass, rdtype, - self.parser, - self.message.origin) - covers = rd.covers() - if self.message.xfr and rdtype == dns.rdatatype.SOA: - force_unique = True - if rdtype == dns.rdatatype.OPT: - self.message.opt = dns.rrset.from_rdata(name, ttl, rd) - elif rdtype == dns.rdatatype.TSIG: - if self.keyring is None: - raise UnknownTSIGKey('got signed message without keyring') - if isinstance(self.keyring, dict): - key = self.keyring.get(absolute_name) - if isinstance(key, bytes): - key = dns.tsig.Key(absolute_name, key, rd.algorithm) + try: + rdata_start = self.parser.current + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE else: - key = self.keyring - if key is None: - raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyring = key - self.message.tsig_ctx = \ - dns.tsig.validate(self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi) - self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) - else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) - if rd is not None: - if ttl > 0x7fffffff: - ttl = 0 - rrset.add(rd, ttl) + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser(rdclass, rdtype, + self.parser, + self.message.origin) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True + if rdtype == dns.rdatatype.OPT: + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) + elif rdtype == dns.rdatatype.TSIG: + if self.keyring is None: + raise UnknownTSIGKey('got signed message without ' + 'keyring') + if isinstance(self.keyring, dict): + key = self.keyring.get(absolute_name) + if isinstance(key, bytes): + key = dns.tsig.Key(absolute_name, key, rd.algorithm) + elif callable(self.keyring): + key = self.keyring(self.message, absolute_name) + else: + key = self.keyring + if key is None: + raise UnknownTSIGKey("key '%s' unknown" % name) + self.message.keyring = key + self.message.tsig_ctx = \ + dns.tsig.validate(self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, + rd) + else: + rrset = self.message.find_rrset(section, name, + rdclass, rdtype, covers, + deleting, True, + force_unique) + if rd is not None: + if ttl > 0x7fffffff: + ttl = 0 + rrset.add(rd, ttl) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + self.parser.seek(rdata_start + rdlen) + else: + raise def read(self): """Read a wire format DNS message and build a dns.message.Message @@ -847,73 +1014,86 @@ class _WireReader: self.parser.get_struct('!HHHHHH') factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) self.message = factory(id=id) - self.message.flags = flags + self.message.flags = dns.flags.Flag(flags) self.initialize_message(self.message) self.one_rr_per_rrset = \ self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) - self._get_question(MessageSection.QUESTION, qcount) - if self.question_only: - return - self._get_section(MessageSection.ANSWER, ancount) - self._get_section(MessageSection.AUTHORITY, aucount) - self._get_section(MessageSection.ADDITIONAL, adcount) - if not self.ignore_trailing and self.parser.remaining() != 0: - raise TrailingJunk - if self.multi and self.message.tsig_ctx and not self.message.had_tsig: - self.message.tsig_ctx.update(self.parser.wire) + try: + self._get_question(MessageSection.QUESTION, qcount) + if self.question_only: + return self.message + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) + if not self.ignore_trailing and self.parser.remaining() != 0: + raise TrailingJunk + if self.multi and self.message.tsig_ctx and \ + not self.message.had_tsig: + self.message.tsig_ctx.update(self.parser.wire) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + else: + raise return self.message def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False): - """Convert a DNS wire format message into a message - object. + ignore_trailing=False, raise_on_truncation=False, + continue_on_error=False): + """Convert a DNS wire format message into a message object. - *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use - if the message is signed. + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the + message is signed. - *request_mac*, a ``bytes``. If the message is a response to a - TSIG-signed request, *request_mac* should be set to the MAC of - that request. + *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed + request, *request_mac* should be set to the MAC of that request. - *xfr*, a ``bool``, should be set to ``True`` if this message is part of - a zone transfer. + *xfr*, a ``bool``, should be set to ``True`` if this message is part of a + zone transfer. - *origin*, a ``dns.name.Name`` or ``None``. If the message is part - of a zone transfer, *origin* should be the origin name of the - zone. If not ``None``, names will be relativized to the origin. + *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone + transfer, *origin* should be the origin name of the zone. If not ``None``, + names will be relativized to the origin. - *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used - when validating zone transfers. + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the + ongoing TSIG context, used when validating zone transfers. - *multi*, a ``bool``, should be set to ``True`` if this message is - part of a multiple message sequence. + *multi*, a ``bool``, should be set to ``True`` if this message is part of a + multiple message sequence. - *question_only*, a ``bool``. If ``True``, read only up to - the end of the question section. + *question_only*, a ``bool``. If ``True``, read only up to the end of the + question section. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its - own RRset. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the message. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of + the message. - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the + TC bit is set. + + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even + if errors occur. Erroneous rdata will be ignored. Errors will be + accumulated as a list of MessageError objects in the message's ``errors`` + attribute. This option is recommended only for DNS analysis tools, or for + use in a server as part of an error handling path. The default is + ``False``. Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. - Raises ``dns.message.TrailingJunk`` if there were octets in the message - past the end of the proper DNS message, and *ignore_trailing* is ``False``. + Raises ``dns.message.TrailingJunk`` if there were octets in the message past + the end of the proper DNS message, and *ignore_trailing* is ``False``. - Raises ``dns.message.BadEDNS`` if an OPT record was in the - wrong section, or occurred more than once. + Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or + occurred more than once. - Raises ``dns.message.BadTSIG`` if a TSIG record was not the last - record of the additional data section. + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of + the additional data section. Raises ``dns.message.Truncated`` if the TC flag is set and *raise_on_truncation* is ``True``. @@ -928,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, message.tsig_ctx = tsig_ctx reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing, keyring, multi) + one_rr_per_rrset, ignore_trailing, keyring, multi, + continue_on_error) try: m = reader.read() except dns.exception.FormError: @@ -941,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, # have to do this check here too. if m.flags & dns.flags.TC and raise_on_truncation: raise Truncated(message=m) + if continue_on_error: + m.errors = reader.errors return m @@ -971,12 +1154,12 @@ class _TextReader: self.id = None self.edns = -1 self.ednsflags = 0 - self.payload = None + self.payload = DEFAULT_EDNS_PAYLOAD self.rcode = None self.opcode = dns.opcode.QUERY self.flags = 0 - def _header_line(self, section): + def _header_line(self, _): """Process one line from the text format header section.""" token = self.tok.get() @@ -1028,6 +1211,8 @@ class _TextReader: self.relativize, self.relativize_to) name = self.last_name + if name is None: + raise NoPreviousName token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError @@ -1062,6 +1247,8 @@ class _TextReader: self.relativize, self.relativize_to) name = self.last_name + if name is None: + raise NoPreviousName token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError @@ -1092,6 +1279,8 @@ class _TextReader: token = self.tok.get() if empty and not token.is_eol_or_eof(): raise dns.exception.SyntaxError + if not empty and token.is_eol_or_eof(): + raise dns.exception.UnexpectedEnd if not token.is_eol_or_eof(): self.tok.unget(token) rd = dns.rdata.from_text(rdclass, rdtype, self.tok, @@ -1235,7 +1424,8 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False): def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, want_dnssec=False, ednsflags=None, payload=None, - request_payload=None, options=None, idna_codec=None): + request_payload=None, options=None, idna_codec=None, + id=None, flags=dns.flags.RD): """Make a query message. The query name, type, and class may all be specified either @@ -1252,7 +1442,9 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, is class IN. *use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the - default is None (no EDNS). + default is ``None``. If ``None``, EDNS will be enabled only if other + parameters (*ednsflags*, *payload*, *request_payload*, or *options*) are + set. See the description of dns.message.Message.use_edns() for the possible values for use_edns and their meanings. @@ -1275,6 +1467,12 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder is used. + *id*, an ``int`` or ``None``, the desired query id. The default is + ``None``, which generates a random query id. + + *flags*, an ``int``, the desired query flags. The default is + ``dns.flags.RD``. + Returns a ``dns.message.QueryMessage`` """ @@ -1282,8 +1480,8 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, qname = dns.name.from_text(qname, idna_codec=idna_codec) rdtype = dns.rdatatype.RdataType.make(rdtype) rdclass = dns.rdataclass.RdataClass.make(rdclass) - m = QueryMessage() - m.flags |= dns.flags.RD + m = QueryMessage(id=id) + m.flags = dns.flags.Flag(flags) m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a @@ -1292,20 +1490,14 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, kwargs = {} if ednsflags is not None: kwargs['ednsflags'] = ednsflags - if use_edns is None: - use_edns = 0 if payload is not None: kwargs['payload'] = payload - if use_edns is None: - use_edns = 0 if request_payload is not None: kwargs['request_payload'] = request_payload - if use_edns is None: - use_edns = 0 if options is not None: kwargs['options'] = options - if use_edns is None: - use_edns = 0 + if kwargs and use_edns is None: + use_edns = 0 kwargs['edns'] = use_edns m.use_edns(**kwargs) m.want_dnssec(want_dnssec) @@ -1355,3 +1547,12 @@ def make_response(query, recursion_available=False, our_payload=8192, tsig_error, b'', query.keyalgorithm) response.request_mac = query.mac return response + +### BEGIN generated MessageSection constants + +QUESTION = MessageSection.QUESTION +ANSWER = MessageSection.ANSWER +AUTHORITY = MessageSection.AUTHORITY +ADDITIONAL = MessageSection.ADDITIONAL + +### END generated MessageSection constants diff --git a/lib/dns/message.pyi b/lib/dns/message.pyi new file mode 100644 index 00000000..252a4118 --- /dev/null +++ b/lib/dns/message.pyi @@ -0,0 +1,47 @@ +from typing import Optional, Dict, List, Tuple, Union +from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode +import hmac + +class Message: + def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes: + ... + def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int, + covers=rdatatype.NONE, deleting : Optional[int]=None, create=False, + force_unique=False) -> rrset.RRset: + ... + def __init__(self, id : Optional[int] =None) -> None: + self.id : int + self.flags = 0 + self.sections : List[List[rrset.RRset]] = [[], [], [], []] + self.opt : rrset.RRset = None + self.request_payload = 0 + self.keyring = None + self.tsig : rrset.RRset = None + self.request_mac = b'' + self.xfr = False + self.origin = None + self.tsig_ctx = None + self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {} + + def is_response(self, other : Message) -> bool: + ... + + def set_rcode(self, rcode : rcode.Rcode): + ... + +def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message: + ... + +def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None, + tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False, + question_only=False, one_rr_per_rrset=False, + ignore_trailing=False) -> Message: + ... +def make_response(query : Message, recursion_available=False, our_payload=8192, + fudge=300) -> Message: + ... + +def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None, + want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None, + request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message: + ... diff --git a/lib/dns/name.py b/lib/dns/name.py index 529ae7f9..8905d70f 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -30,6 +30,7 @@ except ImportError: # pragma: no cover import dns.wire import dns.exception +import dns.immutable # fullcompare() result values @@ -215,9 +216,10 @@ class IDNA2008Codec(IDNACodec): if not have_idna_2008: raise NoIDNA2008 try: + ulabel = idna.ulabel(label) if self.uts_46: - label = idna.uts46_remap(label, False, False) - return _escapify(idna.ulabel(label)) + ulabel = idna.uts46_remap(ulabel, False, self.transitional) + return _escapify(ulabel) except (idna.IDNAError, UnicodeError) as e: raise IDNAException(idna_exception=e) @@ -304,6 +306,7 @@ def _maybe_convert_to_binary(label): raise ValueError # pragma: no cover +@dns.immutable.immutable class Name: """A DNS name. @@ -320,17 +323,9 @@ class Name: """ labels = [_maybe_convert_to_binary(x) for x in labels] - super().__setattr__('labels', tuple(labels)) + self.labels = tuple(labels) _validate_labels(self.labels) - def __setattr__(self, name, value): - # Names are immutable - raise TypeError("object doesn't support attribute assignment") - - def __delattr__(self, name): - # Names are immutable - raise TypeError("object doesn't support attribute deletion") - def __copy__(self): return Name(self.labels) @@ -458,7 +453,7 @@ class Name: Returns a ``bool``. """ - (nr, o, nl) = self.fullcompare(other) + (nr, _, _) = self.fullcompare(other) if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL: return True return False @@ -472,7 +467,7 @@ class Name: Returns a ``bool``. """ - (nr, o, nl) = self.fullcompare(other) + (nr, _, _) = self.fullcompare(other) if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL: return True return False diff --git a/lib/dns/name.pyi b/lib/dns/name.pyi new file mode 100644 index 00000000..c48d4bd1 --- /dev/null +++ b/lib/dns/name.pyi @@ -0,0 +1,40 @@ +from typing import Optional, Union, Tuple, Iterable, List + +have_idna_2008: bool + +class Name: + def is_subdomain(self, o : Name) -> bool: ... + def is_superdomain(self, o : Name) -> bool: ... + def __init__(self, labels : Iterable[Union[bytes,str]]) -> None: + self.labels : List[bytes] + def is_absolute(self) -> bool: ... + def is_wild(self) -> bool: ... + def fullcompare(self, other) -> Tuple[int,int,int]: ... + def canonicalize(self) -> Name: ... + def __eq__(self, other) -> bool: ... + def __ne__(self, other) -> bool: ... + def __lt__(self, other : Name) -> bool: ... + def __le__(self, other : Name) -> bool: ... + def __ge__(self, other : Name) -> bool: ... + def __gt__(self, other : Name) -> bool: ... + def to_text(self, omit_final_dot=False) -> str: ... + def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ... + def to_digestable(self, origin=None) -> bytes: ... + def to_wire(self, file=None, compress=None, origin=None, + canonicalize=False) -> Optional[bytes]: ... + def __add__(self, other : Name) -> Name: ... + def __sub__(self, other : Name) -> Name: ... + def split(self, depth) -> List[Tuple[str,str]]: ... + def concatenate(self, other : Name) -> Name: ... + def relativize(self, origin) -> Name: ... + def derelativize(self, origin) -> Name: ... + def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ... + def parent(self) -> Name: ... + +class IDNACodec: + pass + +def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name: + ... + +empty : Name diff --git a/lib/dns/namedict.py b/lib/dns/namedict.py index 4c8f9abd..ec0750ce 100644 --- a/lib/dns/namedict.py +++ b/lib/dns/namedict.py @@ -85,7 +85,7 @@ class NameDict(MutableMapping): return key in self.__store def get_deepest_match(self, name): - """Find the deepest match to *fname* in the dictionary. + """Find the deepest match to *name* in the dictionary. The deepest match is the longest name in the dictionary which is a superdomain of *name*. Note that *superdomain* includes matching diff --git a/lib/dns/node.py b/lib/dns/node.py index b7e21b54..63ce008b 100644 --- a/lib/dns/node.py +++ b/lib/dns/node.py @@ -17,16 +17,69 @@ """DNS nodes. A node is a set of rdatasets.""" +import enum import io +import dns.immutable import dns.rdataset import dns.rdatatype import dns.renderer +_cname_types = { + dns.rdatatype.CNAME, +} + +# "neutral" types can coexist with a CNAME and thus are not "other data" +_neutral_types = { + dns.rdatatype.NSEC, # RFC 4035 section 2.5 + dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible! + dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 +} + +def _matches_type_or_its_signature(rdtypes, rdtype, covers): + return rdtype in rdtypes or \ + (rdtype == dns.rdatatype.RRSIG and covers in rdtypes) + + +@enum.unique +class NodeKind(enum.Enum): + """Rdatasets in nodes + """ + REGULAR = 0 # a.k.a "other data" + NEUTRAL = 1 + CNAME = 2 + + @classmethod + def classify(cls, rdtype, covers): + if _matches_type_or_its_signature(_cname_types, rdtype, covers): + return NodeKind.CNAME + elif _matches_type_or_its_signature(_neutral_types, rdtype, covers): + return NodeKind.NEUTRAL + else: + return NodeKind.REGULAR + + @classmethod + def classify_rdataset(cls, rdataset): + return cls.classify(rdataset.rdtype, rdataset.covers) + + class Node: - """A Node is a set of rdatasets.""" + """A Node is a set of rdatasets. + + A node is either a CNAME node or an "other data" node. A CNAME + node contains only CNAME, KEY, NSEC, and NSEC3 rdatasets along with their + covering RRSIG rdatasets. An "other data" node contains any + rdataset other than a CNAME or RRSIG(CNAME) rdataset. When + changes are made to a node, the CNAME or "other data" state is + always consistent with the update, i.e. the most recent change + wins. For example, if you have a node which contains a CNAME + rdataset, and then add an MX rdataset to it, then the CNAME + rdataset will be deleted. Likewise if you have a node containing + an MX rdataset and add a CNAME rdataset, the MX rdataset will be + deleted. + """ __slots__ = ['rdatasets'] @@ -78,6 +131,30 @@ class Node: def __iter__(self): return iter(self.rdatasets) + def _append_rdataset(self, rdataset): + """Append rdataset to the node with special handling for CNAME and + other data conditions. + + Specifically, if the rdataset being appended has ``NodeKind.CNAME``, + then all rdatasets other than KEY, NSEC, NSEC3, and their covering + RRSIGs are deleted. If the rdataset being appended has + ``NodeKind.REGULAR`` then CNAME and RRSIG(CNAME) are deleted. + """ + # Make having just one rdataset at the node fast. + if len(self.rdatasets) > 0: + kind = NodeKind.classify_rdataset(rdataset) + if kind == NodeKind.CNAME: + self.rdatasets = [rds for rds in self.rdatasets if + NodeKind.classify_rdataset(rds) != + NodeKind.REGULAR] + elif kind == NodeKind.REGULAR: + self.rdatasets = [rds for rds in self.rdatasets if + NodeKind.classify_rdataset(rds) != + NodeKind.CNAME] + # Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to + # edit self.rdatasets. + self.rdatasets.append(rdataset) + def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, create=False): """Find an rdataset matching the specified properties in the @@ -110,8 +187,8 @@ class Node: return rds if not create: raise KeyError - rds = dns.rdataset.Rdataset(rdclass, rdtype) - self.rdatasets.append(rds) + rds = dns.rdataset.Rdataset(rdclass, rdtype, covers) + self._append_rdataset(rds) return rds def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, @@ -180,6 +257,64 @@ class Node: if not isinstance(replacement, dns.rdataset.Rdataset): raise ValueError('replacement is not an rdataset') + if isinstance(replacement, dns.rrset.RRset): + # RRsets are not good replacements as the match() method + # is not compatible. + replacement = replacement.to_rdataset() self.delete_rdataset(replacement.rdclass, replacement.rdtype, replacement.covers) - self.rdatasets.append(replacement) + self._append_rdataset(replacement) + + def classify(self): + """Classify a node. + + A node which contains a CNAME or RRSIG(CNAME) is a + ``NodeKind.CNAME`` node. + + A node which contains only "neutral" types, i.e. types allowed to + co-exist with a CNAME, is a ``NodeKind.NEUTRAL`` node. The neutral + types are NSEC, NSEC3, KEY, and their associated RRSIGS. An empty node + is also considered neutral. + + A node which contains some rdataset which is not a CNAME, RRSIG(CNAME), + or a neutral type is a a ``NodeKind.REGULAR`` node. Regular nodes are + also commonly referred to as "other data". + """ + for rdataset in self.rdatasets: + kind = NodeKind.classify(rdataset.rdtype, rdataset.covers) + if kind != NodeKind.NEUTRAL: + return kind + return NodeKind.NEUTRAL + + def is_immutable(self): + return False + + +@dns.immutable.immutable +class ImmutableNode(Node): + def __init__(self, node): + super().__init__() + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + raise TypeError("immutable") + + def replace_rdataset(self, replacement): + raise TypeError("immutable") + + def is_immutable(self): + return True diff --git a/lib/dns/node.pyi b/lib/dns/node.pyi new file mode 100644 index 00000000..0997edf9 --- /dev/null +++ b/lib/dns/node.pyi @@ -0,0 +1,17 @@ +from typing import List, Optional, Union +from . import rdataset, rdatatype, name +class Node: + def __init__(self): + self.rdatasets : List[rdataset.Rdataset] + def to_text(self, name : Union[str,name.Name], **kw) -> str: + ... + def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, + create=False) -> rdataset.Rdataset: + ... + def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, + create=False) -> Optional[rdataset.Rdataset]: + ... + def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE): + ... + def replace_rdataset(self, replacement : rdataset.Rdataset) -> None: + ... diff --git a/lib/dns/opcode.py b/lib/dns/opcode.py index 5a76326a..5cf6143c 100644 --- a/lib/dns/opcode.py +++ b/lib/dns/opcode.py @@ -40,8 +40,6 @@ class Opcode(dns.enum.IntEnum): def _unknown_exception_class(cls): return UnknownOpcode -globals().update(Opcode.__members__) - class UnknownOpcode(dns.exception.DNSException): """An DNS opcode is unknown.""" @@ -105,3 +103,13 @@ def is_update(flags): """ return from_flags(flags) == Opcode.UPDATE + +### BEGIN generated Opcode constants + +QUERY = Opcode.QUERY +IQUERY = Opcode.IQUERY +STATUS = Opcode.STATUS +NOTIFY = Opcode.NOTIFY +UPDATE = Opcode.UPDATE + +### END generated Opcode constants diff --git a/lib/dns/query.py b/lib/dns/query.py index 7df565d8..fbf76d8b 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -18,9 +18,10 @@ """Talk to a DNS server.""" import contextlib +import enum import errno import os -import select +import selectors import socket import struct import time @@ -35,14 +36,31 @@ import dns.rcode import dns.rdataclass import dns.rdatatype import dns.serial +import dns.xfr try: import requests from requests_toolbelt.adapters.source import SourceAddressAdapter from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter - have_doh = True + _have_requests = True except ImportError: # pragma: no cover - have_doh = False + _have_requests = False + +_have_httpx = False +_have_http2 = False +try: + import httpx + _have_httpx = True + try: + # See if http2 support is available. + with httpx.Client(http2=True): + _have_http2 = True + except Exception: + pass +except ImportError: # pragma: no cover + pass + +have_doh = _have_requests or _have_httpx try: import ssl @@ -73,20 +91,15 @@ class BadResponse(dns.exception.FormError): """A DNS query response does not respond to the question asked.""" -class TransferError(dns.exception.DNSException): - """A zone transfer response got a non-zero rcode.""" - - def __init__(self, rcode): - message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) - super().__init__(message) - self.rcode = rcode - - class NoDOH(dns.exception.DNSException): """DNS over HTTPS (DOH) was requested but the requests module is not available.""" +# for backwards compatibility +TransferError = dns.xfr.TransferError + + def _compute_times(timeout): now = time.time() if timeout is None: @@ -94,91 +107,49 @@ def _compute_times(timeout): else: return (now, now + timeout) -# This module can use either poll() or select() as the "polling backend". -# -# A backend function takes an fd, bools for readability, writablity, and -# error detection, and a timeout. -def _poll_for(fd, readable, writable, error, timeout): - """Poll polling backend.""" - - event_mask = 0 - if readable: - event_mask |= select.POLLIN - if writable: - event_mask |= select.POLLOUT - if error: - event_mask |= select.POLLERR - - pollable = select.poll() - pollable.register(fd, event_mask) - - if timeout: - event_list = pollable.poll(timeout * 1000) - else: - event_list = pollable.poll() - - return bool(event_list) - - -def _select_for(fd, readable, writable, error, timeout): - """Select polling backend.""" - - rset, wset, xset = [], [], [] - - if readable: - rset = [fd] - if writable: - wset = [fd] - if error: - xset = [fd] - - if timeout is None: - (rcount, wcount, xcount) = select.select(rset, wset, xset) - else: - (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout) - - return bool((rcount or wcount or xcount)) - - -def _wait_for(fd, readable, writable, error, expiration): - # Use the selected polling backend to wait for any of the specified +def _wait_for(fd, readable, writable, _, expiration): + # Use the selected selector class to wait for any of the specified # events. An "expiration" absolute time is converted into a relative # timeout. + # + # The unused parameter is 'error', which is always set when + # selecting for read or write, and we have no error-only selects. - done = False - while not done: - if expiration is None: - timeout = None - else: - timeout = expiration - time.time() - if timeout <= 0.0: - raise dns.exception.Timeout - try: - if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: - return True - if not _polling_backend(fd, readable, writable, error, timeout): - raise dns.exception.Timeout - except OSError as e: # pragma: no cover - if e.args[0] != errno.EINTR: - raise e - done = True + if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: + return True + sel = _selector_class() + events = 0 + if readable: + events |= selectors.EVENT_READ + if writable: + events |= selectors.EVENT_WRITE + if events: + sel.register(fd, events) + if expiration is None: + timeout = None + else: + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + if not sel.select(timeout): + raise dns.exception.Timeout -def _set_polling_backend(fn): +def _set_selector_class(selector_class): # Internal API. Do not use. - global _polling_backend + global _selector_class - _polling_backend = fn + _selector_class = selector_class -if hasattr(select, 'poll'): +if hasattr(selectors, 'PollSelector'): # Prefer poll() on platforms that support it because it has no # limits on the maximum value of a file descriptor (plus it will # be more efficient for high values). - _polling_backend = _poll_for + _selector_class = selectors.PollSelector else: - _polling_backend = _select_for # pragma: no cover + _selector_class = selectors.SelectSelector # pragma: no cover def _wait_for_readable(s, expiration): @@ -303,8 +274,8 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *session*, a ``requests.session.Session``. If provided, the session to use - to send the queries. + *session*, an ``httpx.Client`` or ``requests.session.Session``. If + provided, the client/session to use to send the queries. *path*, a ``str``. If *where* is an IP address, then *path* will be used to construct the URL to send the DNS query to. @@ -320,37 +291,66 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, """ if not have_doh: - raise NoDOH # pragma: no cover + raise NoDOH('Neither httpx nor requests is available.') # pragma: no cover + + _httpx_ok = _have_httpx wire = q.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port, - False) + (af, _, source) = _destination_and_source(where, port, source, source_port, + False) transport_adapter = None + transport = None headers = { "accept": "application/dns-message" } - try: - where_af = dns.inet.af_for_address(where) - if where_af == socket.AF_INET: + if af is not None: + if af == socket.AF_INET: url = 'https://{}:{}{}'.format(where, port, path) - elif where_af == socket.AF_INET6: + elif af == socket.AF_INET6: url = 'https://[{}]:{}{}'.format(where, port, path) - except ValueError: - if bootstrap_address is not None: - split_url = urllib.parse.urlsplit(where) - headers['Host'] = split_url.hostname - url = where.replace(split_url.hostname, bootstrap_address) + elif bootstrap_address is not None: + _httpx_ok = False + split_url = urllib.parse.urlsplit(where) + headers['Host'] = split_url.hostname + url = where.replace(split_url.hostname, bootstrap_address) + if _have_requests: transport_adapter = HostHeaderSSLAdapter() - else: - url = where + else: + url = where if source is not None: # set source port and source address - transport_adapter = SourceAddressAdapter(source) + if _have_httpx: + if source_port == 0: + transport = httpx.HTTPTransport(local_address=source[0]) + else: + _httpx_ok = False + if _have_requests: + transport_adapter = SourceAddressAdapter(source) + + if session: + if _have_httpx: + _is_httpx = isinstance(session, httpx.Client) + else: + _is_httpx = False + if _is_httpx and not _httpx_ok: + raise NoDOH('Session is httpx, but httpx cannot be used for ' + 'the requested operation.') + else: + _is_httpx = _httpx_ok + + if not _httpx_ok and not _have_requests: + raise NoDOH('Cannot use httpx for this operation, and ' + 'requests is not available.') with contextlib.ExitStack() as stack: if not session: - session = stack.enter_context(requests.sessions.Session()) + if _is_httpx: + session = stack.enter_context(httpx.Client(http1=True, + http2=_have_http2, + verify=verify, + transport=transport)) + else: + session = stack.enter_context(requests.sessions.Session()) if transport_adapter: session.mount(url, transport_adapter) @@ -362,13 +362,23 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, "content-type": "application/dns-message", "content-length": str(len(wire)) }) - response = session.post(url, headers=headers, data=wire, - timeout=timeout, verify=verify) + if _is_httpx: + response = session.post(url, headers=headers, content=wire, + timeout=timeout) + else: + response = session.post(url, headers=headers, data=wire, + timeout=timeout, verify=verify) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") - response = session.get(url, headers=headers, - timeout=timeout, verify=verify, - params={"dns": wire}) + if _is_httpx: + wire = wire.decode() # httpx does a repr() if we give it bytes + response = session.get(url, headers=headers, + timeout=timeout, + params={"dns": wire}) + else: + response = session.get(url, headers=headers, + timeout=timeout, verify=verify, + params={"dns": wire}) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes @@ -387,6 +397,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, raise BadResponse return r +def _udp_recv(sock, max_size, expiration): + """Reads a datagram from the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + return sock.recvfrom(max_size) + except BlockingIOError: + _wait_for_readable(sock, expiration) + + +def _udp_send(sock, data, destination, expiration): + """Sends the specified datagram to destination over the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + if destination: + return sock.sendto(data, destination) + else: + return sock.send(data) + except BlockingIOError: # pragma: no cover + _wait_for_writable(sock, expiration) + + def send_udp(sock, what, destination, expiration=None): """Send a DNS message to the specified UDP socket. @@ -406,9 +443,8 @@ def send_udp(sock, what, destination, expiration=None): if isinstance(what, dns.message.Message): what = what.to_wire() - _wait_for_writable(sock, expiration) sent_time = time.time() - n = sock.sendto(what, destination) + n = _udp_send(sock, what, destination, expiration) return (n, sent_time) @@ -458,9 +494,8 @@ def receive_udp(sock, destination=None, expiration=None, """ wire = b'' - while 1: - _wait_for_readable(sock, expiration) - (wire, from_address) = sock.recvfrom(65535) + while True: + (wire, from_address) = _udp_recv(sock, 65535, expiration) if _matches_destination(sock.family, from_address, destination, ignore_unexpected): break @@ -571,7 +606,7 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None, if a socket is provided, it must be a nonblocking datagram socket, and the *source* and *source_port* are ignored for the UDP query. - *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the + *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the TCP query. If ``None``, the default, a socket is created. Note that if a socket is provided, it must be a nonblocking connected stream socket, and *where*, *source* and *source_port* are ignored for the TCP @@ -598,18 +633,16 @@ def _net_read(sock, count, expiration): """ s = b'' while count > 0: - _wait_for_readable(sock, expiration) try: n = sock.recv(count) - except ssl.SSLWantReadError: # pragma: no cover - continue + if n == b'': + raise EOFError + count -= len(n) + s += n + except (BlockingIOError, ssl.SSLWantReadError): + _wait_for_readable(sock, expiration) except ssl.SSLWantWriteError: # pragma: no cover _wait_for_writable(sock, expiration) - continue - if n == b'': - raise EOFError - count = count - len(n) - s = s + n return s @@ -621,14 +654,12 @@ def _net_write(sock, data, expiration): current = 0 l = len(data) while current < l: - _wait_for_writable(sock, expiration) try: current += sock.send(data[current:]) + except (BlockingIOError, ssl.SSLWantWriteError): + _wait_for_writable(sock, expiration) except ssl.SSLWantReadError: # pragma: no cover _wait_for_readable(sock, expiration) - continue - except ssl.SSLWantWriteError: # pragma: no cover - continue def send_tcp(sock, what, expiration=None): @@ -652,7 +683,6 @@ def send_tcp(sock, what, expiration=None): # avoid writev() or doing a short write that would get pushed # onto the net tcpmsg = struct.pack("!H", l) + what - _wait_for_writable(sock, expiration) sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) @@ -730,7 +760,7 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *sock*, a ``socket.socket``, or ``None``, the socket to use for the + *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the query. If ``None``, the default, a socket is created. Note that if a socket is provided, it must be a nonblocking connected stream socket, and *where*, *port*, *source* and *source_port* are ignored. @@ -742,11 +772,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, (begin_time, expiration) = _compute_times(timeout) with contextlib.ExitStack() as stack: if sock: - # - # Verify that the socket is connected, as if it's not connected, - # it's not writable, and the polling in send_tcp() will time out or - # hang forever. - sock.getpeername() s = sock else: (af, destination, source) = _destination_and_source(where, port, @@ -926,8 +951,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, _connect(s, destination, expiration) l = len(wire) if use_udp: - _wait_for_writable(s, expiration) - s.send(wire) + _udp_send(s, wire, None, expiration) else: tcpmsg = struct.pack("!H", l) + wire _net_write(s, tcpmsg, expiration) @@ -948,8 +972,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, (expiration is not None and mexpiration > expiration): mexpiration = expiration if use_udp: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) + (wire, _) = _udp_recv(s, 65535, mexpiration) else: ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) @@ -1016,3 +1039,116 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if done and q.keyring and not r.had_tsig: raise dns.exception.FormError("missing TSIG") yield r + + +class UDPMode(enum.IntEnum): + """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? + + NEVER means "never use UDP; always use TCP" + TRY_FIRST means "try to use UDP but fall back to TCP if needed" + ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" + """ + NEVER = 0 + TRY_FIRST = 1 + ONLY = 2 + + +def inbound_xfr(where, txn_manager, query=None, + port=53, timeout=None, lifetime=None, source=None, + source_port=0, udp_mode=UDPMode.NEVER): + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager + for this transfer (typically a ``dns.zone.Zone``). + + *query*, the query to send. If not supplied, a default query is + constructed using information from the *txn_manager*. + + *port*, an ``int``, the port send the message to. The default is 53. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used + for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use + TCP. Other possibilites are ``dns.UDPMode.TRY_FIRST``, which + means "try UDP but fallback to TCP if needed", and + ``dns.UDPMode.ONLY``, which means "try UDP and raise + ``dns.xfr.UseTCP`` if it does not succeeed. + + Raises on errors. + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + (af, destination, source) = _destination_and_source(where, port, + source, source_port) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + with _make_socket(af, sock_type, source) as s: + _connect(s, destination, expiration) + if is_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + _net_write(s, tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, + is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or \ + (expiration is not None and mexpiration > expiration): + mexpiration = expiration + if is_udp: + (rwire, _) = _udp_recv(s, 65535, mexpiration) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = _net_read(s, l, mexpiration) + r = dns.message.from_wire(rwire, keyring=query.keyring, + request_mac=query.mac, xfr=True, + origin=origin, tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr) + try: + done = inbound.process_message(r) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") diff --git a/lib/dns/query.pyi b/lib/dns/query.pyi new file mode 100644 index 00000000..a22e229f --- /dev/null +++ b/lib/dns/query.pyi @@ -0,0 +1,64 @@ +from typing import Optional, Union, Dict, Generator, Any +from . import tsig, rdatatype, rdataclass, name, message +from requests.sessions import Session + +import socket + +# If the ssl import works, then +# +# error: Name 'ssl' already defined (by an import) +# +# is expected and can be ignored. +try: + import ssl +except ImportError: + class ssl: # type: ignore + SSLContext : Dict = {} + +have_doh: bool + +def https(q : message.Message, where: str, timeout : Optional[float] = None, + port : Optional[int] = 443, source : Optional[str] = None, + source_port : Optional[int] = 0, + session: Optional[Session] = None, + path : Optional[str] = '/dns-query', post : Optional[bool] = True, + bootstrap_address : Optional[str] = None, + verify : Optional[bool] = True) -> message.Message: + pass + +def tcp(q : message.Message, where : str, timeout : float = None, port=53, + af : Optional[int] = None, source : Optional[str] = None, + source_port : Optional[int] = 0, + one_rr_per_rrset : Optional[bool] = False, + ignore_trailing : Optional[bool] = False, + sock : Optional[socket.socket] = None) -> message.Message: + pass + +def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR, + rdclass=rdataclass.IN, + timeout : Optional[float] = None, port=53, + keyring : Optional[Dict[name.Name, bytes]] = None, + keyname : Union[str,name.Name]= None, relativize=True, + lifetime : Optional[float] = None, + source : Optional[str] = None, source_port=0, serial=0, + use_udp : Optional[bool] = False, + keyalgorithm=tsig.default_algorithm) \ + -> Generator[Any,Any,message.Message]: + pass + +def udp(q : message.Message, where : str, timeout : Optional[float] = None, + port=53, source : Optional[str] = None, source_port : Optional[int] = 0, + ignore_unexpected : Optional[bool] = False, + one_rr_per_rrset : Optional[bool] = False, + ignore_trailing : Optional[bool] = False, + sock : Optional[socket.socket] = None) -> message.Message: + pass + +def tls(q : message.Message, where : str, timeout : Optional[float] = None, + port=53, source : Optional[str] = None, source_port : Optional[int] = 0, + one_rr_per_rrset : Optional[bool] = False, + ignore_trailing : Optional[bool] = False, + sock : Optional[socket.socket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None) -> message.Message: + pass diff --git a/lib/dns/rcode.py b/lib/dns/rcode.py index d9ea0051..49fee695 100644 --- a/lib/dns/rcode.py +++ b/lib/dns/rcode.py @@ -72,7 +72,6 @@ class Rcode(dns.enum.IntEnum): def _unknown_exception_class(cls): return UnknownRcode -globals().update(Rcode.__members__) class UnknownRcode(dns.exception.DNSException): """A DNS rcode is unknown.""" @@ -104,8 +103,6 @@ def from_flags(flags, ednsflags): """ value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) - if value < 0 or value > 4095: - raise ValueError('rcode must be >= 0 and <= 4095') return value @@ -139,3 +136,29 @@ def to_text(value, tsig=False): if tsig and value == Rcode.BADVERS: return 'BADSIG' return Rcode.to_text(value) + +### BEGIN generated Rcode constants + +NOERROR = Rcode.NOERROR +FORMERR = Rcode.FORMERR +SERVFAIL = Rcode.SERVFAIL +NXDOMAIN = Rcode.NXDOMAIN +NOTIMP = Rcode.NOTIMP +REFUSED = Rcode.REFUSED +YXDOMAIN = Rcode.YXDOMAIN +YXRRSET = Rcode.YXRRSET +NXRRSET = Rcode.NXRRSET +NOTAUTH = Rcode.NOTAUTH +NOTZONE = Rcode.NOTZONE +DSOTYPENI = Rcode.DSOTYPENI +BADVERS = Rcode.BADVERS +BADSIG = Rcode.BADSIG +BADKEY = Rcode.BADKEY +BADTIME = Rcode.BADTIME +BADMODE = Rcode.BADMODE +BADNAME = Rcode.BADNAME +BADALG = Rcode.BADALG +BADTRUNC = Rcode.BADTRUNC +BADCOOKIE = Rcode.BADCOOKIE + +### END generated Rcode constants diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index e114fe32..624063e0 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -23,43 +23,68 @@ import binascii import io import inspect import itertools +import random import dns.wire import dns.exception +import dns.immutable +import dns.ipv4 +import dns.ipv6 import dns.name import dns.rdataclass import dns.rdatatype import dns.tokenizer +import dns.ttl _chunksize = 32 +# We currently allow comparisons for rdata with relative names for backwards +# compatibility, but in the future we will not, as these kinds of comparisons +# can lead to subtle bugs if code is not carefully written. +# +# This switch allows the future behavior to be turned on so code can be +# tested with it. +_allow_relative_comparisons = True -def _wordbreak(data, chunksize=_chunksize): + +class NoRelativeRdataOrdering(dns.exception.DNSException): + """An attempt was made to do an ordered comparison of one or more + rdata with relative names. The only reliable way of sorting rdata + is to use non-relativized rdata. + + """ + + +def _wordbreak(data, chunksize=_chunksize, separator=b' '): """Break a binary string into chunks of chunksize characters separated by a space. """ if not chunksize: return data.decode() - return b' '.join([data[i:i + chunksize] - for i - in range(0, len(data), chunksize)]).decode() + return separator.join([data[i:i + chunksize] + for i + in range(0, len(data), chunksize)]).decode() -def _hexify(data, chunksize=_chunksize): +# pylint: disable=unused-argument + +def _hexify(data, chunksize=_chunksize, separator=b' ', **kw): """Convert a binary string into its hex encoding, broken up into chunks - of chunksize characters separated by a space. + of chunksize characters separated by a separator. """ - return _wordbreak(binascii.hexlify(data), chunksize) + return _wordbreak(binascii.hexlify(data), chunksize, separator) -def _base64ify(data, chunksize=_chunksize): +def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw): """Convert a binary string into its base64 encoding, broken up into chunks - of chunksize characters separated by a space. + of chunksize characters separated by a separator. """ - return _wordbreak(base64.b64encode(data), chunksize) + return _wordbreak(base64.b64encode(data), chunksize, separator) + +# pylint: enable=unused-argument __escaped = b'"\\' @@ -92,26 +117,15 @@ def _truncate_bitmap(what): return what[0: i + 1] return what[0:1] -def _constify(o): - """ - Convert mutable types to immutable types. - """ - if isinstance(o, bytearray): - return bytes(o) - if isinstance(o, tuple): - try: - hash(o) - return o - except Exception: - return tuple(_constify(elt) for elt in o) - if isinstance(o, list): - return tuple(_constify(elt) for elt in o) - return o +# So we don't have to edit all the rdata classes... +_constify = dns.immutable.constify + +@dns.immutable.immutable class Rdata: """Base class for all DNS rdata types.""" - __slots__ = ['rdclass', 'rdtype'] + __slots__ = ['rdclass', 'rdtype', 'rdcomment'] def __init__(self, rdclass, rdtype): """Initialize an rdata. @@ -121,16 +135,9 @@ class Rdata: *rdtype*, an ``int`` is the rdatatype of the Rdata. """ - object.__setattr__(self, 'rdclass', rdclass) - object.__setattr__(self, 'rdtype', rdtype) - - def __setattr__(self, name, value): - # Rdatas are immutable - raise TypeError("object doesn't support attribute assignment") - - def __delattr__(self, name): - # Rdatas are immutable - raise TypeError("object doesn't support attribute deletion") + self.rdclass = self._as_rdataclass(rdclass) + self.rdtype = self._as_rdatatype(rdtype) + self.rdcomment = None def _get_all_slots(self): return itertools.chain.from_iterable(getattr(cls, '__slots__', []) @@ -153,6 +160,10 @@ class Rdata: def __setstate__(self, state): for slot, val in state.items(): object.__setattr__(self, slot, val) + if not hasattr(self, 'rdcomment'): + # Pickled rdata from 2.0.x might not have a rdcomment, so add + # it if needed. + object.__setattr__(self, 'rdcomment', None) def covers(self): """Return the type a Rdata covers. @@ -184,10 +195,10 @@ class Rdata: Returns a ``str``. """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - raise NotImplementedError + raise NotImplementedError # pragma: no cover def to_wire(self, file=None, compress=None, origin=None, canonicalize=False): @@ -237,12 +248,42 @@ class Rdata: """Compare an rdata with another rdata of the same rdtype and rdclass. - Return < 0 if self < other in the DNSSEC ordering, 0 if self - == other, and > 0 if self > other. - + For rdata with only absolute names: + Return < 0 if self < other in the DNSSEC ordering, 0 if self + == other, and > 0 if self > other. + For rdata with at least one relative names: + The rdata sorts before any rdata with only absolute names. + When compared with another relative rdata, all names are + made absolute as if they were relative to the root, as the + proper origin is not available. While this creates a stable + ordering, it is NOT guaranteed to be the DNSSEC ordering. + In the future, all ordering comparisons for rdata with + relative names will be disallowed. """ - our = self.to_digestable(dns.name.root) - their = other.to_digestable(dns.name.root) + try: + our = self.to_digestable() + our_relative = False + except dns.name.NeedAbsoluteNameOrOrigin: + if _allow_relative_comparisons: + our = self.to_digestable(dns.name.root) + our_relative = True + try: + their = other.to_digestable() + their_relative = False + except dns.name.NeedAbsoluteNameOrOrigin: + if _allow_relative_comparisons: + their = other.to_digestable(dns.name.root) + their_relative = True + if _allow_relative_comparisons: + if our_relative != their_relative: + # For the purpose of comparison, all rdata with at least one + # relative name is less than an rdata with only absolute names. + if our_relative: + return -1 + else: + return 1 + elif our_relative or their_relative: + raise NoRelativeRdataOrdering if our == their: return 0 elif our > their: @@ -255,14 +296,28 @@ class Rdata: return False if self.rdclass != other.rdclass or self.rdtype != other.rdtype: return False - return self._cmp(other) == 0 + our_relative = False + their_relative = False + try: + our = self.to_digestable() + except dns.name.NeedAbsoluteNameOrOrigin: + our = self.to_digestable(dns.name.root) + our_relative = True + try: + their = other.to_digestable() + except dns.name.NeedAbsoluteNameOrOrigin: + their = other.to_digestable(dns.name.root) + their_relative = True + if our_relative != their_relative: + return False + return our == their def __ne__(self, other): if not isinstance(other, Rdata): return True if self.rdclass != other.rdclass or self.rdtype != other.rdtype: return True - return self._cmp(other) != 0 + return not self.__eq__(other) def __lt__(self, other): if not isinstance(other, Rdata) or \ @@ -295,11 +350,11 @@ class Rdata: @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): - raise NotImplementedError + raise NotImplementedError # pragma: no cover @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - raise NotImplementedError + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + raise NotImplementedError # pragma: no cover def replace(self, **kwargs): """ @@ -319,6 +374,8 @@ class Rdata: # Ensure that all of the arguments correspond to valid fields. # Don't allow rdclass or rdtype to be changed, though. for key in kwargs: + if key == 'rdcomment': + continue if key not in parameters: raise AttributeError("'{}' object has no attribute '{}'" .format(self.__class__.__name__, key)) @@ -331,13 +388,149 @@ class Rdata: args = (kwargs.get(key, getattr(self, key)) for key in parameters) # Create, validate, and return the new object. - # - # Note that if we make constructors do validation in the future, - # this validation can go away. rd = self.__class__(*args) - dns.rdata.from_text(rd.rdclass, rd.rdtype, rd.to_text()) + # The comment is not set in the constructor, so give it special + # handling. + rdcomment = kwargs.get('rdcomment', self.rdcomment) + if rdcomment is not None: + object.__setattr__(rd, 'rdcomment', rdcomment) return rd + # Type checking and conversion helpers. These are class methods as + # they don't touch object state and may be useful to others. + + @classmethod + def _as_rdataclass(cls, value): + return dns.rdataclass.RdataClass.make(value) + + @classmethod + def _as_rdatatype(cls, value): + return dns.rdatatype.RdataType.make(value) + + @classmethod + def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True): + if encode and isinstance(value, str): + value = value.encode() + elif isinstance(value, bytearray): + value = bytes(value) + elif not isinstance(value, bytes): + raise ValueError('not bytes') + if max_length is not None and len(value) > max_length: + raise ValueError('too long') + if not empty_ok and len(value) == 0: + raise ValueError('empty bytes not allowed') + return value + + @classmethod + def _as_name(cls, value): + # Note that proper name conversion (e.g. with origin and IDNA + # awareness) is expected to be done via from_text. This is just + # a simple thing for people invoking the constructor directly. + if isinstance(value, str): + return dns.name.from_text(value) + elif not isinstance(value, dns.name.Name): + raise ValueError('not a name') + return value + + @classmethod + def _as_uint8(cls, value): + if not isinstance(value, int): + raise ValueError('not an integer') + if value < 0 or value > 255: + raise ValueError('not a uint8') + return value + + @classmethod + def _as_uint16(cls, value): + if not isinstance(value, int): + raise ValueError('not an integer') + if value < 0 or value > 65535: + raise ValueError('not a uint16') + return value + + @classmethod + def _as_uint32(cls, value): + if not isinstance(value, int): + raise ValueError('not an integer') + if value < 0 or value > 4294967295: + raise ValueError('not a uint32') + return value + + @classmethod + def _as_uint48(cls, value): + if not isinstance(value, int): + raise ValueError('not an integer') + if value < 0 or value > 281474976710655: + raise ValueError('not a uint48') + return value + + @classmethod + def _as_int(cls, value, low=None, high=None): + if not isinstance(value, int): + raise ValueError('not an integer') + if low is not None and value < low: + raise ValueError('value too small') + if high is not None and value > high: + raise ValueError('value too large') + return value + + @classmethod + def _as_ipv4_address(cls, value): + if isinstance(value, str): + # call to check validity + dns.ipv4.inet_aton(value) + return value + elif isinstance(value, bytes): + return dns.ipv4.inet_ntoa(value) + else: + raise ValueError('not an IPv4 address') + + @classmethod + def _as_ipv6_address(cls, value): + if isinstance(value, str): + # call to check validity + dns.ipv6.inet_aton(value) + return value + elif isinstance(value, bytes): + return dns.ipv6.inet_ntoa(value) + else: + raise ValueError('not an IPv6 address') + + @classmethod + def _as_bool(cls, value): + if isinstance(value, bool): + return value + else: + raise ValueError('not a boolean') + + @classmethod + def _as_ttl(cls, value): + if isinstance(value, int): + return cls._as_int(value, 0, dns.ttl.MAX_TTL) + elif isinstance(value, str): + return dns.ttl.from_text(value) + else: + raise ValueError('not a TTL') + + @classmethod + def _as_tuple(cls, value, as_value): + try: + # For user convenience, if value is a singleton of the list + # element type, wrap it in a tuple. + return (as_value(value),) + except Exception: + # Otherwise, check each element of the iterable *value* + # against *as_value*. + return tuple(as_value(v) for v in value) + + # Processing order + + @classmethod + def _processing_order(cls, iterable): + items = list(iterable) + random.shuffle(items) + return items + class GenericRdata(Rdata): @@ -354,7 +547,7 @@ class GenericRdata(Rdata): object.__setattr__(self, 'data', data) def to_text(self, origin=None, relativize=True, **kw): - return r'\# %d ' % len(self.data) + _hexify(self.data) + return r'\# %d ' % len(self.data) + _hexify(self.data, **kw) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, @@ -364,13 +557,7 @@ class GenericRdata(Rdata): raise dns.exception.SyntaxError( r'generic rdata does not start with \#') length = tok.get_int() - chunks = [] - while 1: - token = tok.get() - if token.is_eol_or_eof(): - break - chunks.append(token.value.encode()) - hex = b''.join(chunks) + hex = tok.concatenate_remaining_identifiers().encode() data = binascii.unhexlify(hex) if len(data) != length: raise dns.exception.SyntaxError( @@ -453,29 +640,45 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, Returns an instance of the chosen Rdata subclass. """ - if isinstance(tok, str): tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec) rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) - if cls != GenericRdata: - # peek at first token - token = tok.get() - tok.unget(token) - if token.is_identifier() and \ - token.value == r'\#': - # - # Known type using the generic syntax. Extract the - # wire form from the generic syntax, and then run - # from_wire on it. - # - rdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, - relativize, relativize_to) - return from_wire(rdclass, rdtype, rdata.data, 0, len(rdata.data), - origin) - return cls.from_text(rdclass, rdtype, tok, origin, relativize, - relativize_to) + with dns.exception.ExceptionWrapper(dns.exception.SyntaxError): + rdata = None + if cls != GenericRdata: + # peek at first token + token = tok.get() + tok.unget(token) + if token.is_identifier() and \ + token.value == r'\#': + # + # Known type using the generic syntax. Extract the + # wire form from the generic syntax, and then run + # from_wire on it. + # + grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, + relativize, relativize_to) + rdata = from_wire(rdclass, rdtype, grdata.data, 0, + len(grdata.data), origin) + # + # If this comparison isn't equal, then there must have been + # compressed names in the wire format, which is an error, + # there being no reasonable context to decompress with. + # + rwire = rdata.to_wire() + if rwire != grdata.data: + raise dns.exception.SyntaxError('compressed data in ' + 'generic syntax form ' + 'of known rdatatype') + if rdata is None: + rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize, + relativize_to) + token = tok.get_eol_as_token() + if token.comment is not None: + object.__setattr__(rdata, 'rdcomment', token.comment) + return rdata def from_wire_parser(rdclass, rdtype, parser, origin=None): @@ -505,7 +708,8 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None): rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) - return cls.from_wire_parser(rdclass, rdtype, parser, origin) + with dns.exception.ExceptionWrapper(dns.exception.FormError): + return cls.from_wire_parser(rdclass, rdtype, parser, origin) def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): @@ -543,7 +747,7 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): class RdatatypeExists(dns.exception.DNSException): """DNS rdatatype already exists.""" supp_kwargs = {'rdclass', 'rdtype'} - fmt = "The rdata type with class {rdclass} and rdtype {rdtype} " + \ + fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \ "already exists." diff --git a/lib/dns/rdata.pyi b/lib/dns/rdata.pyi new file mode 100644 index 00000000..f394791f --- /dev/null +++ b/lib/dns/rdata.pyi @@ -0,0 +1,19 @@ +from typing import Dict, Tuple, Any, Optional, BinaryIO +from .name import Name, IDNACodec +class Rdata: + def __init__(self): + self.address : str + def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]: + ... + @classmethod + def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True): + ... +_rdata_modules : Dict[Tuple[Any,Rdata],Any] + +def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None, + relativize : bool = True, relativize_to : Optional[Name] = None, + idna_codec : Optional[IDNACodec] = None): + ... + +def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None): + ... diff --git a/lib/dns/rdataclass.py b/lib/dns/rdataclass.py index 7943a95a..41bba693 100644 --- a/lib/dns/rdataclass.py +++ b/lib/dns/rdataclass.py @@ -48,7 +48,6 @@ class RdataClass(dns.enum.IntEnum): def _unknown_exception_class(cls): return UnknownRdataclass -globals().update(RdataClass.__members__) _metaclasses = {RdataClass.NONE, RdataClass.ANY} @@ -100,3 +99,17 @@ def is_metaclass(rdclass): if rdclass in _metaclasses: return True return False + +### BEGIN generated RdataClass constants + +RESERVED0 = RdataClass.RESERVED0 +IN = RdataClass.IN +INTERNET = RdataClass.INTERNET +CH = RdataClass.CH +CHAOS = RdataClass.CHAOS +HS = RdataClass.HS +HESIOD = RdataClass.HESIOD +NONE = RdataClass.NONE +ANY = RdataClass.ANY + +### END generated RdataClass constants diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index 660415e7..e69ee232 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -22,6 +22,7 @@ import random import struct import dns.exception +import dns.immutable import dns.rdatatype import dns.rdataclass import dns.rdata @@ -79,15 +80,15 @@ class Rdataset(dns.set.Set): TTL or the specified TTL. If the set contains no rdatas, set the TTL to the specified TTL. - *ttl*, an ``int``. + *ttl*, an ``int`` or ``str``. """ - + ttl = dns.ttl.make(ttl) if len(self) == 0: self.ttl = ttl elif ttl < self.ttl: self.ttl = ttl - def add(self, rd, ttl=None): + def add(self, rd, ttl=None): # pylint: disable=arguments-differ """Add the specified rdata to the rdataset. If the optional *ttl* parameter is supplied, then @@ -176,8 +177,8 @@ class Rdataset(dns.set.Set): return not self.__eq__(other) def to_text(self, name=None, origin=None, relativize=True, - override_rdclass=None, **kw): - """Convert the rdataset into DNS master file format. + override_rdclass=None, want_comments=False, **kw): + """Convert the rdataset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information on how *origin* and *relativize* determine the way names @@ -194,6 +195,12 @@ class Rdataset(dns.set.Set): *relativize*, a ``bool``. If ``True``, names will be relativized to *origin*. + + *override_rdclass*, a ``dns.rdataclass.RdataClass`` or ``None``. + If not ``None``, use this class instead of the Rdataset's class. + + *want_comments*, a ``bool``. If ``True``, emit comments for rdata + which have them. The default is ``False``. """ if name is not None: @@ -219,11 +226,16 @@ class Rdataset(dns.set.Set): dns.rdatatype.to_text(self.rdtype))) else: for rd in self: - s.write('%s%s%d %s %s %s\n' % + extra = '' + if want_comments: + if rd.rdcomment: + extra = f' ;{rd.rdcomment}' + s.write('%s%s%d %s %s %s%s\n' % (ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass), dns.rdatatype.to_text(self.rdtype), rd.to_text(origin=origin, relativize=relativize, - **kw))) + **kw), + extra)) # # We strip off the final \n for the caller's convenience in printing # @@ -260,7 +272,7 @@ class Rdataset(dns.set.Set): want_shuffle = False else: rdclass = self.rdclass - file.seek(0, 2) + file.seek(0, io.SEEK_END) if len(self) == 0: name.to_wire(file, compress, origin) stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0) @@ -284,7 +296,7 @@ class Rdataset(dns.set.Set): file.seek(start - 2) stuff = struct.pack("!H", end - start) file.write(stuff) - file.seek(0, 2) + file.seek(0, io.SEEK_END) return len(self) def match(self, rdclass, rdtype, covers): @@ -297,8 +309,86 @@ class Rdataset(dns.set.Set): return True return False + def processing_order(self): + """Return rdatas in a valid processing order according to the type's + specification. For example, MX records are in preference order from + lowest to highest preferences, with items of the same perference + shuffled. -def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None): + For types that do not define a processing order, the rdatas are + simply shuffled. + """ + if len(self) == 0: + return [] + else: + return self[0]._processing_order(iter(self)) + + +@dns.immutable.immutable +class ImmutableRdataset(Rdataset): + + """An immutable DNS rdataset.""" + + _clone_class = Rdataset + + def __init__(self, rdataset): + """Create an immutable rdataset from the specified rdataset.""" + + super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers, + rdataset.ttl) + self.items = dns.immutable.Dict(rdataset.items) + + def update_ttl(self, ttl): + raise TypeError('immutable') + + def add(self, rd, ttl=None): + raise TypeError('immutable') + + def union_update(self, other): + raise TypeError('immutable') + + def intersection_update(self, other): + raise TypeError('immutable') + + def update(self, other): + raise TypeError('immutable') + + def __delitem__(self, i): + raise TypeError('immutable') + + def __ior__(self, other): + raise TypeError('immutable') + + def __iand__(self, other): + raise TypeError('immutable') + + def __iadd__(self, other): + raise TypeError('immutable') + + def __isub__(self, other): + raise TypeError('immutable') + + def clear(self): + raise TypeError('immutable') + + def __copy__(self): + return ImmutableRdataset(super().copy()) + + def copy(self): + return ImmutableRdataset(super().copy()) + + def union(self, other): + return ImmutableRdataset(super().union(other)) + + def intersection(self, other): + return ImmutableRdataset(super().intersection(other)) + + def difference(self, other): + return ImmutableRdataset(super().difference(other)) + + +def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, + origin=None, relativize=True, relativize_to=None): """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. @@ -306,6 +396,14 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None): encoder/decoder to use; if ``None``, the default IDNA 2003 encoder/decoder is used. + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + Returns a ``dns.rdataset.Rdataset`` object. """ @@ -314,7 +412,8 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None): r = Rdataset(rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, idna_codec=idna_codec) + rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, + relativize_to, idna_codec) r.add(rd) return r diff --git a/lib/dns/rdataset.pyi b/lib/dns/rdataset.pyi new file mode 100644 index 00000000..a7bbf2d4 --- /dev/null +++ b/lib/dns/rdataset.pyi @@ -0,0 +1,58 @@ +from typing import Optional, Dict, List, Union +from io import BytesIO +from . import exception, name, set, rdatatype, rdata, rdataset + +class DifferingCovers(exception.DNSException): + """An attempt was made to add a DNS SIG/RRSIG whose covered type + is not the same as that of the other rdatas in the rdataset.""" + + +class IncompatibleTypes(exception.DNSException): + """An attempt was made to add DNS RR data of an incompatible type.""" + + +class Rdataset(set.Set): + def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0): + self.rdclass : int = rdclass + self.rdtype : int = rdtype + self.covers : int = covers + self.ttl : int = ttl + + def update_ttl(self, ttl : int) -> None: + ... + + def add(self, rd : rdata.Rdata, ttl : Optional[int] =None): + ... + + def union_update(self, other : Rdataset): + ... + + def intersection_update(self, other : Rdataset): + ... + + def update(self, other : Rdataset): + ... + + def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True, + override_rdclass : Optional[int] =None, **kw) -> bytes: + ... + + def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None, + override_rdclass : Optional[int] = None, want_shuffle=True) -> int: + ... + + def match(self, rdclass : int, rdtype : int, covers : int) -> bool: + ... + + +def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset: + ... + +def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset: + ... + +def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: + ... + +def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: + ... diff --git a/lib/dns/rdatatype.py b/lib/dns/rdatatype.py index c793d5a0..9499c7b9 100644 --- a/lib/dns/rdatatype.py +++ b/lib/dns/rdatatype.py @@ -72,14 +72,22 @@ class RdataType(dns.enum.IntEnum): NSEC3 = 50 NSEC3PARAM = 51 TLSA = 52 + SMIMEA = 53 HIP = 55 NINFO = 56 CDS = 59 CDNSKEY = 60 OPENPGPKEY = 61 CSYNC = 62 + ZONEMD = 63 + SVCB = 64 + HTTPS = 65 SPF = 99 UNSPEC = 103 + NID = 104 + L32 = 105 + L64 = 106 + LP = 107 EUI48 = 108 EUI64 = 109 TKEY = 249 @@ -92,7 +100,7 @@ class RdataType(dns.enum.IntEnum): URI = 256 CAA = 257 AVC = 258 - AMTRELAY = 259 + AMTRELAY = 260 TA = 32768 DLV = 32769 @@ -115,8 +123,6 @@ class RdataType(dns.enum.IntEnum): _registered_by_text = {} _registered_by_value = {} -globals().update(RdataType.__members__) - _metatypes = {RdataType.OPT} _singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, @@ -219,3 +225,89 @@ def register_type(rdtype, rdtype_text, is_singleton=False): _registered_by_value[rdtype] = rdtype_text if is_singleton: _singletons.add(rdtype) + +### BEGIN generated RdataType constants + +TYPE0 = RdataType.TYPE0 +NONE = RdataType.NONE +A = RdataType.A +NS = RdataType.NS +MD = RdataType.MD +MF = RdataType.MF +CNAME = RdataType.CNAME +SOA = RdataType.SOA +MB = RdataType.MB +MG = RdataType.MG +MR = RdataType.MR +NULL = RdataType.NULL +WKS = RdataType.WKS +PTR = RdataType.PTR +HINFO = RdataType.HINFO +MINFO = RdataType.MINFO +MX = RdataType.MX +TXT = RdataType.TXT +RP = RdataType.RP +AFSDB = RdataType.AFSDB +X25 = RdataType.X25 +ISDN = RdataType.ISDN +RT = RdataType.RT +NSAP = RdataType.NSAP +NSAP_PTR = RdataType.NSAP_PTR +SIG = RdataType.SIG +KEY = RdataType.KEY +PX = RdataType.PX +GPOS = RdataType.GPOS +AAAA = RdataType.AAAA +LOC = RdataType.LOC +NXT = RdataType.NXT +SRV = RdataType.SRV +NAPTR = RdataType.NAPTR +KX = RdataType.KX +CERT = RdataType.CERT +A6 = RdataType.A6 +DNAME = RdataType.DNAME +OPT = RdataType.OPT +APL = RdataType.APL +DS = RdataType.DS +SSHFP = RdataType.SSHFP +IPSECKEY = RdataType.IPSECKEY +RRSIG = RdataType.RRSIG +NSEC = RdataType.NSEC +DNSKEY = RdataType.DNSKEY +DHCID = RdataType.DHCID +NSEC3 = RdataType.NSEC3 +NSEC3PARAM = RdataType.NSEC3PARAM +TLSA = RdataType.TLSA +SMIMEA = RdataType.SMIMEA +HIP = RdataType.HIP +NINFO = RdataType.NINFO +CDS = RdataType.CDS +CDNSKEY = RdataType.CDNSKEY +OPENPGPKEY = RdataType.OPENPGPKEY +CSYNC = RdataType.CSYNC +ZONEMD = RdataType.ZONEMD +SVCB = RdataType.SVCB +HTTPS = RdataType.HTTPS +SPF = RdataType.SPF +UNSPEC = RdataType.UNSPEC +NID = RdataType.NID +L32 = RdataType.L32 +L64 = RdataType.L64 +LP = RdataType.LP +EUI48 = RdataType.EUI48 +EUI64 = RdataType.EUI64 +TKEY = RdataType.TKEY +TSIG = RdataType.TSIG +IXFR = RdataType.IXFR +AXFR = RdataType.AXFR +MAILB = RdataType.MAILB +MAILA = RdataType.MAILA +ANY = RdataType.ANY +URI = RdataType.URI +CAA = RdataType.CAA +AVC = RdataType.AVC +AMTRELAY = RdataType.AMTRELAY +TA = RdataType.TA +DLV = RdataType.DLV + +### END generated RdataType constants diff --git a/lib/dns/rdtypes/ANY/AFSDB.py b/lib/dns/rdtypes/ANY/AFSDB.py index 40878900..d7838e7e 100644 --- a/lib/dns/rdtypes/ANY/AFSDB.py +++ b/lib/dns/rdtypes/ANY/AFSDB.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.mxbase +import dns.immutable +@dns.immutable.immutable class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): """AFSDB record""" diff --git a/lib/dns/rdtypes/ANY/AMTRELAY.py b/lib/dns/rdtypes/ANY/AMTRELAY.py index 4e012a27..9f093dee 100644 --- a/lib/dns/rdtypes/ANY/AMTRELAY.py +++ b/lib/dns/rdtypes/ANY/AMTRELAY.py @@ -18,12 +18,19 @@ import struct import dns.exception +import dns.immutable import dns.rdtypes.util class Relay(dns.rdtypes.util.Gateway): name = 'AMTRELAY relay' + @property + def relay(self): + return self.gateway + + +@dns.immutable.immutable class AMTRELAY(dns.rdata.Rdata): """AMTRELAY record""" @@ -35,11 +42,11 @@ class AMTRELAY(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay): super().__init__(rdclass, rdtype) - Relay(relay_type, relay).check() - object.__setattr__(self, 'precedence', precedence) - object.__setattr__(self, 'discovery_optional', discovery_optional) - object.__setattr__(self, 'relay_type', relay_type) - object.__setattr__(self, 'relay', relay) + relay = Relay(relay_type, relay) + self.precedence = self._as_uint8(precedence) + self.discovery_optional = self._as_bool(discovery_optional) + self.relay_type = relay.type + self.relay = relay.relay def to_text(self, origin=None, relativize=True, **kw): relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) @@ -57,10 +64,10 @@ class AMTRELAY(dns.rdata.Rdata): relay_type = tok.get_uint8() if relay_type > 0x7f: raise dns.exception.SyntaxError('expecting an integer <= 127') - relay = Relay(relay_type).from_text(tok, origin, relativize, - relativize_to) + relay = Relay.from_text(relay_type, tok, origin, relativize, + relativize_to) return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, - relay) + relay.relay) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): relay_type = self.relay_type | (self.discovery_optional << 7) @@ -74,6 +81,6 @@ class AMTRELAY(dns.rdata.Rdata): (precedence, relay_type) = parser.get_struct('!BB') discovery_optional = bool(relay_type >> 7) relay_type &= 0x7f - relay = Relay(relay_type).from_wire_parser(parser, origin) + relay = Relay.from_wire_parser(relay_type, parser, origin) return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, - relay) + relay.relay) diff --git a/lib/dns/rdtypes/ANY/AVC.py b/lib/dns/rdtypes/ANY/AVC.py index 1fa5ecfd..11e026d0 100644 --- a/lib/dns/rdtypes/ANY/AVC.py +++ b/lib/dns/rdtypes/ANY/AVC.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.txtbase +import dns.immutable +@dns.immutable.immutable class AVC(dns.rdtypes.txtbase.TXTBase): """AVC record""" diff --git a/lib/dns/rdtypes/ANY/CAA.py b/lib/dns/rdtypes/ANY/CAA.py index b7edae87..c86b45ea 100644 --- a/lib/dns/rdtypes/ANY/CAA.py +++ b/lib/dns/rdtypes/ANY/CAA.py @@ -18,10 +18,12 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class CAA(dns.rdata.Rdata): """CAA (Certification Authority Authorization) record""" @@ -32,9 +34,11 @@ class CAA(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, flags, tag, value): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'flags', flags) - object.__setattr__(self, 'tag', tag) - object.__setattr__(self, 'value', value) + self.flags = self._as_uint8(flags) + self.tag = self._as_bytes(tag, True, 255) + if not tag.isalnum(): + raise ValueError("tag is not alphanumeric") + self.value = self._as_bytes(value) def to_text(self, origin=None, relativize=True, **kw): return '%u %s "%s"' % (self.flags, @@ -46,10 +50,6 @@ class CAA(dns.rdata.Rdata): relativize_to=None): flags = tok.get_uint8() tag = tok.get_string().encode() - if len(tag) > 255: - raise dns.exception.SyntaxError("tag too long") - if not tag.isalnum(): - raise dns.exception.SyntaxError("tag is not alphanumeric") value = tok.get_string().encode() return cls(rdclass, rdtype, flags, tag, value) diff --git a/lib/dns/rdtypes/ANY/CDNSKEY.py b/lib/dns/rdtypes/ANY/CDNSKEY.py index 72253183..14b19417 100644 --- a/lib/dns/rdtypes/ANY/CDNSKEY.py +++ b/lib/dns/rdtypes/ANY/CDNSKEY.py @@ -16,9 +16,13 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dnskeybase +import dns.immutable + +# pylint: disable=unused-import from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 +# pylint: enable=unused-import - +@dns.immutable.immutable class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): """CDNSKEY record""" diff --git a/lib/dns/rdtypes/ANY/CDS.py b/lib/dns/rdtypes/ANY/CDS.py index a63041dd..094de12b 100644 --- a/lib/dns/rdtypes/ANY/CDS.py +++ b/lib/dns/rdtypes/ANY/CDS.py @@ -16,8 +16,15 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dsbase +import dns.immutable +@dns.immutable.immutable class CDS(dns.rdtypes.dsbase.DSBase): """CDS record""" + + _digest_length_by_type = { + **dns.rdtypes.dsbase.DSBase._digest_length_by_type, + 0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049) + } diff --git a/lib/dns/rdtypes/ANY/CERT.py b/lib/dns/rdtypes/ANY/CERT.py index 62df241c..f35ce3ad 100644 --- a/lib/dns/rdtypes/ANY/CERT.py +++ b/lib/dns/rdtypes/ANY/CERT.py @@ -19,6 +19,7 @@ import struct import base64 import dns.exception +import dns.immutable import dns.dnssec import dns.rdata import dns.tokenizer @@ -27,6 +28,11 @@ _ctype_by_value = { 1: 'PKIX', 2: 'SPKI', 3: 'PGP', + 4: 'IPKIX', + 5: 'ISPKI', + 6: 'IPGP', + 7: 'ACPKIX', + 8: 'IACPKIX', 253: 'URI', 254: 'OID', } @@ -35,6 +41,11 @@ _ctype_by_name = { 'PKIX': 1, 'SPKI': 2, 'PGP': 3, + 'IPKIX': 4, + 'ISPKI': 5, + 'IPGP': 6, + 'ACPKIX': 7, + 'IACPKIX': 8, 'URI': 253, 'OID': 254, } @@ -54,27 +65,28 @@ def _ctype_to_text(what): return str(what) +@dns.immutable.immutable class CERT(dns.rdata.Rdata): """CERT record""" - # see RFC 2538 + # see RFC 4398 __slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate'] def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'certificate_type', certificate_type) - object.__setattr__(self, 'key_tag', key_tag) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'certificate', certificate) + self.certificate_type = self._as_uint16(certificate_type) + self.key_tag = self._as_uint16(key_tag) + self.algorithm = self._as_uint8(algorithm) + self.certificate = self._as_bytes(certificate) def to_text(self, origin=None, relativize=True, **kw): certificate_type = _ctype_to_text(self.certificate_type) return "%s %d %s %s" % (certificate_type, self.key_tag, dns.dnssec.algorithm_to_text(self.algorithm), - dns.rdata._base64ify(self.certificate)) + dns.rdata._base64ify(self.certificate, **kw)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, @@ -82,8 +94,6 @@ class CERT(dns.rdata.Rdata): certificate_type = _ctype_from_text(tok.get_string()) key_tag = tok.get_uint16() algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) - if algorithm < 0 or algorithm > 255: - raise dns.exception.SyntaxError("bad algorithm type") b64 = tok.concatenate_remaining_identifiers().encode() certificate = base64.b64decode(b64) return cls(rdclass, rdtype, certificate_type, key_tag, diff --git a/lib/dns/rdtypes/ANY/CNAME.py b/lib/dns/rdtypes/ANY/CNAME.py index 11d42aa7..a4fcfa88 100644 --- a/lib/dns/rdtypes/ANY/CNAME.py +++ b/lib/dns/rdtypes/ANY/CNAME.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.nsbase +import dns.immutable +@dns.immutable.immutable class CNAME(dns.rdtypes.nsbase.NSBase): """CNAME record diff --git a/lib/dns/rdtypes/ANY/CSYNC.py b/lib/dns/rdtypes/ANY/CSYNC.py index 9cba5fad..979028ae 100644 --- a/lib/dns/rdtypes/ANY/CSYNC.py +++ b/lib/dns/rdtypes/ANY/CSYNC.py @@ -18,16 +18,19 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.rdatatype import dns.name import dns.rdtypes.util +@dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): type_name = 'CSYNC' +@dns.immutable.immutable class CSYNC(dns.rdata.Rdata): """CSYNC record""" @@ -36,9 +39,11 @@ class CSYNC(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, serial, flags, windows): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'serial', serial) - object.__setattr__(self, 'flags', flags) - object.__setattr__(self, 'windows', dns.rdata._constify(windows)) + self.serial = self._as_uint32(serial) + self.flags = self._as_uint16(flags) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = tuple(windows.windows) def to_text(self, origin=None, relativize=True, **kw): text = Bitmap(self.windows).to_text() @@ -49,8 +54,8 @@ class CSYNC(dns.rdata.Rdata): relativize_to=None): serial = tok.get_uint32() flags = tok.get_uint16() - windows = Bitmap().from_text(tok) - return cls(rdclass, rdtype, serial, flags, windows) + bitmap = Bitmap.from_text(tok) + return cls(rdclass, rdtype, serial, flags, bitmap) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(struct.pack('!IH', self.serial, self.flags)) @@ -59,5 +64,5 @@ class CSYNC(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (serial, flags) = parser.get_struct("!IH") - windows = Bitmap().from_wire_parser(parser) - return cls(rdclass, rdtype, serial, flags, windows) + bitmap = Bitmap.from_wire_parser(parser) + return cls(rdclass, rdtype, serial, flags, bitmap) diff --git a/lib/dns/rdtypes/ANY/DLV.py b/lib/dns/rdtypes/ANY/DLV.py index 16352125..947dc42e 100644 --- a/lib/dns/rdtypes/ANY/DLV.py +++ b/lib/dns/rdtypes/ANY/DLV.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dsbase +import dns.immutable +@dns.immutable.immutable class DLV(dns.rdtypes.dsbase.DSBase): """DLV record""" diff --git a/lib/dns/rdtypes/ANY/DNAME.py b/lib/dns/rdtypes/ANY/DNAME.py index 2000d9b0..f4984b55 100644 --- a/lib/dns/rdtypes/ANY/DNAME.py +++ b/lib/dns/rdtypes/ANY/DNAME.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.nsbase +import dns.immutable +@dns.immutable.immutable class DNAME(dns.rdtypes.nsbase.UncompressedNS): """DNAME record""" diff --git a/lib/dns/rdtypes/ANY/DNSKEY.py b/lib/dns/rdtypes/ANY/DNSKEY.py index 2ee37988..e69a7c19 100644 --- a/lib/dns/rdtypes/ANY/DNSKEY.py +++ b/lib/dns/rdtypes/ANY/DNSKEY.py @@ -16,9 +16,13 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dnskeybase +import dns.immutable + +# pylint: disable=unused-import from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 +# pylint: enable=unused-import - +@dns.immutable.immutable class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): """DNSKEY record""" diff --git a/lib/dns/rdtypes/ANY/DS.py b/lib/dns/rdtypes/ANY/DS.py index 7d457b22..3f6c3ee8 100644 --- a/lib/dns/rdtypes/ANY/DS.py +++ b/lib/dns/rdtypes/ANY/DS.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dsbase +import dns.immutable +@dns.immutable.immutable class DS(dns.rdtypes.dsbase.DSBase): """DS record""" diff --git a/lib/dns/rdtypes/ANY/EUI48.py b/lib/dns/rdtypes/ANY/EUI48.py index b16e81f3..0ab88ad0 100644 --- a/lib/dns/rdtypes/ANY/EUI48.py +++ b/lib/dns/rdtypes/ANY/EUI48.py @@ -17,8 +17,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.euibase +import dns.immutable +@dns.immutable.immutable class EUI48(dns.rdtypes.euibase.EUIBase): """EUI48 record""" diff --git a/lib/dns/rdtypes/ANY/EUI64.py b/lib/dns/rdtypes/ANY/EUI64.py index cc080760..c42957ef 100644 --- a/lib/dns/rdtypes/ANY/EUI64.py +++ b/lib/dns/rdtypes/ANY/EUI64.py @@ -17,8 +17,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.euibase +import dns.immutable +@dns.immutable.immutable class EUI64(dns.rdtypes.euibase.EUIBase): """EUI64 record""" diff --git a/lib/dns/rdtypes/ANY/GPOS.py b/lib/dns/rdtypes/ANY/GPOS.py index 03677fd2..29fa8f8b 100644 --- a/lib/dns/rdtypes/ANY/GPOS.py +++ b/lib/dns/rdtypes/ANY/GPOS.py @@ -18,6 +18,7 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer @@ -41,12 +42,7 @@ def _validate_float_string(what): raise dns.exception.FormError -def _sanitize(value): - if isinstance(value, str): - return value.encode() - return value - - +@dns.immutable.immutable class GPOS(dns.rdata.Rdata): """GPOS record""" @@ -66,15 +62,15 @@ class GPOS(dns.rdata.Rdata): if isinstance(altitude, float) or \ isinstance(altitude, int): altitude = str(altitude) - latitude = _sanitize(latitude) - longitude = _sanitize(longitude) - altitude = _sanitize(altitude) + latitude = self._as_bytes(latitude, True, 255) + longitude = self._as_bytes(longitude, True, 255) + altitude = self._as_bytes(altitude, True, 255) _validate_float_string(latitude) _validate_float_string(longitude) _validate_float_string(altitude) - object.__setattr__(self, 'latitude', latitude) - object.__setattr__(self, 'longitude', longitude) - object.__setattr__(self, 'altitude', altitude) + self.latitude = latitude + self.longitude = longitude + self.altitude = altitude flat = self.float_latitude if flat < -90.0 or flat > 90.0: raise dns.exception.FormError('bad latitude') @@ -93,7 +89,6 @@ class GPOS(dns.rdata.Rdata): latitude = tok.get_string() longitude = tok.get_string() altitude = tok.get_string() - tok.get_eol() return cls(rdclass, rdtype, latitude, longitude, altitude) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/HINFO.py b/lib/dns/rdtypes/ANY/HINFO.py index 587e0ad1..cd049693 100644 --- a/lib/dns/rdtypes/ANY/HINFO.py +++ b/lib/dns/rdtypes/ANY/HINFO.py @@ -18,10 +18,12 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class HINFO(dns.rdata.Rdata): """HINFO record""" @@ -32,14 +34,8 @@ class HINFO(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, cpu, os): super().__init__(rdclass, rdtype) - if isinstance(cpu, str): - object.__setattr__(self, 'cpu', cpu.encode()) - else: - object.__setattr__(self, 'cpu', cpu) - if isinstance(os, str): - object.__setattr__(self, 'os', os.encode()) - else: - object.__setattr__(self, 'os', os) + self.cpu = self._as_bytes(cpu, True, 255) + self.os = self._as_bytes(os, True, 255) def to_text(self, origin=None, relativize=True, **kw): return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), @@ -50,7 +46,6 @@ class HINFO(dns.rdata.Rdata): relativize_to=None): cpu = tok.get_string(max_length=255) os = tok.get_string(max_length=255) - tok.get_eol() return cls(rdclass, rdtype, cpu, os) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index 1c774bbf..e887359b 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -20,10 +20,12 @@ import base64 import binascii import dns.exception +import dns.immutable import dns.rdata import dns.rdatatype +@dns.immutable.immutable class HIP(dns.rdata.Rdata): """HIP record""" @@ -34,10 +36,10 @@ class HIP(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'hit', hit) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'key', key) - object.__setattr__(self, 'servers', dns.rdata._constify(servers)) + self.hit = self._as_bytes(hit, True, 255) + self.algorithm = self._as_uint8(algorithm) + self.key = self._as_bytes(key, True) + self.servers = self._as_tuple(servers, self._as_name) def to_text(self, origin=None, relativize=True, **kw): hit = binascii.hexlify(self.hit).decode() @@ -55,14 +57,9 @@ class HIP(dns.rdata.Rdata): relativize_to=None): algorithm = tok.get_uint8() hit = binascii.unhexlify(tok.get_string().encode()) - if len(hit) > 255: - raise dns.exception.SyntaxError("HIT too long") key = base64.b64decode(tok.get_string().encode()) servers = [] - while 1: - token = tok.get() - if token.is_eol_or_eof(): - break + for token in tok.get_remaining(): server = tok.as_name(token, origin, relativize, relativize_to) servers.append(server) return cls(rdclass, rdtype, hit, algorithm, key, servers) diff --git a/lib/dns/rdtypes/ANY/ISDN.py b/lib/dns/rdtypes/ANY/ISDN.py index 6834b3c7..b9a49adb 100644 --- a/lib/dns/rdtypes/ANY/ISDN.py +++ b/lib/dns/rdtypes/ANY/ISDN.py @@ -18,10 +18,12 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class ISDN(dns.rdata.Rdata): """ISDN record""" @@ -32,14 +34,8 @@ class ISDN(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, address, subaddress): super().__init__(rdclass, rdtype) - if isinstance(address, str): - object.__setattr__(self, 'address', address.encode()) - else: - object.__setattr__(self, 'address', address) - if isinstance(address, str): - object.__setattr__(self, 'subaddress', subaddress.encode()) - else: - object.__setattr__(self, 'subaddress', subaddress) + self.address = self._as_bytes(address, True, 255) + self.subaddress = self._as_bytes(subaddress, True, 255) def to_text(self, origin=None, relativize=True, **kw): if self.subaddress: @@ -52,14 +48,11 @@ class ISDN(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): address = tok.get_string() - t = tok.get() - if not t.is_eol_or_eof(): - tok.unget(t) - subaddress = tok.get_string() + tokens = tok.get_remaining(max_tokens=1) + if len(tokens) >= 1: + subaddress = tokens[0].unescape().value else: - tok.unget(t) subaddress = '' - tok.get_eol() return cls(rdclass, rdtype, address, subaddress) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/L32.py b/lib/dns/rdtypes/ANY/L32.py new file mode 100644 index 00000000..47eff958 --- /dev/null +++ b/lib/dns/rdtypes/ANY/L32.py @@ -0,0 +1,40 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable + + +@dns.immutable.immutable +class L32(dns.rdata.Rdata): + + """L32 record""" + + # see: rfc6742.txt + + __slots__ = ['preference', 'locator32'] + + def __init__(self, rdclass, rdtype, preference, locator32): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + self.locator32 = self._as_ipv4_address(locator32) + + def to_text(self, origin=None, relativize=True, **kw): + return f'{self.preference} {self.locator32}' + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + preference = tok.get_uint16() + nodeid = tok.get_identifier() + return cls(rdclass, rdtype, preference, nodeid) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack('!H', self.preference)) + file.write(dns.ipv4.inet_aton(self.locator32)) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + locator32 = parser.get_remaining() + return cls(rdclass, rdtype, preference, locator32) diff --git a/lib/dns/rdtypes/ANY/L64.py b/lib/dns/rdtypes/ANY/L64.py new file mode 100644 index 00000000..aab36a82 --- /dev/null +++ b/lib/dns/rdtypes/ANY/L64.py @@ -0,0 +1,48 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable +import dns.rdtypes.util + + +@dns.immutable.immutable +class L64(dns.rdata.Rdata): + + """L64 record""" + + # see: rfc6742.txt + + __slots__ = ['preference', 'locator64'] + + def __init__(self, rdclass, rdtype, preference, locator64): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + if isinstance(locator64, bytes): + if len(locator64) != 8: + raise ValueError('invalid locator64') + self.locator64 = dns.rdata._hexify(locator64, 4, b':') + else: + dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':') + self.locator64 = locator64 + + def to_text(self, origin=None, relativize=True, **kw): + return f'{self.preference} {self.locator64}' + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + preference = tok.get_uint16() + locator64 = tok.get_identifier() + return cls(rdclass, rdtype, preference, locator64) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack('!H', self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, + 4, 4, ':')) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + locator64 = parser.get_remaining() + return cls(rdclass, rdtype, preference, locator64) diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index eb00a1cd..c9398994 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -18,6 +18,7 @@ import struct import dns.exception +import dns.immutable import dns.rdata @@ -34,17 +35,13 @@ _MIN_LATITUDE = 0x80000000 - 90 * 3600000 _MAX_LONGITUDE = 0x80000000 + 180 * 3600000 _MIN_LONGITUDE = 0x80000000 - 180 * 3600000 -# pylint complains about division since we don't have a from __future__ for -# it, but we don't care about python 2 warnings, so turn them off. -# -# pylint: disable=old-division def _exponent_of(what, desc): if what == 0: return 0 exp = None for (i, pow) in enumerate(_pows): - if what // pow == 0: + if what < pow: exp = i - 1 break if exp is None or exp < 0: @@ -58,7 +55,7 @@ def _float_to_tuple(what): what *= -1 else: sign = 1 - what = round(what * 3600000) # pylint: disable=round-builtin + what = round(what * 3600000) degrees = int(what // 3600000) what -= degrees * 3600000 minutes = int(what // 60000) @@ -94,6 +91,20 @@ def _decode_size(what, desc): return base * pow(10, exponent) +def _check_coordinate_list(value, low, high): + if value[0] < low or value[0] > high: + raise ValueError(f'not in range [{low}, {high}]') + if value[1] < 0 or value[1] > 59: + raise ValueError('bad minutes value') + if value[2] < 0 or value[2] > 59: + raise ValueError('bad seconds value') + if value[3] < 0 or value[3] > 999: + raise ValueError('bad milliseconds value') + if value[4] != 1 and value[4] != -1: + raise ValueError('bad hemisphere value') + + +@dns.immutable.immutable class LOC(dns.rdata.Rdata): """LOC record""" @@ -119,16 +130,18 @@ class LOC(dns.rdata.Rdata): latitude = float(latitude) if isinstance(latitude, float): latitude = _float_to_tuple(latitude) - object.__setattr__(self, 'latitude', dns.rdata._constify(latitude)) + _check_coordinate_list(latitude, -90, 90) + self.latitude = tuple(latitude) if isinstance(longitude, int): longitude = float(longitude) if isinstance(longitude, float): longitude = _float_to_tuple(longitude) - object.__setattr__(self, 'longitude', dns.rdata._constify(longitude)) - object.__setattr__(self, 'altitude', float(altitude)) - object.__setattr__(self, 'size', float(size)) - object.__setattr__(self, 'horizontal_precision', float(hprec)) - object.__setattr__(self, 'vertical_precision', float(vprec)) + _check_coordinate_list(longitude, -180, 180) + self.longitude = tuple(longitude) + self.altitude = float(altitude) + self.size = float(size) + self.horizontal_precision = float(hprec) + self.vertical_precision = float(vprec) def to_text(self, origin=None, relativize=True, **kw): if self.latitude[4] > 0: @@ -167,13 +180,9 @@ class LOC(dns.rdata.Rdata): vprec = _default_vprec latitude[0] = tok.get_int() - if latitude[0] > 90: - raise dns.exception.SyntaxError('latitude >= 90') t = tok.get_string() if t.isdigit(): latitude[1] = int(t) - if latitude[1] >= 60: - raise dns.exception.SyntaxError('latitude minutes >= 60') t = tok.get_string() if '.' in t: (seconds, milliseconds) = t.split('.') @@ -181,8 +190,6 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError( 'bad latitude seconds value') latitude[2] = int(seconds) - if latitude[2] >= 60: - raise dns.exception.SyntaxError('latitude seconds >= 60') l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): raise dns.exception.SyntaxError( @@ -204,13 +211,9 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError('bad latitude hemisphere value') longitude[0] = tok.get_int() - if longitude[0] > 180: - raise dns.exception.SyntaxError('longitude > 180') t = tok.get_string() if t.isdigit(): longitude[1] = int(t) - if longitude[1] >= 60: - raise dns.exception.SyntaxError('longitude minutes >= 60') t = tok.get_string() if '.' in t: (seconds, milliseconds) = t.split('.') @@ -218,8 +221,6 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError( 'bad longitude seconds value') longitude[2] = int(seconds) - if longitude[2] >= 60: - raise dns.exception.SyntaxError('longitude seconds >= 60') l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): raise dns.exception.SyntaxError( @@ -245,25 +246,22 @@ class LOC(dns.rdata.Rdata): t = t[0: -1] altitude = float(t) * 100.0 # m -> cm - token = tok.get().unescape() - if not token.is_eol_or_eof(): - value = token.value + tokens = tok.get_remaining(max_tokens=3) + if len(tokens) >= 1: + value = tokens[0].unescape().value if value[-1] == 'm': value = value[0: -1] size = float(value) * 100.0 # m -> cm - token = tok.get().unescape() - if not token.is_eol_or_eof(): - value = token.value + if len(tokens) >= 2: + value = tokens[1].unescape().value if value[-1] == 'm': value = value[0: -1] hprec = float(value) * 100.0 # m -> cm - token = tok.get().unescape() - if not token.is_eol_or_eof(): - value = token.value + if len(tokens) >= 3: + value = tokens[2].unescape().value if value[-1] == 'm': value = value[0: -1] vprec = float(value) * 100.0 # m -> cm - tok.get_eol() # Try encoding these now so we raise if they are bad _encode_size(size, "size") @@ -296,6 +294,8 @@ class LOC(dns.rdata.Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (version, size, hprec, vprec, latitude, longitude, altitude) = \ parser.get_struct("!BBBBIII") + if version != 0: + raise dns.exception.FormError("LOC version not zero") if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: raise dns.exception.FormError("bad latitude") if latitude > 0x80000000: diff --git a/lib/dns/rdtypes/ANY/LP.py b/lib/dns/rdtypes/ANY/LP.py new file mode 100644 index 00000000..b6a2e36c --- /dev/null +++ b/lib/dns/rdtypes/ANY/LP.py @@ -0,0 +1,41 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable + + +@dns.immutable.immutable +class LP(dns.rdata.Rdata): + + """LP record""" + + # see: rfc6742.txt + + __slots__ = ['preference', 'fqdn'] + + def __init__(self, rdclass, rdtype, preference, fqdn): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + self.fqdn = self._as_name(fqdn) + + def to_text(self, origin=None, relativize=True, **kw): + fqdn = self.fqdn.choose_relativity(origin, relativize) + return '%d %s' % (self.preference, fqdn) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + preference = tok.get_uint16() + fqdn = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, preference, fqdn) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack('!H', self.preference)) + self.fqdn.to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + fqdn = parser.get_name(origin) + return cls(rdclass, rdtype, preference, fqdn) diff --git a/lib/dns/rdtypes/ANY/MX.py b/lib/dns/rdtypes/ANY/MX.py index 0a06494f..a697ea45 100644 --- a/lib/dns/rdtypes/ANY/MX.py +++ b/lib/dns/rdtypes/ANY/MX.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.mxbase +import dns.immutable +@dns.immutable.immutable class MX(dns.rdtypes.mxbase.MXBase): """MX record""" diff --git a/lib/dns/rdtypes/ANY/NID.py b/lib/dns/rdtypes/ANY/NID.py new file mode 100644 index 00000000..74951bbf --- /dev/null +++ b/lib/dns/rdtypes/ANY/NID.py @@ -0,0 +1,47 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable +import dns.rdtypes.util + + +@dns.immutable.immutable +class NID(dns.rdata.Rdata): + + """NID record""" + + # see: rfc6742.txt + + __slots__ = ['preference', 'nodeid'] + + def __init__(self, rdclass, rdtype, preference, nodeid): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + if isinstance(nodeid, bytes): + if len(nodeid) != 8: + raise ValueError('invalid nodeid') + self.nodeid = dns.rdata._hexify(nodeid, 4, b':') + else: + dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':') + self.nodeid = nodeid + + def to_text(self, origin=None, relativize=True, **kw): + return f'{self.preference} {self.nodeid}' + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + preference = tok.get_uint16() + nodeid = tok.get_identifier() + return cls(rdclass, rdtype, preference, nodeid) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack('!H', self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':')) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + nodeid = parser.get_remaining() + return cls(rdclass, rdtype, preference, nodeid) diff --git a/lib/dns/rdtypes/ANY/NINFO.py b/lib/dns/rdtypes/ANY/NINFO.py index d4c8572c..d53e9676 100644 --- a/lib/dns/rdtypes/ANY/NINFO.py +++ b/lib/dns/rdtypes/ANY/NINFO.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.txtbase +import dns.immutable +@dns.immutable.immutable class NINFO(dns.rdtypes.txtbase.TXTBase): """NINFO record""" diff --git a/lib/dns/rdtypes/ANY/NS.py b/lib/dns/rdtypes/ANY/NS.py index f9fcf637..a0cc232a 100644 --- a/lib/dns/rdtypes/ANY/NS.py +++ b/lib/dns/rdtypes/ANY/NS.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.nsbase +import dns.immutable +@dns.immutable.immutable class NS(dns.rdtypes.nsbase.NSBase): """NS record""" diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index 626d3399..dc31f4c4 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -16,16 +16,19 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.exception +import dns.immutable import dns.rdata import dns.rdatatype import dns.name import dns.rdtypes.util +@dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): type_name = 'NSEC' +@dns.immutable.immutable class NSEC(dns.rdata.Rdata): """NSEC record""" @@ -34,8 +37,10 @@ class NSEC(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, next, windows): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'next', next) - object.__setattr__(self, 'windows', dns.rdata._constify(windows)) + self.next = self._as_name(next) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = tuple(windows.windows) def to_text(self, origin=None, relativize=True, **kw): next = self.next.choose_relativity(origin, relativize) @@ -46,15 +51,17 @@ class NSEC(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): next = tok.get_name(origin, relativize, relativize_to) - windows = Bitmap().from_text(tok) + windows = Bitmap.from_text(tok) return cls(rdclass, rdtype, next, windows) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + # Note that NSEC downcasing, originally mandated by RFC 4034 + # section 6.2 was removed by RFC 6840 section 5.1. self.next.to_wire(file, None, origin, False) Bitmap(self.windows).to_wire(file) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): next = parser.get_name(origin) - windows = Bitmap().from_wire_parser(parser) - return cls(rdclass, rdtype, next, windows) + bitmap = Bitmap.from_wire_parser(parser) + return cls(rdclass, rdtype, next, bitmap) diff --git a/lib/dns/rdtypes/ANY/NSEC3.py b/lib/dns/rdtypes/ANY/NSEC3.py index 91471f0f..14242bda 100644 --- a/lib/dns/rdtypes/ANY/NSEC3.py +++ b/lib/dns/rdtypes/ANY/NSEC3.py @@ -20,6 +20,7 @@ import binascii import struct import dns.exception +import dns.immutable import dns.rdata import dns.rdatatype import dns.rdtypes.util @@ -37,10 +38,12 @@ SHA1 = 1 OPTOUT = 1 +@dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): type_name = 'NSEC3' +@dns.immutable.immutable class NSEC3(dns.rdata.Rdata): """NSEC3 record""" @@ -50,15 +53,14 @@ class NSEC3(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'flags', flags) - object.__setattr__(self, 'iterations', iterations) - if isinstance(salt, str): - object.__setattr__(self, 'salt', salt.encode()) - else: - object.__setattr__(self, 'salt', salt) - object.__setattr__(self, 'next', next) - object.__setattr__(self, 'windows', dns.rdata._constify(windows)) + self.algorithm = self._as_uint8(algorithm) + self.flags = self._as_uint8(flags) + self.iterations = self._as_uint16(iterations) + self.salt = self._as_bytes(salt, True, 255) + self.next = self._as_bytes(next, True, 255) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = tuple(windows.windows) def to_text(self, origin=None, relativize=True, **kw): next = base64.b32encode(self.next).translate( @@ -85,9 +87,9 @@ class NSEC3(dns.rdata.Rdata): next = tok.get_string().encode( 'ascii').upper().translate(b32_hex_to_normal) next = base64.b32decode(next) - windows = Bitmap().from_text(tok) + bitmap = Bitmap.from_text(tok) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, - windows) + bitmap) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) @@ -104,6 +106,6 @@ class NSEC3(dns.rdata.Rdata): (algorithm, flags, iterations) = parser.get_struct('!BBH') salt = parser.get_counted_bytes() next = parser.get_counted_bytes() - windows = Bitmap().from_wire_parser(parser) + bitmap = Bitmap.from_wire_parser(parser) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, - windows) + bitmap) diff --git a/lib/dns/rdtypes/ANY/NSEC3PARAM.py b/lib/dns/rdtypes/ANY/NSEC3PARAM.py index 8ac76271..299bf6ed 100644 --- a/lib/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/lib/dns/rdtypes/ANY/NSEC3PARAM.py @@ -19,9 +19,11 @@ import struct import binascii import dns.exception +import dns.immutable import dns.rdata +@dns.immutable.immutable class NSEC3PARAM(dns.rdata.Rdata): """NSEC3PARAM record""" @@ -30,13 +32,10 @@ class NSEC3PARAM(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'flags', flags) - object.__setattr__(self, 'iterations', iterations) - if isinstance(salt, str): - object.__setattr__(self, 'salt', salt.encode()) - else: - object.__setattr__(self, 'salt', salt) + self.algorithm = self._as_uint8(algorithm) + self.flags = self._as_uint8(flags) + self.iterations = self._as_uint16(iterations) + self.salt = self._as_bytes(salt, True, 255) def to_text(self, origin=None, relativize=True, **kw): if self.salt == b'': @@ -57,7 +56,6 @@ class NSEC3PARAM(dns.rdata.Rdata): salt = '' else: salt = binascii.unhexlify(salt.encode()) - tok.get_eol() return cls(rdclass, rdtype, algorithm, flags, iterations, salt) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/OPENPGPKEY.py b/lib/dns/rdtypes/ANY/OPENPGPKEY.py index f632132e..dcfa028d 100644 --- a/lib/dns/rdtypes/ANY/OPENPGPKEY.py +++ b/lib/dns/rdtypes/ANY/OPENPGPKEY.py @@ -18,9 +18,11 @@ import base64 import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class OPENPGPKEY(dns.rdata.Rdata): """OPENPGPKEY record""" @@ -29,10 +31,10 @@ class OPENPGPKEY(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, key): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'key', key) + self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): - return dns.rdata._base64ify(self.key) + return dns.rdata._base64ify(self.key, chunksize=None, **kw) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, diff --git a/lib/dns/rdtypes/ANY/OPT.py b/lib/dns/rdtypes/ANY/OPT.py index c48aa12f..69b8fe75 100644 --- a/lib/dns/rdtypes/ANY/OPT.py +++ b/lib/dns/rdtypes/ANY/OPT.py @@ -18,10 +18,15 @@ import struct import dns.edns +import dns.immutable import dns.exception import dns.rdata +# We don't implement from_text, and that's ok. +# pylint: disable=abstract-method + +@dns.immutable.immutable class OPT(dns.rdata.Rdata): """OPT record""" @@ -40,7 +45,11 @@ class OPT(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) - object.__setattr__(self, 'options', dns.rdata._constify(options)) + def as_option(option): + if not isinstance(option, dns.edns.Option): + raise ValueError('option is not a dns.edns.option') + return option + self.options = self._as_tuple(options, as_option) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for opt in self.options: diff --git a/lib/dns/rdtypes/ANY/PTR.py b/lib/dns/rdtypes/ANY/PTR.py index 20cd5076..265bed03 100644 --- a/lib/dns/rdtypes/ANY/PTR.py +++ b/lib/dns/rdtypes/ANY/PTR.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.nsbase +import dns.immutable +@dns.immutable.immutable class PTR(dns.rdtypes.nsbase.NSBase): """PTR record""" diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index 7446de6d..a4e2297d 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -16,10 +16,12 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.exception +import dns.immutable import dns.rdata import dns.name +@dns.immutable.immutable class RP(dns.rdata.Rdata): """RP record""" @@ -30,8 +32,8 @@ class RP(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, mbox, txt): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'mbox', mbox) - object.__setattr__(self, 'txt', txt) + self.mbox = self._as_name(mbox) + self.txt = self._as_name(txt) def to_text(self, origin=None, relativize=True, **kw): mbox = self.mbox.choose_relativity(origin, relativize) @@ -43,7 +45,6 @@ class RP(dns.rdata.Rdata): relativize_to=None): mbox = tok.get_name(origin, relativize, relativize_to) txt = tok.get_name(origin, relativize, relativize_to) - tok.get_eol() return cls(rdclass, rdtype, mbox, txt) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/RRSIG.py b/lib/dns/rdtypes/ANY/RRSIG.py index 2077d905..d050ccc6 100644 --- a/lib/dns/rdtypes/ANY/RRSIG.py +++ b/lib/dns/rdtypes/ANY/RRSIG.py @@ -21,6 +21,7 @@ import struct import time import dns.dnssec +import dns.immutable import dns.exception import dns.rdata import dns.rdatatype @@ -50,6 +51,7 @@ def posixtime_to_sigtime(what): return time.strftime('%Y%m%d%H%M%S', time.gmtime(what)) +@dns.immutable.immutable class RRSIG(dns.rdata.Rdata): """RRSIG record""" @@ -62,15 +64,15 @@ class RRSIG(dns.rdata.Rdata): original_ttl, expiration, inception, key_tag, signer, signature): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'type_covered', type_covered) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'labels', labels) - object.__setattr__(self, 'original_ttl', original_ttl) - object.__setattr__(self, 'expiration', expiration) - object.__setattr__(self, 'inception', inception) - object.__setattr__(self, 'key_tag', key_tag) - object.__setattr__(self, 'signer', signer) - object.__setattr__(self, 'signature', signature) + self.type_covered = self._as_rdatatype(type_covered) + self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.labels = self._as_uint8(labels) + self.original_ttl = self._as_ttl(original_ttl) + self.expiration = self._as_uint32(expiration) + self.inception = self._as_uint32(inception) + self.key_tag = self._as_uint16(key_tag) + self.signer = self._as_name(signer) + self.signature = self._as_bytes(signature) def covers(self): return self.type_covered @@ -85,7 +87,7 @@ class RRSIG(dns.rdata.Rdata): posixtime_to_sigtime(self.inception), self.key_tag, self.signer.choose_relativity(origin, relativize), - dns.rdata._base64ify(self.signature) + dns.rdata._base64ify(self.signature, **kw) ) @classmethod diff --git a/lib/dns/rdtypes/ANY/RT.py b/lib/dns/rdtypes/ANY/RT.py index d0feb79e..8d9c6bd0 100644 --- a/lib/dns/rdtypes/ANY/RT.py +++ b/lib/dns/rdtypes/ANY/RT.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.mxbase +import dns.immutable +@dns.immutable.immutable class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX): """RT record""" diff --git a/lib/dns/rdtypes/ANY/SMIMEA.py b/lib/dns/rdtypes/ANY/SMIMEA.py new file mode 100644 index 00000000..55d87bf8 --- /dev/null +++ b/lib/dns/rdtypes/ANY/SMIMEA.py @@ -0,0 +1,9 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.immutable +import dns.rdtypes.tlsabase + + +@dns.immutable.immutable +class SMIMEA(dns.rdtypes.tlsabase.TLSABase): + """SMIMEA record""" diff --git a/lib/dns/rdtypes/ANY/SOA.py b/lib/dns/rdtypes/ANY/SOA.py index e93274ed..7ce88652 100644 --- a/lib/dns/rdtypes/ANY/SOA.py +++ b/lib/dns/rdtypes/ANY/SOA.py @@ -18,10 +18,12 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.name +@dns.immutable.immutable class SOA(dns.rdata.Rdata): """SOA record""" @@ -34,13 +36,13 @@ class SOA(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'mname', mname) - object.__setattr__(self, 'rname', rname) - object.__setattr__(self, 'serial', serial) - object.__setattr__(self, 'refresh', refresh) - object.__setattr__(self, 'retry', retry) - object.__setattr__(self, 'expire', expire) - object.__setattr__(self, 'minimum', minimum) + self.mname = self._as_name(mname) + self.rname = self._as_name(rname) + self.serial = self._as_uint32(serial) + self.refresh = self._as_ttl(refresh) + self.retry = self._as_ttl(retry) + self.expire = self._as_ttl(expire) + self.minimum = self._as_ttl(minimum) def to_text(self, origin=None, relativize=True, **kw): mname = self.mname.choose_relativity(origin, relativize) @@ -59,7 +61,6 @@ class SOA(dns.rdata.Rdata): retry = tok.get_ttl() expire = tok.get_ttl() minimum = tok.get_ttl() - tok.get_eol() return cls(rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum) diff --git a/lib/dns/rdtypes/ANY/SPF.py b/lib/dns/rdtypes/ANY/SPF.py index f1f6834e..1190e0de 100644 --- a/lib/dns/rdtypes/ANY/SPF.py +++ b/lib/dns/rdtypes/ANY/SPF.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.txtbase +import dns.immutable +@dns.immutable.immutable class SPF(dns.rdtypes.txtbase.TXTBase): """SPF record""" diff --git a/lib/dns/rdtypes/ANY/SSHFP.py b/lib/dns/rdtypes/ANY/SSHFP.py index a3cc0039..cc035195 100644 --- a/lib/dns/rdtypes/ANY/SSHFP.py +++ b/lib/dns/rdtypes/ANY/SSHFP.py @@ -19,9 +19,11 @@ import struct import binascii import dns.rdata +import dns.immutable import dns.rdatatype +@dns.immutable.immutable class SSHFP(dns.rdata.Rdata): """SSHFP record""" @@ -33,15 +35,18 @@ class SSHFP(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'fp_type', fp_type) - object.__setattr__(self, 'fingerprint', fingerprint) + self.algorithm = self._as_uint8(algorithm) + self.fp_type = self._as_uint8(fp_type) + self.fingerprint = self._as_bytes(fingerprint, True) def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop('chunksize', 128) return '%d %d %s' % (self.algorithm, self.fp_type, dns.rdata._hexify(self.fingerprint, - chunksize=128)) + chunksize=chunksize, + **kw)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, diff --git a/lib/dns/rdtypes/ANY/TKEY.py b/lib/dns/rdtypes/ANY/TKEY.py new file mode 100644 index 00000000..f8c47372 --- /dev/null +++ b/lib/dns/rdtypes/ANY/TKEY.py @@ -0,0 +1,118 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import struct + +import dns.dnssec +import dns.immutable +import dns.exception +import dns.rdata + + +@dns.immutable.immutable +class TKEY(dns.rdata.Rdata): + + """TKEY Record""" + + __slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error', + 'key', 'other'] + + def __init__(self, rdclass, rdtype, algorithm, inception, expiration, + mode, error, key, other=b''): + super().__init__(rdclass, rdtype) + self.algorithm = self._as_name(algorithm) + self.inception = self._as_uint32(inception) + self.expiration = self._as_uint32(expiration) + self.mode = self._as_uint16(mode) + self.error = self._as_uint16(error) + self.key = self._as_bytes(key) + self.other = self._as_bytes(other) + + def to_text(self, origin=None, relativize=True, **kw): + _algorithm = self.algorithm.choose_relativity(origin, relativize) + text = '%s %u %u %u %u %s' % (str(_algorithm), self.inception, + self.expiration, self.mode, self.error, + dns.rdata._base64ify(self.key, 0)) + if len(self.other) > 0: + text += ' %s' % (dns.rdata._base64ify(self.other, 0)) + + return text + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + algorithm = tok.get_name(relativize=False) + inception = tok.get_uint32() + expiration = tok.get_uint32() + mode = tok.get_uint16() + error = tok.get_uint16() + key_b64 = tok.get_string().encode() + key = base64.b64decode(key_b64) + other_b64 = tok.concatenate_remaining_identifiers().encode() + other = base64.b64decode(other_b64) + + return cls(rdclass, rdtype, algorithm, inception, expiration, mode, + error, key, other) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.algorithm.to_wire(file, compress, origin) + file.write(struct.pack("!IIHH", self.inception, self.expiration, + self.mode, self.error)) + file.write(struct.pack("!H", len(self.key))) + file.write(self.key) + file.write(struct.pack("!H", len(self.other))) + if len(self.other) > 0: + file.write(self.other) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + algorithm = parser.get_name(origin) + inception, expiration, mode, error = parser.get_struct("!IIHH") + key = parser.get_counted_bytes(2) + other = parser.get_counted_bytes(2) + + return cls(rdclass, rdtype, algorithm, inception, expiration, mode, + error, key, other) + + # Constants for the mode field - from RFC 2930: + # 2.5 The Mode Field + # + # The mode field specifies the general scheme for key agreement or + # the purpose of the TKEY DNS message. Servers and resolvers + # supporting this specification MUST implement the Diffie-Hellman key + # agreement mode and the key deletion mode for queries. All other + # modes are OPTIONAL. A server supporting TKEY that receives a TKEY + # request with a mode it does not support returns the BADMODE error. + # The following values of the Mode octet are defined, available, or + # reserved: + # + # Value Description + # ----- ----------- + # 0 - reserved, see section 7 + # 1 server assignment + # 2 Diffie-Hellman exchange + # 3 GSS-API negotiation + # 4 resolver assignment + # 5 key deletion + # 6-65534 - available, see section 7 + # 65535 - reserved, see section 7 + SERVER_ASSIGNMENT = 1 + DIFFIE_HELLMAN_EXCHANGE = 2 + GSSAPI_NEGOTIATION = 3 + RESOLVER_ASSIGNMENT = 4 + KEY_DELETION = 5 diff --git a/lib/dns/rdtypes/ANY/TLSA.py b/lib/dns/rdtypes/ANY/TLSA.py index 9c9c8662..c9ba1991 100644 --- a/lib/dns/rdtypes/ANY/TLSA.py +++ b/lib/dns/rdtypes/ANY/TLSA.py @@ -1,67 +1,10 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose with or without fee is hereby granted, -# provided that the above copyright notice and this permission notice -# appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -import struct -import binascii - -import dns.rdata -import dns.rdatatype +import dns.immutable +import dns.rdtypes.tlsabase -class TLSA(dns.rdata.Rdata): +@dns.immutable.immutable +class TLSA(dns.rdtypes.tlsabase.TLSABase): """TLSA record""" - - # see: RFC 6698 - - __slots__ = ['usage', 'selector', 'mtype', 'cert'] - - def __init__(self, rdclass, rdtype, usage, selector, - mtype, cert): - super().__init__(rdclass, rdtype) - object.__setattr__(self, 'usage', usage) - object.__setattr__(self, 'selector', selector) - object.__setattr__(self, 'mtype', mtype) - object.__setattr__(self, 'cert', cert) - - def to_text(self, origin=None, relativize=True, **kw): - return '%d %d %d %s' % (self.usage, - self.selector, - self.mtype, - dns.rdata._hexify(self.cert, - chunksize=128)) - - @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): - usage = tok.get_uint8() - selector = tok.get_uint8() - mtype = tok.get_uint8() - cert = tok.concatenate_remaining_identifiers().encode() - cert = binascii.unhexlify(cert) - return cls(rdclass, rdtype, usage, selector, mtype, cert) - - def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!BBB", self.usage, self.selector, self.mtype) - file.write(header) - file.write(self.cert) - - @classmethod - def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct("BBB") - cert = parser.get_remaining() - return cls(rdclass, rdtype, header[0], header[1], header[2], cert) diff --git a/lib/dns/rdtypes/ANY/TSIG.py b/lib/dns/rdtypes/ANY/TSIG.py index 18db4c9e..b43a78f1 100644 --- a/lib/dns/rdtypes/ANY/TSIG.py +++ b/lib/dns/rdtypes/ANY/TSIG.py @@ -15,12 +15,16 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +import base64 import struct import dns.exception +import dns.immutable +import dns.rcode import dns.rdata +@dns.immutable.immutable class TSIG(dns.rdata.Rdata): """TSIG record""" @@ -52,20 +56,45 @@ class TSIG(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'time_signed', time_signed) - object.__setattr__(self, 'fudge', fudge) - object.__setattr__(self, 'mac', dns.rdata._constify(mac)) - object.__setattr__(self, 'original_id', original_id) - object.__setattr__(self, 'error', error) - object.__setattr__(self, 'other', dns.rdata._constify(other)) + self.algorithm = self._as_name(algorithm) + self.time_signed = self._as_uint48(time_signed) + self.fudge = self._as_uint16(fudge) + self.mac = self._as_bytes(mac) + self.original_id = self._as_uint16(original_id) + self.error = dns.rcode.Rcode.make(error) + self.other = self._as_bytes(other) def to_text(self, origin=None, relativize=True, **kw): algorithm = self.algorithm.choose_relativity(origin, relativize) - return f"{algorithm} {self.fudge} {self.time_signed} " + \ + error = dns.rcode.to_text(self.error, True) + text = f"{algorithm} {self.time_signed} {self.fudge} " + \ f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \ - f"{self.original_id} {self.error} " + \ - f"{len(self.other)} {dns.rdata._base64ify(self.other, 0)}" + f"{self.original_id} {error} {len(self.other)}" + if self.other: + text += f" {dns.rdata._base64ify(self.other, 0)}" + return text + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + algorithm = tok.get_name(relativize=False) + time_signed = tok.get_uint48() + fudge = tok.get_uint16() + mac_len = tok.get_uint16() + mac = base64.b64decode(tok.get_string()) + if len(mac) != mac_len: + raise SyntaxError('invalid MAC') + original_id = tok.get_uint16() + error = dns.rcode.from_text(tok.get_string()) + other_len = tok.get_uint16() + if other_len > 0: + other = base64.b64decode(tok.get_string()) + if len(other) != other_len: + raise SyntaxError('invalid other data') + else: + other = b'' + return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, + original_id, error, other) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.algorithm.to_wire(file, None, origin, False) @@ -81,9 +110,9 @@ class TSIG(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - algorithm = parser.get_name(origin) - (time_hi, time_lo, fudge) = parser.get_struct('!HIH') - time_signed = (time_hi << 32) + time_lo + algorithm = parser.get_name() + time_signed = parser.get_uint48() + fudge = parser.get_uint16() mac = parser.get_counted_bytes(2) (original_id, error) = parser.get_struct('!HH') other = parser.get_counted_bytes(2) diff --git a/lib/dns/rdtypes/ANY/TXT.py b/lib/dns/rdtypes/ANY/TXT.py index c5ae919c..cc4b6611 100644 --- a/lib/dns/rdtypes/ANY/TXT.py +++ b/lib/dns/rdtypes/ANY/TXT.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.txtbase +import dns.immutable +@dns.immutable.immutable class TXT(dns.rdtypes.txtbase.TXTBase): """TXT record""" diff --git a/lib/dns/rdtypes/ANY/URI.py b/lib/dns/rdtypes/ANY/URI.py index 84296f52..524fa1ba 100644 --- a/lib/dns/rdtypes/ANY/URI.py +++ b/lib/dns/rdtypes/ANY/URI.py @@ -19,10 +19,13 @@ import struct import dns.exception +import dns.immutable import dns.rdata +import dns.rdtypes.util import dns.name +@dns.immutable.immutable class URI(dns.rdata.Rdata): """URI record""" @@ -33,14 +36,11 @@ class URI(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, priority, weight, target): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'priority', priority) - object.__setattr__(self, 'weight', weight) - if len(target) < 1: + self.priority = self._as_uint16(priority) + self.weight = self._as_uint16(weight) + self.target = self._as_bytes(target, True) + if len(self.target) == 0: raise dns.exception.SyntaxError("URI target cannot be empty") - if isinstance(target, str): - object.__setattr__(self, 'target', target.encode()) - else: - object.__setattr__(self, 'target', target) def to_text(self, origin=None, relativize=True, **kw): return '%d %d "%s"' % (self.priority, self.weight, @@ -54,7 +54,6 @@ class URI(dns.rdata.Rdata): target = tok.get().unescape() if not (target.is_quoted_string() or target.is_identifier()): raise dns.exception.SyntaxError("URI target must be a string") - tok.get_eol() return cls(rdclass, rdtype, priority, weight, target.value) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -69,3 +68,13 @@ class URI(dns.rdata.Rdata): if len(target) == 0: raise dns.exception.FormError('URI target may not be empty') return cls(rdclass, rdtype, priority, weight, target) + + def _processing_priority(self): + return self.priority + + def _processing_weight(self): + return self.weight + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.weighted_processing_order(iterable) diff --git a/lib/dns/rdtypes/ANY/X25.py b/lib/dns/rdtypes/ANY/X25.py index 214f1dca..4f7230c0 100644 --- a/lib/dns/rdtypes/ANY/X25.py +++ b/lib/dns/rdtypes/ANY/X25.py @@ -18,10 +18,12 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class X25(dns.rdata.Rdata): """X25 record""" @@ -32,10 +34,7 @@ class X25(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) - if isinstance(address, str): - object.__setattr__(self, 'address', address.encode()) - else: - object.__setattr__(self, 'address', address) + self.address = self._as_bytes(address, True, 255) def to_text(self, origin=None, relativize=True, **kw): return '"%s"' % dns.rdata._escapify(self.address) @@ -44,7 +43,6 @@ class X25(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): address = tok.get_string() - tok.get_eol() return cls(rdclass, rdtype, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/ANY/ZONEMD.py b/lib/dns/rdtypes/ANY/ZONEMD.py new file mode 100644 index 00000000..035f7b32 --- /dev/null +++ b/lib/dns/rdtypes/ANY/ZONEMD.py @@ -0,0 +1,65 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct +import binascii + +import dns.immutable +import dns.rdata +import dns.rdatatype +import dns.zone + + +@dns.immutable.immutable +class ZONEMD(dns.rdata.Rdata): + + """ZONEMD record""" + + # See RFC 8976 + + __slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest'] + + def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest): + super().__init__(rdclass, rdtype) + self.serial = self._as_uint32(serial) + self.scheme = dns.zone.DigestScheme.make(scheme) + self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm) + self.digest = self._as_bytes(digest) + + if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2 + raise ValueError('scheme 0 is reserved') + if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3 + raise ValueError('hash_algorithm 0 is reserved') + + hasher = dns.zone._digest_hashers.get(self.hash_algorithm) + if hasher and hasher().digest_size != len(self.digest): + raise ValueError('digest length inconsistent with hash algorithm') + + def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop('chunksize', 128) + return '%d %d %d %s' % (self.serial, self.scheme, self.hash_algorithm, + dns.rdata._hexify(self.digest, + chunksize=chunksize, + **kw)) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + serial = tok.get_uint32() + scheme = tok.get_uint8() + hash_algorithm = tok.get_uint8() + digest = tok.concatenate_remaining_identifiers().encode() + digest = binascii.unhexlify(digest) + return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!IBB", self.serial, self.scheme, + self.hash_algorithm) + file.write(header) + file.write(self.digest) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!IBB") + digest = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], header[2], digest) diff --git a/lib/dns/rdtypes/ANY/__init__.py b/lib/dns/rdtypes/ANY/__init__.py index ea704c86..6c56baff 100644 --- a/lib/dns/rdtypes/ANY/__init__.py +++ b/lib/dns/rdtypes/ANY/__init__.py @@ -19,6 +19,7 @@ __all__ = [ 'AFSDB', + 'AMTRELAY', 'AVC', 'CAA', 'CDNSKEY', @@ -38,6 +39,7 @@ __all__ = [ 'ISDN', 'LOC', 'MX', + 'NINFO', 'NS', 'NSEC', 'NSEC3', @@ -48,12 +50,15 @@ __all__ = [ 'RP', 'RRSIG', 'RT', + 'SMIMEA', 'SOA', 'SPF', 'SSHFP', + 'TKEY', 'TLSA', 'TSIG', 'TXT', 'URI', 'X25', + 'ZONEMD', ] diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py index b738ac6c..828701b4 100644 --- a/lib/dns/rdtypes/CH/A.py +++ b/lib/dns/rdtypes/CH/A.py @@ -15,9 +15,12 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import struct +import dns.rdtypes.mxbase +import dns.immutable + +@dns.immutable.immutable class A(dns.rdata.Rdata): """A record for Chaosnet""" @@ -29,8 +32,8 @@ class A(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, domain, address): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'domain', domain) - object.__setattr__(self, 'address', address) + self.domain = self._as_name(domain) + self.address = self._as_uint16(address) def to_text(self, origin=None, relativize=True, **kw): domain = self.domain.choose_relativity(origin, relativize) @@ -41,7 +44,6 @@ class A(dns.rdata.Rdata): relativize_to=None): domain = tok.get_name(origin, relativize, relativize_to) address = tok.get_uint16(base=8) - tok.get_eol() return cls(rdclass, rdtype, domain, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/IN/A.py b/lib/dns/rdtypes/IN/A.py index 8b71e329..74b591ef 100644 --- a/lib/dns/rdtypes/IN/A.py +++ b/lib/dns/rdtypes/IN/A.py @@ -16,11 +16,13 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.exception +import dns.immutable import dns.ipv4 import dns.rdata import dns.tokenizer +@dns.immutable.immutable class A(dns.rdata.Rdata): """A record.""" @@ -29,9 +31,7 @@ class A(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) - # check that it's OK - dns.ipv4.inet_aton(address) - object.__setattr__(self, 'address', address) + self.address = self._as_ipv4_address(address) def to_text(self, origin=None, relativize=True, **kw): return self.address @@ -40,7 +40,6 @@ class A(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): address = tok.get_identifier() - tok.get_eol() return cls(rdclass, rdtype, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -48,5 +47,5 @@ class A(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - address = dns.ipv4.inet_ntoa(parser.get_remaining()) + address = parser.get_remaining() return cls(rdclass, rdtype, address) diff --git a/lib/dns/rdtypes/IN/AAAA.py b/lib/dns/rdtypes/IN/AAAA.py index 08f9d679..2d3ec902 100644 --- a/lib/dns/rdtypes/IN/AAAA.py +++ b/lib/dns/rdtypes/IN/AAAA.py @@ -16,11 +16,13 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.exception +import dns.immutable import dns.ipv6 import dns.rdata import dns.tokenizer +@dns.immutable.immutable class AAAA(dns.rdata.Rdata): """AAAA record.""" @@ -29,9 +31,7 @@ class AAAA(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) - # check that it's OK - dns.ipv6.inet_aton(address) - object.__setattr__(self, 'address', address) + self.address = self._as_ipv6_address(address) def to_text(self, origin=None, relativize=True, **kw): return self.address @@ -40,7 +40,6 @@ class AAAA(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): address = tok.get_identifier() - tok.get_eol() return cls(rdclass, rdtype, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -48,5 +47,5 @@ class AAAA(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - address = dns.ipv6.inet_ntoa(parser.get_remaining()) + address = parser.get_remaining() return cls(rdclass, rdtype, address) diff --git a/lib/dns/rdtypes/IN/APL.py b/lib/dns/rdtypes/IN/APL.py index ab7fe4bc..ae94fb24 100644 --- a/lib/dns/rdtypes/IN/APL.py +++ b/lib/dns/rdtypes/IN/APL.py @@ -20,11 +20,13 @@ import codecs import struct import dns.exception +import dns.immutable import dns.ipv4 import dns.ipv6 import dns.rdata import dns.tokenizer +@dns.immutable.immutable class APLItem: """An APL list item.""" @@ -32,10 +34,17 @@ class APLItem: __slots__ = ['family', 'negation', 'address', 'prefix'] def __init__(self, family, negation, address, prefix): - self.family = family - self.negation = negation - self.address = address - self.prefix = prefix + self.family = dns.rdata.Rdata._as_uint16(family) + self.negation = dns.rdata.Rdata._as_bool(negation) + if self.family == 1: + self.address = dns.rdata.Rdata._as_ipv4_address(address) + self.prefix = dns.rdata.Rdata._as_int(prefix, 0, 32) + elif self.family == 2: + self.address = dns.rdata.Rdata._as_ipv6_address(address) + self.prefix = dns.rdata.Rdata._as_int(prefix, 0, 128) + else: + self.address = dns.rdata.Rdata._as_bytes(address, max_length=127) + self.prefix = dns.rdata.Rdata._as_uint8(prefix) def __str__(self): if self.negation: @@ -68,6 +77,7 @@ class APLItem: file.write(address) +@dns.immutable.immutable class APL(dns.rdata.Rdata): """APL record.""" @@ -78,7 +88,10 @@ class APL(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, items): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'items', dns.rdata._constify(items)) + for item in items: + if not isinstance(item, APLItem): + raise ValueError('item not an APLItem') + self.items = tuple(items) def to_text(self, origin=None, relativize=True, **kw): return ' '.join(map(str, self.items)) @@ -87,11 +100,8 @@ class APL(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): items = [] - while True: - token = tok.get().unescape() - if token.is_eol_or_eof(): - break - item = token.value + for token in tok.get_remaining(): + item = token.unescape().value if item[0] == '!': negation = True item = item[1:] @@ -127,11 +137,9 @@ class APL(dns.rdata.Rdata): if header[0] == 1: if l < 4: address += b'\x00' * (4 - l) - address = dns.ipv4.inet_ntoa(address) elif header[0] == 2: if l < 16: address += b'\x00' * (16 - l) - address = dns.ipv6.inet_ntoa(address) else: # # This isn't really right according to the RFC, but it diff --git a/lib/dns/rdtypes/IN/DHCID.py b/lib/dns/rdtypes/IN/DHCID.py index 6f66eb89..a9185989 100644 --- a/lib/dns/rdtypes/IN/DHCID.py +++ b/lib/dns/rdtypes/IN/DHCID.py @@ -18,8 +18,10 @@ import base64 import dns.exception +import dns.immutable +@dns.immutable.immutable class DHCID(dns.rdata.Rdata): """DHCID record""" @@ -30,10 +32,10 @@ class DHCID(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, data): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'data', data) + self.data = self._as_bytes(data) def to_text(self, origin=None, relativize=True, **kw): - return dns.rdata._base64ify(self.data) + return dns.rdata._base64ify(self.data, **kw) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, diff --git a/lib/dns/rdtypes/IN/HTTPS.py b/lib/dns/rdtypes/IN/HTTPS.py new file mode 100644 index 00000000..6a67e8ed --- /dev/null +++ b/lib/dns/rdtypes/IN/HTTPS.py @@ -0,0 +1,8 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.rdtypes.svcbbase +import dns.immutable + +@dns.immutable.immutable +class HTTPS(dns.rdtypes.svcbbase.SVCBBase): + """HTTPS record""" diff --git a/lib/dns/rdtypes/IN/IPSECKEY.py b/lib/dns/rdtypes/IN/IPSECKEY.py index 182ad2cb..d1d39438 100644 --- a/lib/dns/rdtypes/IN/IPSECKEY.py +++ b/lib/dns/rdtypes/IN/IPSECKEY.py @@ -19,12 +19,14 @@ import struct import base64 import dns.exception +import dns.immutable import dns.rdtypes.util class Gateway(dns.rdtypes.util.Gateway): name = 'IPSECKEY gateway' +@dns.immutable.immutable class IPSECKEY(dns.rdata.Rdata): """IPSECKEY record""" @@ -36,19 +38,19 @@ class IPSECKEY(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key): super().__init__(rdclass, rdtype) - Gateway(gateway_type, gateway).check() - object.__setattr__(self, 'precedence', precedence) - object.__setattr__(self, 'gateway_type', gateway_type) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'gateway', gateway) - object.__setattr__(self, 'key', key) + gateway = Gateway(gateway_type, gateway) + self.precedence = self._as_uint8(precedence) + self.gateway_type = gateway.type + self.algorithm = self._as_uint8(algorithm) + self.gateway = gateway.gateway + self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize) return '%d %d %d %s %s' % (self.precedence, self.gateway_type, self.algorithm, gateway, - dns.rdata._base64ify(self.key)) + dns.rdata._base64ify(self.key, **kw)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, @@ -56,12 +58,12 @@ class IPSECKEY(dns.rdata.Rdata): precedence = tok.get_uint8() gateway_type = tok.get_uint8() algorithm = tok.get_uint8() - gateway = Gateway(gateway_type).from_text(tok, origin, relativize, - relativize_to) + gateway = Gateway.from_text(gateway_type, tok, origin, relativize, + relativize_to) b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) return cls(rdclass, rdtype, precedence, gateway_type, algorithm, - gateway, key) + gateway.gateway, key) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack("!BBB", self.precedence, self.gateway_type, @@ -75,7 +77,7 @@ class IPSECKEY(dns.rdata.Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): header = parser.get_struct('!BBB') gateway_type = header[1] - gateway = Gateway(gateway_type).from_wire_parser(parser, origin) + gateway = Gateway.from_wire_parser(gateway_type, parser, origin) key = parser.get_remaining() return cls(rdclass, rdtype, header[0], gateway_type, header[2], - gateway, key) + gateway.gateway, key) diff --git a/lib/dns/rdtypes/IN/KX.py b/lib/dns/rdtypes/IN/KX.py index ebf8fd77..c27e9215 100644 --- a/lib/dns/rdtypes/IN/KX.py +++ b/lib/dns/rdtypes/IN/KX.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.mxbase +import dns.immutable +@dns.immutable.immutable class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX): """KX record""" diff --git a/lib/dns/rdtypes/IN/NAPTR.py b/lib/dns/rdtypes/IN/NAPTR.py index 48d43562..b107974d 100644 --- a/lib/dns/rdtypes/IN/NAPTR.py +++ b/lib/dns/rdtypes/IN/NAPTR.py @@ -18,8 +18,10 @@ import struct import dns.exception +import dns.immutable import dns.name import dns.rdata +import dns.rdtypes.util def _write_string(file, s): @@ -29,12 +31,7 @@ def _write_string(file, s): file.write(s) -def _sanitize(value): - if isinstance(value, str): - return value.encode() - return value - - +@dns.immutable.immutable class NAPTR(dns.rdata.Rdata): """NAPTR record""" @@ -47,12 +44,12 @@ class NAPTR(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, order, preference, flags, service, regexp, replacement): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'flags', _sanitize(flags)) - object.__setattr__(self, 'service', _sanitize(service)) - object.__setattr__(self, 'regexp', _sanitize(regexp)) - object.__setattr__(self, 'order', order) - object.__setattr__(self, 'preference', preference) - object.__setattr__(self, 'replacement', replacement) + self.flags = self._as_bytes(flags, True, 255) + self.service = self._as_bytes(service, True, 255) + self.regexp = self._as_bytes(regexp, True, 255) + self.order = self._as_uint16(order) + self.preference = self._as_uint16(preference) + self.replacement = self._as_name(replacement) def to_text(self, origin=None, relativize=True, **kw): replacement = self.replacement.choose_relativity(origin, relativize) @@ -72,7 +69,6 @@ class NAPTR(dns.rdata.Rdata): service = tok.get_string() regexp = tok.get_string() replacement = tok.get_name(origin, relativize, relativize_to) - tok.get_eol() return cls(rdclass, rdtype, order, preference, flags, service, regexp, replacement) @@ -88,9 +84,16 @@ class NAPTR(dns.rdata.Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (order, preference) = parser.get_struct('!HH') strings = [] - for i in range(3): + for _ in range(3): s = parser.get_counted_bytes() strings.append(s) replacement = parser.get_name(origin) return cls(rdclass, rdtype, order, preference, strings[0], strings[1], strings[2], replacement) + + def _processing_priority(self): + return (self.order, self.preference) + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) diff --git a/lib/dns/rdtypes/IN/NSAP.py b/lib/dns/rdtypes/IN/NSAP.py index 227465fa..23ae9b1a 100644 --- a/lib/dns/rdtypes/IN/NSAP.py +++ b/lib/dns/rdtypes/IN/NSAP.py @@ -18,10 +18,12 @@ import binascii import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class NSAP(dns.rdata.Rdata): """NSAP record.""" @@ -32,7 +34,7 @@ class NSAP(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'address', address) + self.address = self._as_bytes(address) def to_text(self, origin=None, relativize=True, **kw): return "0x%s" % binascii.hexlify(self.address).decode() @@ -41,7 +43,6 @@ class NSAP(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): address = tok.get_string() - tok.get_eol() if address[0:2] != '0x': raise dns.exception.SyntaxError('string does not start with 0x') address = address[2:].replace('.', '') diff --git a/lib/dns/rdtypes/IN/NSAP_PTR.py b/lib/dns/rdtypes/IN/NSAP_PTR.py index a5b66c80..57dadd47 100644 --- a/lib/dns/rdtypes/IN/NSAP_PTR.py +++ b/lib/dns/rdtypes/IN/NSAP_PTR.py @@ -16,8 +16,10 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.nsbase +import dns.immutable +@dns.immutable.immutable class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS): """NSAP-PTR record""" diff --git a/lib/dns/rdtypes/IN/PX.py b/lib/dns/rdtypes/IN/PX.py index 946d79f8..113d409c 100644 --- a/lib/dns/rdtypes/IN/PX.py +++ b/lib/dns/rdtypes/IN/PX.py @@ -18,10 +18,13 @@ import struct import dns.exception +import dns.immutable import dns.rdata +import dns.rdtypes.util import dns.name +@dns.immutable.immutable class PX(dns.rdata.Rdata): """PX record.""" @@ -32,9 +35,9 @@ class PX(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, preference, map822, mapx400): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'preference', preference) - object.__setattr__(self, 'map822', map822) - object.__setattr__(self, 'mapx400', mapx400) + self.preference = self._as_uint16(preference) + self.map822 = self._as_name(map822) + self.mapx400 = self._as_name(mapx400) def to_text(self, origin=None, relativize=True, **kw): map822 = self.map822.choose_relativity(origin, relativize) @@ -47,7 +50,6 @@ class PX(dns.rdata.Rdata): preference = tok.get_uint16() map822 = tok.get_name(origin, relativize, relativize_to) mapx400 = tok.get_name(origin, relativize, relativize_to) - tok.get_eol() return cls(rdclass, rdtype, preference, map822, mapx400) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -62,3 +64,10 @@ class PX(dns.rdata.Rdata): map822 = parser.get_name(origin) mapx400 = parser.get_name(origin) return cls(rdclass, rdtype, preference, map822, mapx400) + + def _processing_priority(self): + return self.preference + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) diff --git a/lib/dns/rdtypes/IN/SRV.py b/lib/dns/rdtypes/IN/SRV.py index 485153f4..5b5ff422 100644 --- a/lib/dns/rdtypes/IN/SRV.py +++ b/lib/dns/rdtypes/IN/SRV.py @@ -18,10 +18,13 @@ import struct import dns.exception +import dns.immutable import dns.rdata +import dns.rdtypes.util import dns.name +@dns.immutable.immutable class SRV(dns.rdata.Rdata): """SRV record""" @@ -32,10 +35,10 @@ class SRV(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, priority, weight, port, target): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'priority', priority) - object.__setattr__(self, 'weight', weight) - object.__setattr__(self, 'port', port) - object.__setattr__(self, 'target', target) + self.priority = self._as_uint16(priority) + self.weight = self._as_uint16(weight) + self.port = self._as_uint16(port) + self.target = self._as_name(target) def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) @@ -49,7 +52,6 @@ class SRV(dns.rdata.Rdata): weight = tok.get_uint16() port = tok.get_uint16() target = tok.get_name(origin, relativize, relativize_to) - tok.get_eol() return cls(rdclass, rdtype, priority, weight, port, target) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -62,3 +64,13 @@ class SRV(dns.rdata.Rdata): (priority, weight, port) = parser.get_struct('!HHH') target = parser.get_name(origin) return cls(rdclass, rdtype, priority, weight, port, target) + + def _processing_priority(self): + return self.priority + + def _processing_weight(self): + return self.weight + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.weighted_processing_order(iterable) diff --git a/lib/dns/rdtypes/IN/SVCB.py b/lib/dns/rdtypes/IN/SVCB.py new file mode 100644 index 00000000..14838e16 --- /dev/null +++ b/lib/dns/rdtypes/IN/SVCB.py @@ -0,0 +1,8 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.rdtypes.svcbbase +import dns.immutable + +@dns.immutable.immutable +class SVCB(dns.rdtypes.svcbbase.SVCBBase): + """SVCB record""" diff --git a/lib/dns/rdtypes/IN/WKS.py b/lib/dns/rdtypes/IN/WKS.py index d66d8583..264e45d3 100644 --- a/lib/dns/rdtypes/IN/WKS.py +++ b/lib/dns/rdtypes/IN/WKS.py @@ -19,12 +19,18 @@ import socket import struct import dns.ipv4 +import dns.immutable import dns.rdata -_proto_tcp = socket.getprotobyname('tcp') -_proto_udp = socket.getprotobyname('udp') - +try: + _proto_tcp = socket.getprotobyname('tcp') + _proto_udp = socket.getprotobyname('udp') +except OSError: + # Fall back to defaults in case /etc/protocols is unavailable. + _proto_tcp = 6 + _proto_udp = 17 +@dns.immutable.immutable class WKS(dns.rdata.Rdata): """WKS record""" @@ -35,14 +41,13 @@ class WKS(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, address, protocol, bitmap): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'address', address) - object.__setattr__(self, 'protocol', protocol) - object.__setattr__(self, 'bitmap', dns.rdata._constify(bitmap)) + self.address = self._as_ipv4_address(address) + self.protocol = self._as_uint8(protocol) + self.bitmap = self._as_bytes(bitmap) def to_text(self, origin=None, relativize=True, **kw): bits = [] - for i in range(0, len(self.bitmap)): - byte = self.bitmap[i] + for i, byte in enumerate(self.bitmap): for j in range(0, 8): if byte & (0x80 >> j): bits.append(str(i * 8 + j)) @@ -59,12 +64,10 @@ class WKS(dns.rdata.Rdata): else: protocol = socket.getprotobyname(protocol) bitmap = bytearray() - while 1: - token = tok.get().unescape() - if token.is_eol_or_eof(): - break - if token.value.isdigit(): - serv = int(token.value) + for token in tok.get_remaining(): + value = token.unescape().value + if value.isdigit(): + serv = int(value) else: if protocol != _proto_udp and protocol != _proto_tcp: raise NotImplementedError("protocol must be TCP or UDP") @@ -72,11 +75,11 @@ class WKS(dns.rdata.Rdata): protocol_text = "udp" else: protocol_text = "tcp" - serv = socket.getservbyname(token.value, protocol_text) + serv = socket.getservbyname(value, protocol_text) i = serv // 8 l = len(bitmap) if l < i + 1: - for j in range(l, i + 1): + for _ in range(l, i + 1): bitmap.append(0) bitmap[i] = bitmap[i] | (0x80 >> (serv % 8)) bitmap = dns.rdata._truncate_bitmap(bitmap) @@ -90,7 +93,7 @@ class WKS(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - address = dns.ipv4.inet_ntoa(parser.get_bytes(4)) + address = parser.get_bytes(4) protocol = parser.get_uint8() bitmap = parser.get_remaining() return cls(rdclass, rdtype, address, protocol, bitmap) diff --git a/lib/dns/rdtypes/IN/__init__.py b/lib/dns/rdtypes/IN/__init__.py index d7e69c9f..d51b99e7 100644 --- a/lib/dns/rdtypes/IN/__init__.py +++ b/lib/dns/rdtypes/IN/__init__.py @@ -22,6 +22,7 @@ __all__ = [ 'AAAA', 'APL', 'DHCID', + 'HTTPS', 'IPSECKEY', 'KX', 'NAPTR', @@ -29,5 +30,6 @@ __all__ = [ 'NSAP_PTR', 'PX', 'SRV', + 'SVCB', 'WKS', ] diff --git a/lib/dns/rdtypes/__init__.py b/lib/dns/rdtypes/__init__.py index ccc848cf..c3af264e 100644 --- a/lib/dns/rdtypes/__init__.py +++ b/lib/dns/rdtypes/__init__.py @@ -21,8 +21,13 @@ __all__ = [ 'ANY', 'IN', 'CH', + 'dnskeybase', + 'dsbase', 'euibase', 'mxbase', 'nsbase', + 'svcbbase', + 'tlsabase', + 'txtbase', 'util' ] diff --git a/lib/dns/rdtypes/dnskeybase.py b/lib/dns/rdtypes/dnskeybase.py index 0243d6f3..788bb2bf 100644 --- a/lib/dns/rdtypes/dnskeybase.py +++ b/lib/dns/rdtypes/dnskeybase.py @@ -20,6 +20,7 @@ import enum import struct import dns.exception +import dns.immutable import dns.dnssec import dns.rdata @@ -31,9 +32,8 @@ class Flag(enum.IntFlag): REVOKE = 0x0080 ZONE = 0x0100 -globals().update(Flag.__members__) - +@dns.immutable.immutable class DNSKEYBase(dns.rdata.Rdata): """Base class for rdata that is like a DNSKEY record""" @@ -42,21 +42,21 @@ class DNSKEYBase(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'flags', flags) - object.__setattr__(self, 'protocol', protocol) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'key', key) + self.flags = self._as_uint16(flags) + self.protocol = self._as_uint8(protocol) + self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %d %s' % (self.flags, self.protocol, self.algorithm, - dns.rdata._base64ify(self.key)) + dns.rdata._base64ify(self.key, **kw)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): flags = tok.get_uint16() protocol = tok.get_uint8() - algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) + algorithm = tok.get_string() b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) return cls(rdclass, rdtype, flags, protocol, algorithm, key) @@ -72,3 +72,11 @@ class DNSKEYBase(dns.rdata.Rdata): key = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], key) + +### BEGIN generated Flag constants + +SEP = Flag.SEP +REVOKE = Flag.REVOKE +ZONE = Flag.ZONE + +### END generated Flag constants diff --git a/lib/dns/rdtypes/dnskeybase.pyi b/lib/dns/rdtypes/dnskeybase.pyi new file mode 100644 index 00000000..1b999cfd --- /dev/null +++ b/lib/dns/rdtypes/dnskeybase.pyi @@ -0,0 +1,38 @@ +from typing import Set, Any + +SEP : int +REVOKE : int +ZONE : int + +def flags_to_text_set(flags : int) -> Set[str]: + ... + +def flags_from_text_set(texts_set) -> int: + ... + +from .. import rdata + +class DNSKEYBase(rdata.Rdata): + def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): + self.flags : int + self.protocol : int + self.key : str + self.algorithm : int + + def to_text(self, origin : Any = None, relativize=True, **kw : Any): + ... + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + ... + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + ... + + @classmethod + def from_parser(cls, rdclass, rdtype, parser, origin=None): + ... + + def flags_to_text_set(self) -> Set[str]: + ... diff --git a/lib/dns/rdtypes/dsbase.py b/lib/dns/rdtypes/dsbase.py index d7850bee..0c2e7471 100644 --- a/lib/dns/rdtypes/dsbase.py +++ b/lib/dns/rdtypes/dsbase.py @@ -19,35 +19,54 @@ import struct import binascii import dns.dnssec +import dns.immutable import dns.rdata import dns.rdatatype +@dns.immutable.immutable class DSBase(dns.rdata.Rdata): """Base class for rdata that is like a DS record""" __slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest'] + # Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml + _digest_length_by_type = { + 1: 20, # SHA-1, RFC 3658 Sec. 2.4 + 2: 32, # SHA-256, RFC 4509 Sec. 2.2 + 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1 + 4: 48, # SHA-384, RFC 6605 Sec. 2 + } + def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'key_tag', key_tag) - object.__setattr__(self, 'algorithm', algorithm) - object.__setattr__(self, 'digest_type', digest_type) - object.__setattr__(self, 'digest', digest) + self.key_tag = self._as_uint16(key_tag) + self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.digest_type = self._as_uint8(digest_type) + self.digest = self._as_bytes(digest) + try: + if len(self.digest) != self._digest_length_by_type[self.digest_type]: + raise ValueError('digest length inconsistent with digest type') + except KeyError: + if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4 + raise ValueError('digest type 0 is reserved') def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop('chunksize', 128) return '%d %d %d %s' % (self.key_tag, self.algorithm, self.digest_type, dns.rdata._hexify(self.digest, - chunksize=128)) + chunksize=chunksize, + **kw)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): key_tag = tok.get_uint16() - algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) + algorithm = tok.get_string() digest_type = tok.get_uint8() digest = tok.concatenate_remaining_identifiers().encode() digest = binascii.unhexlify(digest) diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index c1677a81..48b69bd3 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -17,8 +17,10 @@ import binascii import dns.rdata +import dns.immutable +@dns.immutable.immutable class EUIBase(dns.rdata.Rdata): """EUIxx record""" @@ -32,19 +34,18 @@ class EUIBase(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, eui): super().__init__(rdclass, rdtype) - if len(eui) != self.byte_len: + self.eui = self._as_bytes(eui) + if len(self.eui) != self.byte_len: raise dns.exception.FormError('EUI%s rdata has to have %s bytes' % (self.byte_len * 8, self.byte_len)) - object.__setattr__(self, 'eui', eui) def to_text(self, origin=None, relativize=True, **kw): - return dns.rdata._hexify(self.eui, chunksize=2).replace(' ', '-') + return dns.rdata._hexify(self.eui, chunksize=2, separator=b'-', **kw) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): text = tok.get_string() - tok.get_eol() if len(text) != cls.text_len: raise dns.exception.SyntaxError( 'Input text must have %s characters' % cls.text_len) diff --git a/lib/dns/rdtypes/mxbase.py b/lib/dns/rdtypes/mxbase.py index d6a6efed..56418234 100644 --- a/lib/dns/rdtypes/mxbase.py +++ b/lib/dns/rdtypes/mxbase.py @@ -20,10 +20,13 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.name +import dns.rdtypes.util +@dns.immutable.immutable class MXBase(dns.rdata.Rdata): """Base class for rdata that is like an MX record.""" @@ -32,8 +35,8 @@ class MXBase(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, preference, exchange): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'preference', preference) - object.__setattr__(self, 'exchange', exchange) + self.preference = self._as_uint16(preference) + self.exchange = self._as_name(exchange) def to_text(self, origin=None, relativize=True, **kw): exchange = self.exchange.choose_relativity(origin, relativize) @@ -44,7 +47,6 @@ class MXBase(dns.rdata.Rdata): relativize_to=None): preference = tok.get_uint16() exchange = tok.get_name(origin, relativize, relativize_to) - tok.get_eol() return cls(rdclass, rdtype, preference, exchange) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -58,7 +60,15 @@ class MXBase(dns.rdata.Rdata): exchange = parser.get_name(origin) return cls(rdclass, rdtype, preference, exchange) + def _processing_priority(self): + return self.preference + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) + + +@dns.immutable.immutable class UncompressedMX(MXBase): """Base class for rdata that is like an MX record, but whose name @@ -69,6 +79,7 @@ class UncompressedMX(MXBase): super()._to_wire(file, None, origin, False) +@dns.immutable.immutable class UncompressedDowncasingMX(MXBase): """Base class for rdata that is like an MX record, but whose name diff --git a/lib/dns/rdtypes/nsbase.py b/lib/dns/rdtypes/nsbase.py index 93d3ee53..b3e25506 100644 --- a/lib/dns/rdtypes/nsbase.py +++ b/lib/dns/rdtypes/nsbase.py @@ -18,10 +18,12 @@ """NS-like base classes.""" import dns.exception +import dns.immutable import dns.rdata import dns.name +@dns.immutable.immutable class NSBase(dns.rdata.Rdata): """Base class for rdata that is like an NS record.""" @@ -30,7 +32,7 @@ class NSBase(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, target): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'target', target) + self.target = self._as_name(target) def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) @@ -40,7 +42,6 @@ class NSBase(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): target = tok.get_name(origin, relativize, relativize_to) - tok.get_eol() return cls(rdclass, rdtype, target) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -52,6 +53,7 @@ class NSBase(dns.rdata.Rdata): return cls(rdclass, rdtype, target) +@dns.immutable.immutable class UncompressedNS(NSBase): """Base class for rdata that is like an NS record, but whose name diff --git a/lib/dns/rdtypes/svcbbase.py b/lib/dns/rdtypes/svcbbase.py new file mode 100644 index 00000000..09d7a52b --- /dev/null +++ b/lib/dns/rdtypes/svcbbase.py @@ -0,0 +1,555 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import base64 +import enum +import io +import struct + +import dns.enum +import dns.exception +import dns.immutable +import dns.ipv4 +import dns.ipv6 +import dns.name +import dns.rdata +import dns.rdtypes.util +import dns.tokenizer +import dns.wire + +# Until there is an RFC, this module is experimental and may be changed in +# incompatible ways. + + +class UnknownParamKey(dns.exception.DNSException): + """Unknown SVCB ParamKey""" + + +class ParamKey(dns.enum.IntEnum): + """SVCB ParamKey""" + + MANDATORY = 0 + ALPN = 1 + NO_DEFAULT_ALPN = 2 + PORT = 3 + IPV4HINT = 4 + ECH = 5 + IPV6HINT = 6 + + @classmethod + def _maximum(cls): + return 65535 + + @classmethod + def _short_name(cls): + return "SVCBParamKey" + + @classmethod + def _prefix(cls): + return "KEY" + + @classmethod + def _unknown_exception_class(cls): + return UnknownParamKey + + +class Emptiness(enum.IntEnum): + NEVER = 0 + ALWAYS = 1 + ALLOWED = 2 + + +def _validate_key(key): + force_generic = False + if isinstance(key, bytes): + # We decode to latin-1 so we get 0-255 as valid and do NOT interpret + # UTF-8 sequences + key = key.decode('latin-1') + if isinstance(key, str): + if key.lower().startswith('key'): + force_generic = True + if key[3:].startswith('0') and len(key) != 4: + # key has leading zeros + raise ValueError('leading zeros in key') + key = key.replace('-', '_') + return (ParamKey.make(key), force_generic) + +def key_to_text(key): + return ParamKey.to_text(key).replace('_', '-').lower() + +# Like rdata escapify, but escapes ',' too. + +_escaped = b'",\\' + +def _escapify(qstring): + text = '' + for c in qstring: + if c in _escaped: + text += '\\' + chr(c) + elif c >= 0x20 and c < 0x7F: + text += chr(c) + else: + text += '\\%03d' % c + return text + +def _unescape(value): + if value == '': + return value + unescaped = b'' + l = len(value) + i = 0 + while i < l: + c = value[i] + i += 1 + if c == '\\': + if i >= l: # pragma: no cover (can't happen via tokenizer get()) + raise dns.exception.UnexpectedEnd + c = value[i] + i += 1 + if c.isdigit(): + if i >= l: + raise dns.exception.UnexpectedEnd + c2 = value[i] + i += 1 + if i >= l: + raise dns.exception.UnexpectedEnd + c3 = value[i] + i += 1 + if not (c2.isdigit() and c3.isdigit()): + raise dns.exception.SyntaxError + codepoint = int(c) * 100 + int(c2) * 10 + int(c3) + if codepoint > 255: + raise dns.exception.SyntaxError + unescaped += b'%c' % (codepoint) + continue + unescaped += c.encode() + return unescaped + + +def _split(value): + l = len(value) + i = 0 + items = [] + unescaped = b'' + while i < l: + c = value[i] + i += 1 + if c == ord('\\'): + if i >= l: # pragma: no cover (can't happen via tokenizer get()) + raise dns.exception.UnexpectedEnd + c = value[i] + i += 1 + unescaped += b'%c' % (c) + elif c == ord(','): + items.append(unescaped) + unescaped = b'' + else: + unescaped += b'%c' % (c) + items.append(unescaped) + return items + + +@dns.immutable.immutable +class Param: + """Abstract base class for SVCB parameters""" + + @classmethod + def emptiness(cls): + return Emptiness.NEVER + + +@dns.immutable.immutable +class GenericParam(Param): + """Generic SVCB parameter + """ + def __init__(self, value): + self.value = dns.rdata.Rdata._as_bytes(value, True) + + @classmethod + def emptiness(cls): + return Emptiness.ALLOWED + + @classmethod + def from_value(cls, value): + if value is None or len(value) == 0: + return None + else: + return cls(_unescape(value)) + + def to_text(self): + return '"' + dns.rdata._escapify(self.value) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + value = parser.get_bytes(parser.remaining()) + if len(value) == 0: + return None + else: + return cls(value) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + file.write(self.value) + + +@dns.immutable.immutable +class MandatoryParam(Param): + def __init__(self, keys): + # check for duplicates + keys = sorted([_validate_key(key)[0] for key in keys]) + prior_k = None + for k in keys: + if k == prior_k: + raise ValueError(f'duplicate key {k:d}') + prior_k = k + if k == ParamKey.MANDATORY: + raise ValueError('listed the mandatory key as mandatory') + self.keys = tuple(keys) + + @classmethod + def from_value(cls, value): + keys = [k.encode() for k in value.split(',')] + return cls(keys) + + def to_text(self): + return '"' + ','.join([key_to_text(key) for key in self.keys]) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + keys = [] + last_key = -1 + while parser.remaining() > 0: + key = parser.get_uint16() + if key < last_key: + raise dns.exception.FormError('manadatory keys not ascending') + last_key = key + keys.append(key) + return cls(keys) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for key in self.keys: + file.write(struct.pack('!H', key)) + + +@dns.immutable.immutable +class ALPNParam(Param): + def __init__(self, ids): + self.ids = dns.rdata.Rdata._as_tuple( + ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False)) + + @classmethod + def from_value(cls, value): + return cls(_split(_unescape(value))) + + def to_text(self): + value = ','.join([_escapify(id) for id in self.ids]) + return '"' + dns.rdata._escapify(value.encode()) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + ids = [] + while parser.remaining() > 0: + id = parser.get_counted_bytes() + ids.append(id) + return cls(ids) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for id in self.ids: + file.write(struct.pack('!B', len(id))) + file.write(id) + + +@dns.immutable.immutable +class NoDefaultALPNParam(Param): + # We don't ever expect to instantiate this class, but we need + # a from_value() and a from_wire_parser(), so we just return None + # from the class methods when things are OK. + + @classmethod + def emptiness(cls): + return Emptiness.ALWAYS + + @classmethod + def from_value(cls, value): + if value is None or value == '': + return None + else: + raise ValueError('no-default-alpn with non-empty value') + + def to_text(self): + raise NotImplementedError # pragma: no cover + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + if parser.remaining() != 0: + raise dns.exception.FormError + return None + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + raise NotImplementedError # pragma: no cover + + +@dns.immutable.immutable +class PortParam(Param): + def __init__(self, port): + self.port = dns.rdata.Rdata._as_uint16(port) + + @classmethod + def from_value(cls, value): + value = int(value) + return cls(value) + + def to_text(self): + return f'"{self.port}"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + port = parser.get_uint16() + return cls(port) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + file.write(struct.pack('!H', self.port)) + + +@dns.immutable.immutable +class IPv4HintParam(Param): + def __init__(self, addresses): + self.addresses = dns.rdata.Rdata._as_tuple( + addresses, dns.rdata.Rdata._as_ipv4_address) + + @classmethod + def from_value(cls, value): + addresses = value.split(',') + return cls(addresses) + + def to_text(self): + return '"' + ','.join(self.addresses) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + addresses = [] + while parser.remaining() > 0: + ip = parser.get_bytes(4) + addresses.append(dns.ipv4.inet_ntoa(ip)) + return cls(addresses) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for address in self.addresses: + file.write(dns.ipv4.inet_aton(address)) + + +@dns.immutable.immutable +class IPv6HintParam(Param): + def __init__(self, addresses): + self.addresses = dns.rdata.Rdata._as_tuple( + addresses, dns.rdata.Rdata._as_ipv6_address) + + @classmethod + def from_value(cls, value): + addresses = value.split(',') + return cls(addresses) + + def to_text(self): + return '"' + ','.join(self.addresses) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + addresses = [] + while parser.remaining() > 0: + ip = parser.get_bytes(16) + addresses.append(dns.ipv6.inet_ntoa(ip)) + return cls(addresses) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for address in self.addresses: + file.write(dns.ipv6.inet_aton(address)) + + +@dns.immutable.immutable +class ECHParam(Param): + def __init__(self, ech): + self.ech = dns.rdata.Rdata._as_bytes(ech, True) + + @classmethod + def from_value(cls, value): + if '\\' in value: + raise ValueError('escape in ECH value') + value = base64.b64decode(value.encode()) + return cls(value) + + def to_text(self): + b64 = base64.b64encode(self.ech).decode('ascii') + return f'"{b64}"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + value = parser.get_bytes(parser.remaining()) + return cls(value) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + file.write(self.ech) + + +_class_for_key = { + ParamKey.MANDATORY: MandatoryParam, + ParamKey.ALPN: ALPNParam, + ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam, + ParamKey.PORT: PortParam, + ParamKey.IPV4HINT: IPv4HintParam, + ParamKey.ECH: ECHParam, + ParamKey.IPV6HINT: IPv6HintParam, +} + + +def _validate_and_define(params, key, value): + (key, force_generic) = _validate_key(_unescape(key)) + if key in params: + raise SyntaxError(f'duplicate key "{key:d}"') + cls = _class_for_key.get(key, GenericParam) + emptiness = cls.emptiness() + if value is None: + if emptiness == Emptiness.NEVER: + raise SyntaxError('value cannot be empty') + value = cls.from_value(value) + else: + if force_generic: + value = cls.from_wire_parser(dns.wire.Parser(_unescape(value))) + else: + value = cls.from_value(value) + params[key] = value + + +@dns.immutable.immutable +class SVCBBase(dns.rdata.Rdata): + + """Base class for SVCB-like records""" + + # see: draft-ietf-dnsop-svcb-https-01 + + __slots__ = ['priority', 'target', 'params'] + + def __init__(self, rdclass, rdtype, priority, target, params): + super().__init__(rdclass, rdtype) + self.priority = self._as_uint16(priority) + self.target = self._as_name(target) + for k, v in params.items(): + k = ParamKey.make(k) + if not isinstance(v, Param) and v is not None: + raise ValueError("not a Param") + self.params = dns.immutable.Dict(params) + # Make sure any paramater listed as mandatory is present in the + # record. + mandatory = params.get(ParamKey.MANDATORY) + if mandatory: + for key in mandatory.keys: + # Note we have to say "not in" as we have None as a value + # so a get() and a not None test would be wrong. + if key not in params: + raise ValueError(f'key {key:d} declared mandatory but not ' + 'present') + # The no-default-alpn parameter requires the alpn parameter. + if ParamKey.NO_DEFAULT_ALPN in params: + if ParamKey.ALPN not in params: + raise ValueError('no-default-alpn present, but alpn missing') + + def to_text(self, origin=None, relativize=True, **kw): + target = self.target.choose_relativity(origin, relativize) + params = [] + for key in sorted(self.params.keys()): + value = self.params[key] + if value is None: + params.append(key_to_text(key)) + else: + kv = key_to_text(key) + '=' + value.to_text() + params.append(kv) + if len(params) > 0: + space = ' ' + else: + space = '' + return '%d %s%s%s' % (self.priority, target, space, ' '.join(params)) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + priority = tok.get_uint16() + target = tok.get_name(origin, relativize, relativize_to) + if priority == 0: + token = tok.get() + if not token.is_eol_or_eof(): + raise SyntaxError('parameters in AliasMode') + tok.unget(token) + params = {} + while True: + token = tok.get() + if token.is_eol_or_eof(): + tok.unget(token) + break + if token.ttype != dns.tokenizer.IDENTIFIER: + raise SyntaxError('parameter is not an identifier') + equals = token.value.find('=') + if equals == len(token.value) - 1: + # 'key=', so next token should be a quoted string without + # any intervening whitespace. + key = token.value[:-1] + token = tok.get(want_leading=True) + if token.ttype != dns.tokenizer.QUOTED_STRING: + raise SyntaxError('whitespace after =') + value = token.value + elif equals > 0: + # key=value + key = token.value[:equals] + value = token.value[equals + 1:] + elif equals == 0: + # =key + raise SyntaxError('parameter cannot start with "="') + else: + # key + key = token.value + value = None + _validate_and_define(params, key, value) + return cls(rdclass, rdtype, priority, target, params) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!H", self.priority)) + self.target.to_wire(file, None, origin, False) + for key in sorted(self.params): + file.write(struct.pack("!H", key)) + value = self.params[key] + # placeholder for length (or actual length of empty values) + file.write(struct.pack("!H", 0)) + if value is None: + continue + else: + start = file.tell() + value.to_wire(file, origin) + end = file.tell() + assert end - start < 65536 + file.seek(start - 2) + stuff = struct.pack("!H", end - start) + file.write(stuff) + file.seek(0, io.SEEK_END) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + priority = parser.get_uint16() + target = parser.get_name(origin) + if priority == 0 and parser.remaining() != 0: + raise dns.exception.FormError('parameters in AliasMode') + params = {} + prior_key = -1 + while parser.remaining() > 0: + key = parser.get_uint16() + if key < prior_key: + raise dns.exception.FormError('keys not in order') + prior_key = key + vlen = parser.get_uint16() + pcls = _class_for_key.get(key, GenericParam) + with parser.restrict_to(vlen): + value = pcls.from_wire_parser(parser, origin) + params[key] = value + return cls(rdclass, rdtype, priority, target, params) + + def _processing_priority(self): + return self.priority + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) diff --git a/lib/dns/rdtypes/tlsabase.py b/lib/dns/rdtypes/tlsabase.py new file mode 100644 index 00000000..786fca55 --- /dev/null +++ b/lib/dns/rdtypes/tlsabase.py @@ -0,0 +1,72 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct +import binascii + +import dns.rdata +import dns.immutable +import dns.rdatatype + + +@dns.immutable.immutable +class TLSABase(dns.rdata.Rdata): + + """Base class for TLSA and SMIMEA records""" + + # see: RFC 6698 + + __slots__ = ['usage', 'selector', 'mtype', 'cert'] + + def __init__(self, rdclass, rdtype, usage, selector, + mtype, cert): + super().__init__(rdclass, rdtype) + self.usage = self._as_uint8(usage) + self.selector = self._as_uint8(selector) + self.mtype = self._as_uint8(mtype) + self.cert = self._as_bytes(cert) + + def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop('chunksize', 128) + return '%d %d %d %s' % (self.usage, + self.selector, + self.mtype, + dns.rdata._hexify(self.cert, + chunksize=chunksize, + **kw)) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + usage = tok.get_uint8() + selector = tok.get_uint8() + mtype = tok.get_uint8() + cert = tok.concatenate_remaining_identifiers().encode() + cert = binascii.unhexlify(cert) + return cls(rdclass, rdtype, usage, selector, mtype, cert) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!BBB", self.usage, self.selector, self.mtype) + file.write(header) + file.write(self.cert) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BBB") + cert = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], header[2], cert) diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index ad0093da..68071ee0 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -20,10 +20,12 @@ import struct import dns.exception +import dns.immutable import dns.rdata import dns.tokenizer +@dns.immutable.immutable class TXTBase(dns.rdata.Rdata): """Base class for rdata that is like a TXT record (see RFC 1035).""" @@ -40,16 +42,8 @@ class TXTBase(dns.rdata.Rdata): *strings*, a tuple of ``bytes`` """ super().__init__(rdclass, rdtype) - if isinstance(strings, (bytes, str)): - strings = (strings,) - encoded_strings = [] - for string in strings: - if isinstance(string, str): - string = string.encode() - else: - string = dns.rdata._constify(string) - encoded_strings.append(string) - object.__setattr__(self, 'strings', tuple(encoded_strings)) + self.strings = self._as_tuple(strings, + lambda x: self._as_bytes(x, True, 255)) def to_text(self, origin=None, relativize=True, **kw): txt = '' @@ -63,11 +57,12 @@ class TXTBase(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): strings = [] - while 1: - token = tok.get().unescape_to_bytes() - if token.is_eol_or_eof(): - break - if not (token.is_quoted_string() or token.is_identifier()): + for token in tok.get_remaining(): + token = token.unescape_to_bytes() + # The 'if' below is always true in the current code, but we + # are leaving this check in in case things change some day. + if not (token.is_quoted_string() or + token.is_identifier()): # pragma: no cover raise dns.exception.SyntaxError("expected a string") if len(token.value) > 255: raise dns.exception.SyntaxError("string too long") diff --git a/lib/dns/rdtypes/txtbase.pyi b/lib/dns/rdtypes/txtbase.pyi new file mode 100644 index 00000000..af447d50 --- /dev/null +++ b/lib/dns/rdtypes/txtbase.pyi @@ -0,0 +1,6 @@ +from .. import rdata + +class TXTBase(rdata.Rdata): + ... +class TXT(TXTBase): + ... diff --git a/lib/dns/rdtypes/util.py b/lib/dns/rdtypes/util.py index a63d1a0a..9bf8f7e9 100644 --- a/lib/dns/rdtypes/util.py +++ b/lib/dns/rdtypes/util.py @@ -15,25 +15,31 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +import collections +import random import struct import dns.exception -import dns.name import dns.ipv4 import dns.ipv6 +import dns.name +import dns.rdata + class Gateway: """A helper class for the IPSECKEY gateway and AMTRELAY relay fields""" name = "" def __init__(self, type, gateway=None): - self.type = type + self.type = dns.rdata.Rdata._as_uint8(type) self.gateway = gateway + self._check() - def _invalid_type(self): - return f"invalid {self.name} type: {self.type}" + @classmethod + def _invalid_type(cls, gateway_type): + return f"invalid {cls.name} type: {gateway_type}" - def check(self): + def _check(self): if self.type == 0: if self.gateway not in (".", None): raise SyntaxError(f"invalid {self.name} for type 0") @@ -48,7 +54,7 @@ class Gateway: if not isinstance(self.gateway, dns.name.Name): raise SyntaxError(f"invalid {self.name}; not a name") else: - raise SyntaxError(self._invalid_type()) + raise SyntaxError(self._invalid_type(self.type)) def to_text(self, origin=None, relativize=True): if self.type == 0: @@ -58,16 +64,21 @@ class Gateway: elif self.type == 3: return str(self.gateway.choose_relativity(origin, relativize)) else: - raise ValueError(self._invalid_type()) + raise ValueError(self._invalid_type(self.type)) # pragma: no cover - def from_text(self, tok, origin=None, relativize=True, relativize_to=None): - if self.type in (0, 1, 2): - return tok.get_string() - elif self.type == 3: - return tok.get_name(origin, relativize, relativize_to) + @classmethod + def from_text(cls, gateway_type, tok, origin=None, relativize=True, + relativize_to=None): + if gateway_type in (0, 1, 2): + gateway = tok.get_string() + elif gateway_type == 3: + gateway = tok.get_name(origin, relativize, relativize_to) else: - raise dns.exception.SyntaxError(self._invalid_type()) + raise dns.exception.SyntaxError( + cls._invalid_type(gateway_type)) # pragma: no cover + return cls(gateway_type, gateway) + # pylint: disable=unused-argument def to_wire(self, file, compress=None, origin=None, canonicalize=False): if self.type == 0: pass @@ -78,26 +89,43 @@ class Gateway: elif self.type == 3: self.gateway.to_wire(file, None, origin, False) else: - raise ValueError(self._invalid_type()) + raise ValueError(self._invalid_type(self.type)) # pragma: no cover + # pylint: enable=unused-argument - def from_wire_parser(self, parser, origin=None): - if self.type == 0: - return None - elif self.type == 1: - return dns.ipv4.inet_ntoa(parser.get_bytes(4)) - elif self.type == 2: - return dns.ipv6.inet_ntoa(parser.get_bytes(16)) - elif self.type == 3: - return parser.get_name(origin) + @classmethod + def from_wire_parser(cls, gateway_type, parser, origin=None): + if gateway_type == 0: + gateway = None + elif gateway_type == 1: + gateway = dns.ipv4.inet_ntoa(parser.get_bytes(4)) + elif gateway_type == 2: + gateway = dns.ipv6.inet_ntoa(parser.get_bytes(16)) + elif gateway_type == 3: + gateway = parser.get_name(origin) else: - raise dns.exception.FormError(self._invalid_type()) + raise dns.exception.FormError(cls._invalid_type(gateway_type)) + return cls(gateway_type, gateway) + class Bitmap: """A helper class for the NSEC/NSEC3/CSYNC type bitmaps""" type_name = "" def __init__(self, windows=None): + last_window = -1 self.windows = windows + for (window, bitmap) in self.windows: + if not isinstance(window, int): + raise ValueError(f"bad {self.type_name} window type") + if window <= last_window: + raise ValueError(f"bad {self.type_name} window order") + if window > 256: + raise ValueError(f"bad {self.type_name} window number") + last_window = window + if not isinstance(bitmap, bytes): + raise ValueError(f"bad {self.type_name} octets type") + if len(bitmap) == 0 or len(bitmap) > 32: + raise ValueError(f"bad {self.type_name} octets") def to_text(self): text = "" @@ -111,15 +139,13 @@ class Bitmap: text += (' ' + ' '.join(bits)) return text - def from_text(self, tok): + @classmethod + def from_text(cls, tok): rdtypes = [] - while True: - token = tok.get().unescape() - if token.is_eol_or_eof(): - break - rdtype = dns.rdatatype.from_text(token.value) + for token in tok.get_remaining(): + rdtype = dns.rdatatype.from_text(token.unescape().value) if rdtype == 0: - raise dns.exception.SyntaxError(f"{self.type_name} with bit 0") + raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0") rdtypes.append(rdtype) rdtypes.sort() window = 0 @@ -134,7 +160,7 @@ class Bitmap: new_window = rdtype // 256 if new_window != window: if octets != 0: - windows.append((window, bitmap[0:octets])) + windows.append((window, bytes(bitmap[0:octets]))) bitmap = bytearray(b'\0' * 32) window = new_window offset = rdtype % 256 @@ -143,24 +169,76 @@ class Bitmap: octets = byte + 1 bitmap[byte] = bitmap[byte] | (0x80 >> bit) if octets != 0: - windows.append((window, bitmap[0:octets])) - return windows + windows.append((window, bytes(bitmap[0:octets]))) + return cls(windows) def to_wire(self, file): for (window, bitmap) in self.windows: file.write(struct.pack('!BB', window, len(bitmap))) file.write(bitmap) - def from_wire_parser(self, parser): + @classmethod + def from_wire_parser(cls, parser): windows = [] - last_window = -1 while parser.remaining() > 0: window = parser.get_uint8() - if window <= last_window: - raise dns.exception.FormError(f"bad {self.type_name} bitmap") bitmap = parser.get_counted_bytes() - if len(bitmap) == 0 or len(bitmap) > 32: - raise dns.exception.FormError(f"bad {self.type_name} octets") windows.append((window, bitmap)) - last_window = window - return windows + return cls(windows) + + +def _priority_table(items): + by_priority = collections.defaultdict(list) + for rdata in items: + by_priority[rdata._processing_priority()].append(rdata) + return by_priority + +def priority_processing_order(iterable): + items = list(iterable) + if len(items) == 1: + return items + by_priority = _priority_table(items) + ordered = [] + for k in sorted(by_priority.keys()): + rdatas = by_priority[k] + random.shuffle(rdatas) + ordered.extend(rdatas) + return ordered + +_no_weight = 0.1 + +def weighted_processing_order(iterable): + items = list(iterable) + if len(items) == 1: + return items + by_priority = _priority_table(items) + ordered = [] + for k in sorted(by_priority.keys()): + rdatas = by_priority[k] + total = sum(rdata._processing_weight() or _no_weight + for rdata in rdatas) + while len(rdatas) > 1: + r = random.uniform(0, total) + for (n, rdata) in enumerate(rdatas): + weight = rdata._processing_weight() or _no_weight + if weight > r: + break + r -= weight + total -= weight + ordered.append(rdata) # pylint: disable=undefined-loop-variable + del rdatas[n] # pylint: disable=undefined-loop-variable + ordered.append(rdatas[0]) + return ordered + +def parse_formatted_hex(formatted, num_chunks, chunk_size, separator): + if len(formatted) != num_chunks * (chunk_size + 1) - 1: + raise ValueError('invalid formatted hex string') + value = b'' + for _ in range(num_chunks): + chunk = formatted[0:chunk_size] + value += int(chunk, 16).to_bytes(chunk_size // 2, 'big') + formatted = formatted[chunk_size:] + if len(formatted) > 0 and formatted[0] != separator: + raise ValueError('invalid formatted hex string') + formatted = formatted[1:] + return value diff --git a/lib/dns/resolver.py b/lib/dns/resolver.py index 4f630e4d..166f8492 100644 --- a/lib/dns/resolver.py +++ b/lib/dns/resolver.py @@ -43,14 +43,17 @@ import dns.reversename import dns.tsig if sys.platform == 'win32': - import winreg # pragma: no cover + import dns.win32util class NXDOMAIN(dns.exception.DNSException): """The DNS query name does not exist.""" supp_kwargs = {'qnames', 'responses'} fmt = None # we have our own __str__ implementation - def _check_kwargs(self, qnames, responses=None): + # pylint: disable=arguments-differ + + def _check_kwargs(self, qnames, + responses=None): if not isinstance(qnames, (list, tuple, set)): raise AttributeError("qnames must be a list, tuple or set") if len(qnames) == 0: @@ -78,17 +81,16 @@ class NXDOMAIN(dns.exception.DNSException): """Return the unresolved canonical name.""" if 'qnames' not in self.kwargs: raise TypeError("parametrized exception required") - IN = dns.rdataclass.IN - CNAME = dns.rdatatype.CNAME - cname = None for qname in self.kwargs['qnames']: response = self.kwargs['responses'][qname] - for answer in response.answer: - if answer.rdtype != CNAME or answer.rdclass != IN: - continue - cname = answer[0].target.to_text() - if cname is not None: - return dns.name.from_text(cname) + try: + cname = response.canonical_name() + if cname != qname: + return cname + except Exception: + # We can just eat this exception as it means there was + # something wrong with the response. + pass return self.kwargs['qnames'][0] def __add__(self, e_nx): @@ -129,11 +131,33 @@ class NXDOMAIN(dns.exception.DNSException): class YXDOMAIN(dns.exception.DNSException): """The DNS query name is too long after DNAME substitution.""" -# The definition of the Timeout exception has moved from here to the -# dns.exception module. We keep dns.resolver.Timeout defined for -# backwards compatibility. -Timeout = dns.exception.Timeout +def _errors_to_text(errors): + """Turn a resolution errors trace into a list of text.""" + texts = [] + for err in errors: + texts.append('Server {} {} port {} answered {}'.format(err[0], + 'TCP' if err[1] else 'UDP', err[2], err[3])) + return texts + + +class LifetimeTimeout(dns.exception.Timeout): + """The resolution lifetime expired.""" + + msg = "The resolution lifetime expired." + fmt = "%s after {timeout} seconds: {errors}" % msg[:-1] + supp_kwargs = {'timeout', 'errors'} + + def _fmt_kwargs(self, **kwargs): + srv_msgs = _errors_to_text(kwargs['errors']) + return super()._fmt_kwargs(timeout=kwargs['timeout'], + errors='; '.join(srv_msgs)) + + +# We added more detail to resolution timeouts, but they are still +# subclasses of dns.exception.Timeout for backwards compatibility. We also +# keep dns.resolver.Timeout defined for backwards compatibility. +Timeout = LifetimeTimeout class NoAnswer(dns.exception.DNSException): @@ -145,6 +169,9 @@ class NoAnswer(dns.exception.DNSException): def _fmt_kwargs(self, **kwargs): return super()._fmt_kwargs(query=kwargs['response'].question) + def response(self): + return self.kwargs['response'] + class NoNameservers(dns.exception.DNSException): """All nameservers failed to answer the query. @@ -160,10 +187,7 @@ class NoNameservers(dns.exception.DNSException): supp_kwargs = {'request', 'errors'} def _fmt_kwargs(self, **kwargs): - srv_msgs = [] - for err in kwargs['errors']: - srv_msgs.append('Server {} {} port {} answered {}'.format(err[0], - 'TCP' if err[1] else 'UDP', err[2], err[3])) + srv_msgs = _errors_to_text(kwargs['errors']) return super()._fmt_kwargs(query=kwargs['request'].question, errors='; '.join(srv_msgs)) @@ -206,51 +230,12 @@ class Answer: self.response = response self.nameserver = nameserver self.port = port - min_ttl = -1 - rrset = None - for count in range(0, 15): - try: - rrset = response.find_rrset(response.answer, qname, - rdclass, rdtype) - if min_ttl == -1 or rrset.ttl < min_ttl: - min_ttl = rrset.ttl - break - except KeyError: - if rdtype != dns.rdatatype.CNAME: - try: - crrset = response.find_rrset(response.answer, - qname, - rdclass, - dns.rdatatype.CNAME) - if min_ttl == -1 or crrset.ttl < min_ttl: - min_ttl = crrset.ttl - for rd in crrset: - qname = rd.target - break - continue - except KeyError: - # Exit the chaining loop - break - self.canonical_name = qname - self.rrset = rrset - if rrset is None: - while 1: - # Look for a SOA RR whose owner name is a superdomain - # of qname. - try: - srrset = response.find_rrset(response.authority, qname, - rdclass, dns.rdatatype.SOA) - if min_ttl == -1 or srrset.ttl < min_ttl: - min_ttl = srrset.ttl - if srrset[0].minimum < min_ttl: - min_ttl = srrset[0].minimum - break - except KeyError: - try: - qname = qname.parent() - except dns.name.NoParent: - break - self.expiration = time.time() + min_ttl + self.chaining_result = response.resolve_chaining() + # Copy some attributes out of chaining_result for backwards + # compatibility and convenience. + self.canonical_name = self.chaining_result.canonical_name + self.rrset = self.chaining_result.answer + self.expiration = time.time() + self.chaining_result.minimum_ttl def __getattr__(self, attr): # pragma: no cover if attr == 'name': @@ -283,7 +268,54 @@ class Answer: del self.rrset[i] -class Cache: +class CacheStatistics: + """Cache Statistics + """ + + def __init__(self, hits=0, misses=0): + self.hits = hits + self.misses = misses + + def reset(self): + self.hits = 0 + self.misses = 0 + + def clone(self): + return CacheStatistics(self.hits, self.misses) + + +class CacheBase: + def __init__(self): + self.lock = _threading.Lock() + self.statistics = CacheStatistics() + + def reset_statistics(self): + """Reset all statistics to zero.""" + with self.lock: + self.statistics.reset() + + def hits(self): + """How many hits has the cache had?""" + with self.lock: + return self.statistics.hits + + def misses(self): + """How many misses has the cache had?""" + with self.lock: + return self.statistics.misses + + def get_statistics_snapshot(self): + """Return a consistent snapshot of all the statistics. + + If running with multiple threads, it's better to take a + snapshot than to call statistics methods such as hits() and + misses() individually. + """ + with self.lock: + return self.statistics.clone() + + +class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" def __init__(self, cleaning_interval=300.0): @@ -291,10 +323,10 @@ class Cache: periodic cleanings. """ + super().__init__() self.data = {} self.cleaning_interval = cleaning_interval self.next_cleaning = time.time() + self.cleaning_interval - self.lock = _threading.Lock() def _maybe_clean(self): """Clean the cache if it's time to do so.""" @@ -325,7 +357,9 @@ class Cache: self._maybe_clean() v = self.data.get(key) if v is None or v.expiration <= time.time(): + self.statistics.misses += 1 return None + self.statistics.hits += 1 return v def put(self, key, value): @@ -366,6 +400,7 @@ class LRUCacheNode: def __init__(self, key, value): self.key = key self.value = value + self.hits = 0 self.prev = self self.next = self @@ -380,7 +415,7 @@ class LRUCacheNode: self.prev.next = self.next -class LRUCache: +class LRUCache(CacheBase): """Thread-safe, bounded, least-recently-used DNS answer cache. This cache is better than the simple cache (above) if you're @@ -395,12 +430,12 @@ class LRUCache: it must be greater than 0. """ + super().__init__() self.data = {} self.set_max_size(max_size) self.sentinel = LRUCacheNode(None, None) self.sentinel.prev = self.sentinel self.sentinel.next = self.sentinel - self.lock = _threading.Lock() def set_max_size(self, max_size): if max_size < 1: @@ -421,16 +456,29 @@ class LRUCache: with self.lock: node = self.data.get(key) if node is None: + self.statistics.misses += 1 return None # Unlink because we're either going to move the node to the front # of the LRU list or we're going to free it. node.unlink() if node.value.expiration <= time.time(): del self.data[node.key] + self.statistics.misses += 1 return None node.link_after(self.sentinel) + self.statistics.hits += 1 + node.hits += 1 return node.value + def get_hits_for_key(self, key): + """Return the number of cache hits associated with the specified key.""" + with self.lock: + node = self.data.get(key) + if node is None or node.value.expiration <= time.time(): + return 0 + else: + return node.hits + def put(self, key, value): """Associate key and value in the cache. @@ -632,8 +680,15 @@ class _Resolution: assert response is not None rcode = response.rcode() if rcode == dns.rcode.NOERROR: - answer = Answer(self.qname, self.rdtype, self.rdclass, response, - self.nameserver, self.port) + try: + answer = Answer(self.qname, self.rdtype, self.rdclass, response, + self.nameserver, self.port) + except Exception as e: + self.errors.append((self.nameserver, self.tcp_attempt, + self.port, e, response)) + # The nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + return (None, False) if self.resolver.cache: self.resolver.cache.put((self.qname, self.rdtype, self.rdclass), answer) @@ -641,16 +696,24 @@ class _Resolution: raise NoAnswer(response=answer.response) return (answer, True) elif rcode == dns.rcode.NXDOMAIN: - self.nxdomain_responses[self.qname] = response - # Make next_nameserver() return None, so caller breaks its - # inner loop and calls next_request(). - if self.resolver.cache: + # Further validate the response by making an Answer, even + # if we aren't going to cache it. + try: answer = Answer(self.qname, dns.rdatatype.ANY, dns.rdataclass.IN, response) + except Exception as e: + self.errors.append((self.nameserver, self.tcp_attempt, + self.port, e, response)) + # The nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + return (None, False) + self.nxdomain_responses[self.qname] = response + if self.resolver.cache: self.resolver.cache.put((self.qname, dns.rdatatype.ANY, self.rdclass), answer) - + # Make next_nameserver() return None, so caller breaks its + # inner loop and calls next_request(). return (None, True) elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() @@ -668,7 +731,7 @@ class _Resolution: dns.rcode.to_text(rcode), response)) return (None, False) -class Resolver: +class BaseResolver: """DNS stub resolver.""" # We initialize in reset() @@ -690,7 +753,7 @@ class Resolver: self.reset() if configure: if sys.platform == 'win32': - self.read_registry() # pragma: no cover + self.read_registry() elif filename: self.read_resolv_conf(filename) @@ -743,7 +806,7 @@ class Resolver: f = stack.enter_context(open(f)) except OSError: # /etc/resolv.conf doesn't exist, can't be read, etc. - raise NoResolverConfiguration + raise NoResolverConfiguration(f'cannot open {f}') for l in f: if len(l) == 0 or l[0] == '#' or l[0] == ';': @@ -758,15 +821,21 @@ class Resolver: self.nameservers.append(tokens[1]) elif tokens[0] == 'domain': self.domain = dns.name.from_text(tokens[1]) + # domain and search are exclusive + self.search = [] elif tokens[0] == 'search': + # the last search wins + self.search = [] for suffix in tokens[1:]: self.search.append(dns.name.from_text(suffix)) + # We don't set domain as it is not used if + # len(self.search) > 0 elif tokens[0] == 'options': for opt in tokens[1:]: if opt == 'rotate': self.rotate = True elif opt == 'edns0': - self.use_edns(0, 0, 0) + self.use_edns() elif 'timeout' in opt: try: self.timeout = int(opt.split(':')[1]) @@ -778,176 +847,36 @@ class Resolver: except (ValueError, IndexError): pass if len(self.nameservers) == 0: - raise NoResolverConfiguration - - def _determine_split_char(self, entry): - # - # The windows registry irritatingly changes the list element - # delimiter in between ' ' and ',' (and vice-versa) in various - # versions of windows. - # - if entry.find(' ') >= 0: # pragma: no cover - split_char = ' ' - elif entry.find(',') >= 0: # pragma: no cover - split_char = ',' - else: - # probably a singleton; treat as a space-separated list. - split_char = ' ' - return split_char - - def _config_win32_nameservers(self, nameservers): - # we call str() on nameservers to convert it from unicode to ascii - nameservers = str(nameservers) - split_char = self._determine_split_char(nameservers) - ns_list = nameservers.split(split_char) - for ns in ns_list: - if ns not in self.nameservers: - self.nameservers.append(ns) - - def _config_win32_domain(self, domain): # pragma: no cover - # we call str() on domain to convert it from unicode to ascii - self.domain = dns.name.from_text(str(domain)) - - def _config_win32_search(self, search): # pragma: no cover - # we call str() on search to convert it from unicode to ascii - search = str(search) - split_char = self._determine_split_char(search) - search_list = search.split(split_char) - for s in search_list: - if s not in self.search: - self.search.append(dns.name.from_text(s)) - - def _config_win32_fromkey(self, key, always_try_domain): - try: - servers, rtype = winreg.QueryValueEx(key, 'NameServer') - except WindowsError: # pylint: disable=undefined-variable - servers = None - if servers: - self._config_win32_nameservers(servers) - if servers or always_try_domain: - try: - dom, rtype = winreg.QueryValueEx(key, 'Domain') - if dom: - self._config_win32_domain(dom) - except WindowsError: # pragma: no cover - pass - else: - try: - servers, rtype = winreg.QueryValueEx(key, 'DhcpNameServer') - except WindowsError: # pragma: no cover - servers = None - if servers: # pragma: no cover - self._config_win32_nameservers(servers) - try: - dom, rtype = winreg.QueryValueEx(key, 'DhcpDomain') - if dom: # pragma: no cover - self._config_win32_domain(dom) - except WindowsError: # pragma: no cover - pass - try: - search, rtype = winreg.QueryValueEx(key, 'SearchList') - except WindowsError: # pylint: disable=undefined-variable - search = None - if search: # pragma: no cover - self._config_win32_search(search) + raise NoResolverConfiguration('no nameservers') def read_registry(self): """Extract resolver configuration from the Windows registry.""" - - lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: - tcp_params = winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters') - try: - self._config_win32_fromkey(tcp_params, True) - finally: - tcp_params.Close() - interfaces = winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters' - r'\Interfaces') - try: - i = 0 - while True: - try: - guid = winreg.EnumKey(interfaces, i) - i += 1 - key = winreg.OpenKey(interfaces, guid) - if not self._win32_is_nic_enabled(lm, guid, key): - continue - try: - self._config_win32_fromkey(key, False) - finally: - key.Close() - except EnvironmentError: # pragma: no cover - break - finally: - interfaces.Close() - finally: - lm.Close() + info = dns.win32util.get_dns_info() + if info.domain is not None: + self.domain = info.domain + self.nameservers = info.nameservers + self.search = info.search + except AttributeError: + raise NotImplementedError - def _win32_is_nic_enabled(self, lm, guid, - interface_key): - # Look in the Windows Registry to determine whether the network - # interface corresponding to the given guid is enabled. - # - # (Code contributed by Paul Marks, thanks!) - # - try: - # This hard-coded location seems to be consistent, at least - # from Windows 2000 through Vista. - connection_key = winreg.OpenKey( - lm, - r'SYSTEM\CurrentControlSet\Control\Network' - r'\{4D36E972-E325-11CE-BFC1-08002BE10318}' - r'\%s\Connection' % guid) - - try: - # The PnpInstanceID points to a key inside Enum - (pnp_id, ttype) = winreg.QueryValueEx( - connection_key, 'PnpInstanceID') - - if ttype != winreg.REG_SZ: # pragma: no cover - raise ValueError - - device_key = winreg.OpenKey( - lm, r'SYSTEM\CurrentControlSet\Enum\%s' % pnp_id) - - try: - # Get ConfigFlags for this device - (flags, ttype) = winreg.QueryValueEx( - device_key, 'ConfigFlags') - - if ttype != winreg.REG_DWORD: # pragma: no cover - raise ValueError - - # Based on experimentation, bit 0x1 indicates that the - # device is disabled. - return not flags & 0x1 - - finally: - device_key.Close() - finally: - connection_key.Close() - except Exception: # pragma: no cover - return False - - def _compute_timeout(self, start, lifetime=None): + def _compute_timeout(self, start, lifetime=None, errors=None): lifetime = self.lifetime if lifetime is None else lifetime now = time.time() duration = now - start + if errors is None: + errors = [] if duration < 0: if duration < -1: # Time going backwards is bad. Just give up. - raise Timeout(timeout=duration) + raise LifetimeTimeout(timeout=duration, errors=errors) else: # Time went backwards, but only a little. This can # happen, e.g. under vmware with older linux kernels. # Pretend it didn't happen. now = start if duration >= lifetime: - raise Timeout(timeout=duration) + raise LifetimeTimeout(timeout=duration, errors=errors) return min(lifetime - duration, self.timeout) def _get_qnames_to_try(self, qname, search): @@ -959,19 +888,113 @@ class Resolver: if qname.is_absolute(): qnames_to_try.append(qname) else: - if len(qname) > 1 or not search: - qnames_to_try.append(qname.concatenate(dns.name.root)) - if search and self.search: - for suffix in self.search: - if self.ndots is None or len(qname.labels) >= self.ndots: - qnames_to_try.append(qname.concatenate(suffix)) - elif search: - qnames_to_try.append(qname.concatenate(self.domain)) + abs_qname = qname.concatenate(dns.name.root) + if search: + if len(self.search) > 0: + # There is a search list, so use it exclusively + search_list = self.search[:] + elif self.domain != dns.name.root and self.domain is not None: + # We have some notion of a domain that isn't the root, so + # use it as the search list. + search_list = [self.domain] + else: + search_list = [] + # Figure out the effective ndots (default is 1) + if self.ndots is None: + ndots = 1 + else: + ndots = self.ndots + for suffix in search_list: + qnames_to_try.append(qname + suffix) + if len(qname) > ndots: + # The name has at least ndots dots, so we should try an + # absolute query first. + qnames_to_try.insert(0, abs_qname) + else: + # The name has less than ndots dots, so we should search + # first, then try the absolute name. + qnames_to_try.append(abs_qname) + else: + qnames_to_try.append(abs_qname) return qnames_to_try + def use_tsig(self, keyring, keyname=None, + algorithm=dns.tsig.default_algorithm): + """Add a TSIG signature to each query. + + The parameters are passed to ``dns.message.Message.use_tsig()``; + see its documentation for details. + """ + + self.keyring = keyring + self.keyname = keyname + self.keyalgorithm = algorithm + + def use_edns(self, edns=0, ednsflags=0, + payload=dns.message.DEFAULT_EDNS_PAYLOAD): + """Configure EDNS behavior. + + *edns*, an ``int``, is the EDNS level to use. Specifying + ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case + the other parameters are ignored. Specifying ``True`` is + equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + """ + + if edns is None or edns is False: + edns = -1 + elif edns is True: + edns = 0 + self.edns = edns + self.ednsflags = ednsflags + self.payload = payload + + def set_flags(self, flags): + """Overrides the default flags with your own. + + *flags*, an ``int``, the message flags to use. + """ + + self.flags = flags + + @property + def nameservers(self): + return self._nameservers + + @nameservers.setter + def nameservers(self, nameservers): + """ + *nameservers*, a ``list`` of nameservers. + + Raises ``ValueError`` if *nameservers* is anything other than a + ``list``. + """ + if isinstance(nameservers, list): + for nameserver in nameservers: + if not dns.inet.is_address(nameserver): + try: + if urlparse(nameserver).scheme != 'https': + raise NotImplementedError + except Exception: + raise ValueError(f'nameserver {nameserver} is not an ' + 'IP address or valid https URL') + self._nameservers = nameservers + else: + raise ValueError('nameservers must be a list' + ' (not a {})'.format(type(nameservers))) + + +class Resolver(BaseResolver): + """DNS stub resolver.""" + def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, tcp=False, source=None, raise_on_no_answer=True, source_port=0, - lifetime=None, search=None): + lifetime=None, search=None): # pylint: disable=arguments-differ """Query nameservers to find the answer to the question. The *qname*, *rdtype*, and *rdclass* parameters may be objects @@ -1004,7 +1027,7 @@ class Resolver: which causes the value of the resolver's ``use_search_by_default`` attribute to be used. - Raises ``dns.exception.Timeout`` if no answers could be found + Raises ``dns.resolver.LifetimeTimeout`` if no answers could be found in the specified lifetime. Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. @@ -1040,7 +1063,8 @@ class Resolver: (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: time.sleep(backoff) - timeout = self._compute_timeout(start, lifetime) + timeout = self._compute_timeout(start, lifetime, + resolution.errors) try: if dns.inet.is_address(nameserver): if tcp: @@ -1058,9 +1082,6 @@ class Resolver: source_port=source_port, raise_on_truncation=True) else: - protocol = urlparse(nameserver).scheme - if protocol != 'https': - raise NotImplementedError response = dns.query.https(request, nameserver, timeout=timeout) except Exception as ex: @@ -1109,64 +1130,31 @@ class Resolver: rdclass=dns.rdataclass.IN, *args, **kwargs) - def use_tsig(self, keyring, keyname=None, - algorithm=dns.tsig.default_algorithm): - """Add a TSIG signature to each query. + # pylint: disable=redefined-outer-name - The parameters are passed to ``dns.message.Message.use_tsig()``; - see its documentation for details. + def canonical_name(self, name): + """Determine the canonical name of *name*. + + The canonical name is the name the resolver uses for queries + after all CNAME and DNAME renamings have been applied. + + *name*, a ``dns.name.Name`` or ``str``, the query name. + + This method can raise any exception that ``resolve()`` can + raise, other than ``dns.resolver.NoAnswer`` and + ``dns.resolver.NXDOMAIN``. + + Returns a ``dns.name.Name``. """ + try: + answer = self.resolve(name, raise_on_no_answer=False) + canonical_name = answer.canonical_name + except dns.resolver.NXDOMAIN as e: + canonical_name = e.canonical_name + return canonical_name - self.keyring = keyring - self.keyname = keyname - self.keyalgorithm = algorithm + # pylint: enable=redefined-outer-name - def use_edns(self, edns, ednsflags, payload): - """Configure EDNS behavior. - - *edns*, an ``int``, is the EDNS level to use. Specifying - ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case - the other parameters are ignored. Specifying ``True`` is - equivalent to specifying 0, i.e. "use EDNS0". - - *ednsflags*, an ``int``, the EDNS flag values. - - *payload*, an ``int``, is the EDNS sender's payload field, which is the - maximum size of UDP datagram the sender can handle. I.e. how big - a response to this message can be. - """ - - if edns is None: - edns = -1 - self.edns = edns - self.ednsflags = ednsflags - self.payload = payload - - def set_flags(self, flags): - """Overrides the default flags with your own. - - *flags*, an ``int``, the message flags to use. - """ - - self.flags = flags - - @property - def nameservers(self): - return self._nameservers - - @nameservers.setter - def nameservers(self, nameservers): - """ - *nameservers*, a ``list`` of nameservers. - - Raises ``ValueError`` if *nameservers* is anything other than a - ``list``. - """ - if isinstance(nameservers, list): - self._nameservers = nameservers - else: - raise ValueError('nameservers must be a list' - ' (not a {})'.format(type(nameservers))) #: The default resolver. default_resolver = None @@ -1233,7 +1221,18 @@ def resolve_address(ipaddr, *args, **kwargs): return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) -def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): +def canonical_name(name): + """Determine the canonical name of *name*. + + See ``dns.resolver.Resolver.canonical_name`` for more information on the + parameters and possible exceptions. + """ + + return get_default_resolver().canonical_name(name) + + +def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, + lifetime=None): """Find the name of the zone which contains the specified name. *name*, an absolute ``dns.name.Name`` or ``str``, the query name. @@ -1243,12 +1242,19 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): *tcp*, a ``bool``. If ``True``, use TCP to make the query. *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. - If ``None``, the default resolver is used. + If ``None``, the default, then the default resolver is used. + + *lifetime*, a ``float``, the total time to allow for the queries needed + to determine the zone. If ``None``, the default, then only the individual + query limits of the resolver apply. Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS root. (This is only likely to happen if you're using non-default root servers in your network and they are misconfigured.) + Raises ``dns.resolver.LifetimeTimeout`` if the answer could not be + found in the alotted lifetime. + Returns a ``dns.name.Name``. """ @@ -1258,14 +1264,44 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): resolver = get_default_resolver() if not name.is_absolute(): raise NotAbsolute(name) + start = time.time() + if lifetime is not None: + expiration = start + lifetime + else: + expiration = None while 1: try: - answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp) + if expiration: + rlifetime = expiration - time.time() + if rlifetime <= 0: + rlifetime = 0 + else: + rlifetime = None + answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp, + lifetime=rlifetime) if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): - pass + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e: + if isinstance(e, dns.resolver.NXDOMAIN): + response = e.responses().get(name) + else: + response = e.response() # pylint: disable=no-value-for-parameter + if response: + for rrs in response.authority: + if rrs.rdtype == dns.rdatatype.SOA and \ + rrs.rdclass == rdclass: + (nr, _, _) = rrs.name.fullcompare(name) + if nr == dns.name.NAMERELN_SUPERDOMAIN: + # We're doing a proper superdomain check as + # if the name were equal we ought to have gotten + # it in the answer section! We are ignoring the + # possibility that the authority is insane and + # is including multiple SOA RRs for different + # authorities. + return rrs.name + # we couldn't extract anything useful from the response (e.g. it's + # a type 3 NXDOMAIN) try: name = name.parent() except dns.name.NoParent: @@ -1315,7 +1351,7 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') v6addrs = [] v4addrs = [] - canonical_name = None + canonical_name = None # pylint: disable=redefined-outer-name # Is host None or an address literal? If so, use the system's # getaddrinfo(). if host is None: @@ -1352,8 +1388,7 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, v4addrs.append(rdata.address) except dns.resolver.NXDOMAIN: raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') - except Exception as e: - print(e) + except Exception: # We raise EAI_AGAIN here as the failure may be temporary # (e.g. a timeout) and EAI_SYSTEM isn't defined on Windows. # [Issue #416] @@ -1482,7 +1517,7 @@ def _gethostbyaddr(ip): 'Name or service not known') sockaddr = (ip, 80) family = socket.AF_INET - (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) + (name, _) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) aliases = [] addresses = [] tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, diff --git a/lib/dns/resolver.pyi b/lib/dns/resolver.pyi new file mode 100644 index 00000000..6da21f12 --- /dev/null +++ b/lib/dns/resolver.pyi @@ -0,0 +1,61 @@ +from typing import Union, Optional, List, Any, Dict +from . import exception, rdataclass, name, rdatatype + +import socket +_gethostbyname = socket.gethostbyname + +class NXDOMAIN(exception.DNSException): ... +class YXDOMAIN(exception.DNSException): ... +class NoAnswer(exception.DNSException): ... +class NoNameservers(exception.DNSException): ... +class NotAbsolute(exception.DNSException): ... +class NoRootSOA(exception.DNSException): ... +class NoMetaqueries(exception.DNSException): ... +class NoResolverConfiguration(exception.DNSException): ... +Timeout = exception.Timeout + +def resolve(qname : str, rdtype : Union[int,str] = 0, + rdclass : Union[int,str] = 0, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime : Optional[float]=None, + search : Optional[bool]=None): + ... +def query(qname : str, rdtype : Union[int,str] = 0, + rdclass : Union[int,str] = 0, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime : Optional[float]=None): + ... +def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Optional[Dict]): + ... +class LRUCache: + def __init__(self, max_size=1000): + ... + def get(self, key): + ... + def put(self, key, val): + ... +class Answer: + def __init__(self, qname, rdtype, rdclass, response, + raise_on_no_answer=True): + ... +def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False, + resolver : Optional[Resolver] = None): + ... + +class Resolver: + def __init__(self, filename : Optional[str] = '/etc/resolv.conf', + configure : Optional[bool] = True): + self.nameservers : List[str] + def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, + rdclass : Union[int,str] = rdataclass.IN, + tcp : bool = False, source : Optional[str] = None, + raise_on_no_answer=True, source_port : int = 0, + lifetime : Optional[float]=None, + search : Optional[bool]=None): + ... + def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A, + rdclass : Union[int,str] = rdataclass.IN, + tcp : bool = False, source : Optional[str] = None, + raise_on_no_answer=True, source_port : int = 0, + lifetime : Optional[float]=None): + ... diff --git a/lib/dns/reversename.pyi b/lib/dns/reversename.pyi new file mode 100644 index 00000000..97f072ea --- /dev/null +++ b/lib/dns/reversename.pyi @@ -0,0 +1,6 @@ +from . import name +def from_address(text : str) -> name.Name: + ... + +def to_address(name : name.Name) -> str: + ... diff --git a/lib/dns/rrset.py b/lib/dns/rrset.py index 68136f40..a71d4573 100644 --- a/lib/dns/rrset.py +++ b/lib/dns/rrset.py @@ -69,25 +69,45 @@ class RRset(dns.rdataset.Rdataset): return self.to_text() def __eq__(self, other): - if not isinstance(other, RRset): - return False - if self.name != other.name: + if isinstance(other, RRset): + if self.name != other.name: + return False + elif not isinstance(other, dns.rdataset.Rdataset): return False return super().__eq__(other) - def match(self, name, rdclass, rdtype, covers, deleting=None): - """Returns ``True`` if this rrset matches the specified class, type, - covers, and deletion state. - """ + def match(self, *args, **kwargs): + """Does this rrset match the specified attributes? + Behaves as :py:func:`full_match()` if the first argument is a + ``dns.name.Name``, and as :py:func:`dns.rdataset.Rdataset.match()` + otherwise. + + (This behavior fixes a design mistake where the signature of this + method became incompatible with that of its superclass. The fix + makes RRsets matchable as Rdatasets while preserving backwards + compatibility.) + """ + if isinstance(args[0], dns.name.Name): + return self.full_match(*args, **kwargs) + else: + return super().match(*args, **kwargs) + + def full_match(self, name, rdclass, rdtype, covers, + deleting=None): + """Returns ``True`` if this rrset matches the specified name, class, + type, covers, and deletion state. + """ if not super().match(rdclass, rdtype, covers): return False if self.name != name or self.deleting != deleting: return False return True + # pylint: disable=arguments-differ + def to_text(self, origin=None, relativize=True, **kw): - """Convert the RRset into DNS master file format. + """Convert the RRset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information on how *origin* and *relativize* determine the way names @@ -106,7 +126,8 @@ class RRset(dns.rdataset.Rdataset): return super().to_text(self.name, origin, relativize, self.deleting, **kw) - def to_wire(self, file, compress=None, origin=None, **kw): + def to_wire(self, file, compress=None, origin=None, + **kw): """Convert the RRset to wire format. All keyword arguments are passed to ``dns.rdataset.to_wire()``; see @@ -118,6 +139,8 @@ class RRset(dns.rdataset.Rdataset): return super().to_wire(self.name, file, compress, origin, self.deleting, **kw) + # pylint: enable=arguments-differ + def to_rdataset(self): """Convert an RRset into an Rdataset. @@ -127,7 +150,8 @@ class RRset(dns.rdataset.Rdataset): def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, - idna_codec=None): + idna_codec=None, origin=None, relativize=True, + relativize_to=None): """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. @@ -135,6 +159,14 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, encoder/decoder to use; if ``None``, the default IDNA 2003 encoder/decoder is used. + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + Returns a ``dns.rrset.RRset`` object. """ @@ -145,7 +177,8 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, r = RRset(name, rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, idna_codec=idna_codec) + rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, + relativize_to, idna_codec) r.add(rd) return r diff --git a/lib/dns/rrset.pyi b/lib/dns/rrset.pyi new file mode 100644 index 00000000..0a81a2a0 --- /dev/null +++ b/lib/dns/rrset.pyi @@ -0,0 +1,10 @@ +from typing import List, Optional +from . import rdataset, rdatatype + +class RRset(rdataset.Rdataset): + def __init__(self, name, rdclass : int , rdtype : int, covers=rdatatype.NONE, + deleting : Optional[int] =None) -> None: + self.name = name + self.deleting = deleting +def from_text(name : str, ttl : int, rdclass : str, rdtype : str, *text_rdatas : str): + ... diff --git a/lib/dns/set.py b/lib/dns/set.py index 0982d787..1fd4d0ae 100644 --- a/lib/dns/set.py +++ b/lib/dns/set.py @@ -84,9 +84,13 @@ class Set: subclasses. """ - cls = self.__class__ + if hasattr(self, '_clone_class'): + cls = self._clone_class + else: + cls = self.__class__ obj = cls.__new__(cls) - obj.items = self.items.copy() + obj.items = odict() + obj.items.update(self.items) return obj def __copy__(self): diff --git a/lib/dns/tokenizer.py b/lib/dns/tokenizer.py index 3e5d2ba9..7ddc7a96 100644 --- a/lib/dns/tokenizer.py +++ b/lib/dns/tokenizer.py @@ -15,7 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""Tokenize DNS master file format""" +"""Tokenize DNS zone file format""" import io import sys @@ -41,19 +41,20 @@ class UngetBufferFull(dns.exception.DNSException): class Token: - """A DNS master file format token. + """A DNS zone file format token. ttype: The token type value: The token value has_escape: Does the token value contain escapes? """ - def __init__(self, ttype, value='', has_escape=False): + def __init__(self, ttype, value='', has_escape=False, comment=None): """Initialize a token instance.""" self.ttype = ttype self.value = value self.has_escape = has_escape + self.comment = comment def is_eof(self): return self.ttype == EOF @@ -104,7 +105,7 @@ class Token: c = self.value[i] i += 1 if c == '\\': - if i >= l: + if i >= l: # pragma: no cover (can't happen via get()) raise dns.exception.UnexpectedEnd c = self.value[i] i += 1 @@ -119,7 +120,10 @@ class Token: i += 1 if not (c2.isdigit() and c3.isdigit()): raise dns.exception.SyntaxError - c = chr(int(c) * 100 + int(c2) * 10 + int(c3)) + codepoint = int(c) * 100 + int(c2) * 10 + int(c3) + if codepoint > 255: + raise dns.exception.SyntaxError + c = chr(codepoint) unescaped += c return Token(self.ttype, unescaped) @@ -155,7 +159,7 @@ class Token: c = self.value[i] i += 1 if c == '\\': - if i >= l: + if i >= l: # pragma: no cover (can't happen via get()) raise dns.exception.UnexpectedEnd c = self.value[i] i += 1 @@ -170,7 +174,10 @@ class Token: i += 1 if not (c2.isdigit() and c3.isdigit()): raise dns.exception.SyntaxError - unescaped += b'%c' % (int(c) * 100 + int(c2) * 10 + int(c3)) + codepoint = int(c) * 100 + int(c2) * 10 + int(c3) + if codepoint > 255: + raise dns.exception.SyntaxError + unescaped += b'%c' % (codepoint) else: # Note that as mentioned above, if c is a Unicode # code point outside of the ASCII range, then this @@ -184,7 +191,7 @@ class Token: class Tokenizer: - """A DNS master file format tokenizer. + """A DNS zone file format tokenizer. A token object is basically a (type, value) tuple. The valid types are EOF, EOL, WHITESPACE, IDENTIFIER, QUOTED_STRING, @@ -396,13 +403,13 @@ class Tokenizer: if self.multiline: raise dns.exception.SyntaxError( 'unbalanced parentheses') - return Token(EOF) + return Token(EOF, comment=token) elif self.multiline: self.skip_whitespace() token = '' continue else: - return Token(EOL, '\n') + return Token(EOL, '\n', comment=token) else: # This code exists in case we ever want a # delimiter to be returned. It never produces @@ -422,7 +429,7 @@ class Tokenizer: token += c has_escape = True c = self._get_char() - if c == '' or c == '\n': + if c == '' or (c == '\n' and not self.quoting): raise dns.exception.UnexpectedEnd token += c if token == '' and ttype != QUOTED_STRING: @@ -529,6 +536,21 @@ class Tokenizer: '%d is not an unsigned 32-bit integer' % value) return value + def get_uint48(self, base=10): + """Read the next token and interpret it as a 48-bit unsigned + integer. + + Raises dns.exception.SyntaxError if not a 48-bit unsigned integer. + + Returns an int. + """ + + value = self.get_int(base=base) + if value < 0 or value > 281474976710655: + raise dns.exception.SyntaxError( + '%d is not an unsigned 48-bit integer' % value) + return value + def get_string(self, max_length=None): """Read the next token and interpret it as a string. @@ -559,6 +581,25 @@ class Tokenizer: raise dns.exception.SyntaxError('expecting an identifier') return token.value + def get_remaining(self, max_tokens=None): + """Return the remaining tokens on the line, until an EOL or EOF is seen. + + max_tokens: If not None, stop after this number of tokens. + + Returns a list of tokens. + """ + + tokens = [] + while True: + token = self.get() + if token.is_eol_or_eof(): + self.unget(token) + break + tokens.append(token) + if len(tokens) == max_tokens: + break + return tokens + def concatenate_remaining_identifiers(self): """Read the remaining tokens on the line, which should be identifiers. @@ -572,6 +613,7 @@ class Tokenizer: while True: token = self.get().unescape() if token.is_eol_or_eof(): + self.unget(token) break if not token.is_identifier(): raise dns.exception.SyntaxError @@ -601,7 +643,7 @@ class Tokenizer: token = self.get() return self.as_name(token, origin, relativize, relativize_to) - def get_eol(self): + def get_eol_as_token(self): """Read the next token and raise an exception if it isn't EOL or EOF. @@ -613,7 +655,10 @@ class Tokenizer: raise dns.exception.SyntaxError( 'expected EOL or EOF, got %d "%s"' % (token.ttype, token.value)) - return token.value + return token + + def get_eol(self): + return self.get_eol_as_token().value def get_ttl(self): """Read the next token and interpret it as a DNS TTL. diff --git a/lib/dns/transaction.py b/lib/dns/transaction.py new file mode 100644 index 00000000..ae7417ed --- /dev/null +++ b/lib/dns/transaction.py @@ -0,0 +1,587 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import collections + +import dns.exception +import dns.name +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.rrset +import dns.serial +import dns.ttl + + +class TransactionManager: + def reader(self): + """Begin a read-only transaction.""" + raise NotImplementedError # pragma: no cover + + def writer(self, replacement=False): + """Begin a writable transaction. + + *replacement*, a ``bool``. If `True`, the content of the + transaction completely replaces any prior content. If False, + the default, then the content of the transaction updates the + existing content. + """ + raise NotImplementedError # pragma: no cover + + def origin_information(self): + """Returns a tuple + + (absolute_origin, relativize, effective_origin) + + giving the absolute name of the default origin for any + relative domain names, the "effective origin", and whether + names should be relativized. The "effective origin" is the + absolute origin if relativize is False, and the empty name if + relativize is true. (The effective origin is provided even + though it can be computed from the absolute_origin and + relativize setting because it avoids a lot of code + duplication.) + + If the returned names are `None`, then no origin information is + available. + + This information is used by code working with transactions to + allow it to coordinate relativization. The transaction code + itself takes what it gets (i.e. does not change name + relativity). + + """ + raise NotImplementedError # pragma: no cover + + def get_class(self): + """The class of the transaction manager. + """ + raise NotImplementedError # pragma: no cover + + def from_wire_origin(self): + """Origin to use in from_wire() calls. + """ + (absolute_origin, relativize, _) = self.origin_information() + if relativize: + return absolute_origin + else: + return None + + +class DeleteNotExact(dns.exception.DNSException): + """Existing data did not match data specified by an exact delete.""" + + +class ReadOnly(dns.exception.DNSException): + """Tried to write to a read-only transaction.""" + + +class AlreadyEnded(dns.exception.DNSException): + """Tried to use an already-ended transaction.""" + + +def _ensure_immutable_rdataset(rdataset): + if rdataset is None or isinstance(rdataset, dns.rdataset.ImmutableRdataset): + return rdataset + return dns.rdataset.ImmutableRdataset(rdataset) + +def _ensure_immutable_node(node): + if node is None or node.is_immutable(): + return node + return dns.node.ImmutableNode(node) + + +class Transaction: + + def __init__(self, manager, replacement=False, read_only=False): + self.manager = manager + self.replacement = replacement + self.read_only = read_only + self._ended = False + self._check_put_rdataset = [] + self._check_delete_rdataset = [] + self._check_delete_name = [] + + # + # This is the high level API + # + + def get(self, name, rdtype, covers=dns.rdatatype.NONE): + """Return the rdataset associated with *name*, *rdtype*, and *covers*, + or `None` if not found. + + Note that the returned rdataset is immutable. + """ + self._check_ended() + if isinstance(name, str): + name = dns.name.from_text(name, None) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdataset = self._get_rdataset(name, rdtype, covers) + return _ensure_immutable_rdataset(rdataset) + + def get_node(self, name): + """Return the node at *name*, if any. + + Returns an immutable node or ``None``. + """ + return _ensure_immutable_node(self._get_node(name)) + + def _check_read_only(self): + if self.read_only: + raise ReadOnly + + def add(self, *args): + """Add records. + + The arguments may be: + + - rrset + + - name, rdataset... + + - name, ttl, rdata... + """ + self._check_ended() + self._check_read_only() + return self._add(False, args) + + def replace(self, *args): + """Replace the existing rdataset at the name with the specified + rdataset, or add the specified rdataset if there was no existing + rdataset. + + The arguments may be: + + - rrset + + - name, rdataset... + + - name, ttl, rdata... + + Note that if you want to replace the entire node, you should do + a delete of the name followed by one or more calls to add() or + replace(). + """ + self._check_ended() + self._check_read_only() + return self._add(True, args) + + def delete(self, *args): + """Delete records. + + It is not an error if some of the records are not in the existing + set. + + The arguments may be: + + - rrset + + - name + + - name, rdataclass, rdatatype, [covers] + + - name, rdataset... + + - name, rdata... + """ + self._check_ended() + self._check_read_only() + return self._delete(False, args) + + def delete_exact(self, *args): + """Delete records. + + The arguments may be: + + - rrset + + - name + + - name, rdataclass, rdatatype, [covers] + + - name, rdataset... + + - name, rdata... + + Raises dns.transaction.DeleteNotExact if some of the records + are not in the existing set. + + """ + self._check_ended() + self._check_read_only() + return self._delete(True, args) + + def name_exists(self, name): + """Does the specified name exist?""" + self._check_ended() + if isinstance(name, str): + name = dns.name.from_text(name, None) + return self._name_exists(name) + + def update_serial(self, value=1, relative=True, name=dns.name.empty): + """Update the serial number. + + *value*, an `int`, is an increment if *relative* is `True`, or the + actual value to set if *relative* is `False`. + + Raises `KeyError` if there is no SOA rdataset at *name*. + + Raises `ValueError` if *value* is negative or if the increment is + so large that it would cause the new serial to be less than the + prior value. + """ + self._check_ended() + if value < 0: + raise ValueError('negative update_serial() value') + if isinstance(name, str): + name = dns.name.from_text(name, None) + rdataset = self._get_rdataset(name, dns.rdatatype.SOA, + dns.rdatatype.NONE) + if rdataset is None or len(rdataset) == 0: + raise KeyError + if relative: + serial = dns.serial.Serial(rdataset[0].serial) + value + else: + serial = dns.serial.Serial(value) + serial = serial.value # convert back to int + if serial == 0: + serial = 1 + rdata = rdataset[0].replace(serial=serial) + new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata) + self.replace(name, new_rdataset) + + def __iter__(self): + self._check_ended() + return self._iterate_rdatasets() + + def changed(self): + """Has this transaction changed anything? + + For read-only transactions, the result is always `False`. + + For writable transactions, the result is `True` if at some time + during the life of the transaction, the content was changed. + """ + self._check_ended() + return self._changed() + + def commit(self): + """Commit the transaction. + + Normally transactions are used as context managers and commit + or rollback automatically, but it may be done explicitly if needed. + A ``dns.transaction.Ended`` exception will be raised if you try + to use a transaction after it has been committed or rolled back. + + Raises an exception if the commit fails (in which case the transaction + is also rolled back. + """ + self._end(True) + + def rollback(self): + """Rollback the transaction. + + Normally transactions are used as context managers and commit + or rollback automatically, but it may be done explicitly if needed. + A ``dns.transaction.AlreadyEnded`` exception will be raised if you try + to use a transaction after it has been committed or rolled back. + + Rollback cannot otherwise fail. + """ + self._end(False) + + def check_put_rdataset(self, check): + """Call *check* before putting (storing) an rdataset. + + The function is called with the transaction, the name, and the rdataset. + + The check function may safely make non-mutating transaction method + calls, but behavior is undefined if mutating transaction methods are + called. The check function should raise an exception if it objects to + the put, and otherwise should return ``None``. + """ + self._check_put_rdataset.append(check) + + def check_delete_rdataset(self, check): + """Call *check* before deleting an rdataset. + + The function is called with the transaction, the name, the rdatatype, + and the covered rdatatype. + + The check function may safely make non-mutating transaction method + calls, but behavior is undefined if mutating transaction methods are + called. The check function should raise an exception if it objects to + the put, and otherwise should return ``None``. + """ + self._check_delete_rdataset.append(check) + + def check_delete_name(self, check): + """Call *check* before putting (storing) an rdataset. + + The function is called with the transaction and the name. + + The check function may safely make non-mutating transaction method + calls, but behavior is undefined if mutating transaction methods are + called. The check function should raise an exception if it objects to + the put, and otherwise should return ``None``. + """ + self._check_delete_name.append(check) + + # + # Helper methods + # + + def _raise_if_not_empty(self, method, args): + if len(args) != 0: + raise TypeError(f'extra parameters to {method}') + + def _rdataset_from_args(self, method, deleting, args): + try: + arg = args.popleft() + if isinstance(arg, dns.rrset.RRset): + rdataset = arg.to_rdataset() + elif isinstance(arg, dns.rdataset.Rdataset): + rdataset = arg + else: + if deleting: + ttl = 0 + else: + if isinstance(arg, int): + ttl = arg + if ttl > dns.ttl.MAX_TTL: + raise ValueError(f'{method}: TTL value too big') + else: + raise TypeError(f'{method}: expected a TTL') + arg = args.popleft() + if isinstance(arg, dns.rdata.Rdata): + rdataset = dns.rdataset.from_rdata(ttl, arg) + else: + raise TypeError(f'{method}: expected an Rdata') + return rdataset + except IndexError: + if deleting: + return None + else: + # reraise + raise TypeError(f'{method}: expected more arguments') + + def _add(self, replace, args): + try: + args = collections.deque(args) + if replace: + method = 'replace()' + else: + method = 'add()' + arg = args.popleft() + if isinstance(arg, str): + arg = dns.name.from_text(arg, None) + if isinstance(arg, dns.name.Name): + name = arg + rdataset = self._rdataset_from_args(method, False, args) + elif isinstance(arg, dns.rrset.RRset): + rrset = arg + name = rrset.name + # rrsets are also rdatasets, but they don't print the + # same and can't be stored in nodes, so convert. + rdataset = rrset.to_rdataset() + else: + raise TypeError(f'{method} requires a name or RRset ' + + 'as the first argument') + if rdataset.rdclass != self.manager.get_class(): + raise ValueError(f'{method} has objects of wrong RdataClass') + if rdataset.rdtype == dns.rdatatype.SOA: + (_, _, origin) = self.manager.origin_information() + if name != origin: + raise ValueError(f'{method} has non-origin SOA') + self._raise_if_not_empty(method, args) + if not replace: + existing = self._get_rdataset(name, rdataset.rdtype, + rdataset.covers) + if existing is not None: + if isinstance(existing, dns.rdataset.ImmutableRdataset): + trds = dns.rdataset.Rdataset(existing.rdclass, + existing.rdtype, + existing.covers) + trds.update(existing) + existing = trds + rdataset = existing.union(rdataset) + self._checked_put_rdataset(name, rdataset) + except IndexError: + raise TypeError(f'not enough parameters to {method}') + + def _delete(self, exact, args): + try: + args = collections.deque(args) + if exact: + method = 'delete_exact()' + else: + method = 'delete()' + arg = args.popleft() + if isinstance(arg, str): + arg = dns.name.from_text(arg, None) + if isinstance(arg, dns.name.Name): + name = arg + if len(args) > 0 and (isinstance(args[0], int) or + isinstance(args[0], str)): + # deleting by type and (optionally) covers + rdtype = dns.rdatatype.RdataType.make(args.popleft()) + if len(args) > 0: + covers = dns.rdatatype.RdataType.make(args.popleft()) + else: + covers = dns.rdatatype.NONE + self._raise_if_not_empty(method, args) + existing = self._get_rdataset(name, rdtype, covers) + if existing is None: + if exact: + raise DeleteNotExact(f'{method}: missing rdataset') + else: + self._delete_rdataset(name, rdtype, covers) + return + else: + rdataset = self._rdataset_from_args(method, True, args) + elif isinstance(arg, dns.rrset.RRset): + rdataset = arg # rrsets are also rdatasets + name = rdataset.name + else: + raise TypeError(f'{method} requires a name or RRset ' + + 'as the first argument') + self._raise_if_not_empty(method, args) + if rdataset: + if rdataset.rdclass != self.manager.get_class(): + raise ValueError(f'{method} has objects of wrong ' + 'RdataClass') + existing = self._get_rdataset(name, rdataset.rdtype, + rdataset.covers) + if existing is not None: + if exact: + intersection = existing.intersection(rdataset) + if intersection != rdataset: + raise DeleteNotExact(f'{method}: missing rdatas') + rdataset = existing.difference(rdataset) + if len(rdataset) == 0: + self._checked_delete_rdataset(name, rdataset.rdtype, + rdataset.covers) + else: + self._checked_put_rdataset(name, rdataset) + elif exact: + raise DeleteNotExact(f'{method}: missing rdataset') + else: + if exact and not self._name_exists(name): + raise DeleteNotExact(f'{method}: name not known') + self._checked_delete_name(name) + except IndexError: + raise TypeError(f'not enough parameters to {method}') + + def _check_ended(self): + if self._ended: + raise AlreadyEnded + + def _end(self, commit): + self._check_ended() + if self._ended: + raise AlreadyEnded + try: + self._end_transaction(commit) + finally: + self._ended = True + + def _checked_put_rdataset(self, name, rdataset): + for check in self._check_put_rdataset: + check(self, name, rdataset) + self._put_rdataset(name, rdataset) + + def _checked_delete_rdataset(self, name, rdtype, covers): + for check in self._check_delete_rdataset: + check(self, name, rdtype, covers) + self._delete_rdataset(name, rdtype, covers) + + def _checked_delete_name(self, name): + for check in self._check_delete_name: + check(self, name) + self._delete_name(name) + + # + # Transactions are context managers. + # + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._ended: + if exc_type is None: + self.commit() + else: + self.rollback() + return False + + # + # This is the low level API, which must be implemented by subclasses + # of Transaction. + # + + def _get_rdataset(self, name, rdtype, covers): + """Return the rdataset associated with *name*, *rdtype*, and *covers*, + or `None` if not found. + """ + raise NotImplementedError # pragma: no cover + + def _put_rdataset(self, name, rdataset): + """Store the rdataset.""" + raise NotImplementedError # pragma: no cover + + def _delete_name(self, name): + """Delete all data associated with *name*. + + It is not an error if the name does not exist. + """ + raise NotImplementedError # pragma: no cover + + def _delete_rdataset(self, name, rdtype, covers): + """Delete all data associated with *name*, *rdtype*, and *covers*. + + It is not an error if the rdataset does not exist. + """ + raise NotImplementedError # pragma: no cover + + def _name_exists(self, name): + """Does name exist? + + Returns a bool. + """ + raise NotImplementedError # pragma: no cover + + def _changed(self): + """Has this transaction changed anything?""" + raise NotImplementedError # pragma: no cover + + def _end_transaction(self, commit): + """End the transaction. + + *commit*, a bool. If ``True``, commit the transaction, otherwise + roll it back. + + If committing adn the commit fails, then roll back and raise an + exception. + """ + raise NotImplementedError # pragma: no cover + + def _set_origin(self, origin): + """Set the origin. + + This method is called when reading a possibly relativized + source, and an origin setting operation occurs (e.g. $ORIGIN + in a zone file). + """ + raise NotImplementedError # pragma: no cover + + def _iterate_rdatasets(self): + """Return an iterator that yields (name, rdataset) tuples. + """ + raise NotImplementedError # pragma: no cover + + def _get_node(self, name): + """Return the node at *name*, if any. + + Returns a node or ``None``. + """ + raise NotImplementedError # pragma: no cover diff --git a/lib/dns/tsig.py b/lib/dns/tsig.py index 8f34fe67..50b2d47e 100644 --- a/lib/dns/tsig.py +++ b/lib/dns/tsig.py @@ -71,31 +71,142 @@ class PeerBadTruncation(PeerError): """The peer didn't like amount of truncation in the TSIG we sent""" + # TSIG Algorithms HMAC_MD5 = dns.name.from_text("HMAC-MD5.SIG-ALG.REG.INT") HMAC_SHA1 = dns.name.from_text("hmac-sha1") HMAC_SHA224 = dns.name.from_text("hmac-sha224") HMAC_SHA256 = dns.name.from_text("hmac-sha256") +HMAC_SHA256_128 = dns.name.from_text("hmac-sha256-128") HMAC_SHA384 = dns.name.from_text("hmac-sha384") +HMAC_SHA384_192 = dns.name.from_text("hmac-sha384-192") HMAC_SHA512 = dns.name.from_text("hmac-sha512") - -_hashes = { - HMAC_SHA224: hashlib.sha224, - HMAC_SHA256: hashlib.sha256, - HMAC_SHA384: hashlib.sha384, - HMAC_SHA512: hashlib.sha512, - HMAC_SHA1: hashlib.sha1, - HMAC_MD5: hashlib.md5, -} +HMAC_SHA512_256 = dns.name.from_text("hmac-sha512-256") +GSS_TSIG = dns.name.from_text("gss-tsig") default_algorithm = HMAC_SHA256 +class GSSTSig: + """ + GSS-TSIG TSIG implementation. This uses the GSS-API context established + in the TKEY message handshake to sign messages using GSS-API message + integrity codes, per the RFC. + + In order to avoid a direct GSSAPI dependency, the keyring holds a ref + to the GSSAPI object required, rather than the key itself. + """ + def __init__(self, gssapi_context): + self.gssapi_context = gssapi_context + self.data = b'' + self.name = 'gss-tsig' + + def update(self, data): + self.data += data + + def sign(self): + # defer to the GSSAPI function to sign + return self.gssapi_context.get_signature(self.data) + + def verify(self, expected): + try: + # defer to the GSSAPI function to verify + return self.gssapi_context.verify_signature(self.data, expected) + except Exception: + # note the usage of a bare exception + raise BadSignature + + +class GSSTSigAdapter: + def __init__(self, keyring): + self.keyring = keyring + + def __call__(self, message, keyname): + if keyname in self.keyring: + key = self.keyring[keyname] + if isinstance(key, Key) and key.algorithm == GSS_TSIG: + if message: + GSSTSigAdapter.parse_tkey_and_step(key, message, keyname) + return key + else: + return None + + @classmethod + def parse_tkey_and_step(cls, key, message, keyname): + # if the message is a TKEY type, absorb the key material + # into the context using step(); this is used to allow the + # client to complete the GSSAPI negotiation before attempting + # to verify the signed response to a TKEY message exchange + try: + rrset = message.find_rrset(message.answer, keyname, + dns.rdataclass.ANY, + dns.rdatatype.TKEY) + if rrset: + token = rrset[0].key + gssapi_context = key.secret + return gssapi_context.step(token) + except KeyError: + pass + + +class HMACTSig: + """ + HMAC TSIG implementation. This uses the HMAC python module to handle the + sign/verify operations. + """ + + _hashes = { + HMAC_SHA1: hashlib.sha1, + HMAC_SHA224: hashlib.sha224, + HMAC_SHA256: hashlib.sha256, + HMAC_SHA256_128: (hashlib.sha256, 128), + HMAC_SHA384: hashlib.sha384, + HMAC_SHA384_192: (hashlib.sha384, 192), + HMAC_SHA512: hashlib.sha512, + HMAC_SHA512_256: (hashlib.sha512, 256), + HMAC_MD5: hashlib.md5, + } + + def __init__(self, key, algorithm): + try: + hashinfo = self._hashes[algorithm] + except KeyError: + raise NotImplementedError(f"TSIG algorithm {algorithm} " + + "is not supported") + + # create the HMAC context + if isinstance(hashinfo, tuple): + self.hmac_context = hmac.new(key, digestmod=hashinfo[0]) + self.size = hashinfo[1] + else: + self.hmac_context = hmac.new(key, digestmod=hashinfo) + self.size = None + self.name = self.hmac_context.name + if self.size: + self.name += f'-{self.size}' + + def update(self, data): + return self.hmac_context.update(data) + + def sign(self): + # defer to the HMAC digest() function for that digestmod + digest = self.hmac_context.digest() + if self.size: + digest = digest[: (self.size // 8)] + return digest + + def verify(self, expected): + # re-digest and compare the results + mac = self.sign() + if not hmac.compare_digest(mac, expected): + raise BadSignature + + def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None): """Return a context containing the TSIG rdata for the input parameters - @rtype: hmac.HMAC object + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object @raises ValueError: I{other_data} is too long @raises NotImplementedError: I{algorithm} is not supported """ @@ -131,7 +242,7 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, def _maybe_start_digest(key, mac, multi): """If this is the first message in a multi-message sequence, start a new context. - @rtype: hmac.HMAC object + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object """ if multi: ctx = get_context(key) @@ -146,17 +257,14 @@ def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False): """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata for the input parameters, the HMAC MAC calculated by applying the TSIG signature algorithm, and the TSIG digest context. - @rtype: (string, hmac.HMAC object) + @rtype: (string, dns.tsig.HMACTSig or dns.tsig.GSSTSig object) @raises ValueError: I{other_data} is too long @raises NotImplementedError: I{algorithm} is not supported """ ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi) - mac = ctx.digest() - tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, - key.algorithm, time, rdata.fudge, mac, - rdata.original_id, rdata.error, - rdata.other) + mac = ctx.sign() + tsig = rdata.replace(time_signed=time, mac=mac) return (tsig, _maybe_start_digest(key, mac, multi)) @@ -169,7 +277,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, @raises BadTime: There is too much time skew between the client and the server. @raises BadSignature: The TSIG signature did not validate - @rtype: hmac.HMAC object""" + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object""" (adcount,) = struct.unpack("!H", wire[10:12]) if adcount == 0: @@ -194,25 +302,21 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, if key.algorithm != rdata.algorithm: raise BadAlgorithm ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi) - mac = ctx.digest() - if not hmac.compare_digest(mac, rdata.mac): - raise BadSignature - return _maybe_start_digest(key, mac, multi) + ctx.verify(rdata.mac) + return _maybe_start_digest(key, rdata.mac, multi) def get_context(key): - """Returns an HMAC context foe the specified key. + """Returns an HMAC context for the specified key. @rtype: HMAC context @raises NotImplementedError: I{algorithm} is not supported """ - try: - digestmod = _hashes[key.algorithm] - except KeyError: - raise NotImplementedError(f"TSIG algorithm {key.algorithm} " + - "is not supported") - return hmac.new(key.secret, digestmod=digestmod) + if key.algorithm == GSS_TSIG: + return GSSTSig(key.secret) + else: + return HMACTSig(key.secret, key.algorithm) class Key: @@ -232,3 +336,11 @@ class Key: self.name == other.name and self.secret == other.secret and self.algorithm == other.algorithm) + + def __repr__(self): + r = f" Dict[name.Name,bytes]: + ... +def to_text(keyring : Dict[name.Name,bytes]) -> Dict[str, str]: + ... diff --git a/lib/dns/ttl.py b/lib/dns/ttl.py index 55ae5e16..df92b2b6 100644 --- a/lib/dns/ttl.py +++ b/lib/dns/ttl.py @@ -19,6 +19,13 @@ import dns.exception +# Technically TTLs are supposed to be between 0 and 2**31 - 1, with values +# greater than that interpreted as 0, but we do not impose this policy here +# as values > 2**31 - 1 occur in real world data. +# +# We leave it to applications to impose tighter bounds if desired. +MAX_TTL = 2**32 - 1 + class BadTTL(dns.exception.SyntaxError): """DNS TTL value is not well-formed.""" @@ -38,16 +45,20 @@ def from_text(text): if text.isdigit(): total = int(text) + elif len(text) == 0: + raise BadTTL else: - if not text[0].isdigit(): - raise BadTTL total = 0 current = 0 + need_digit = True for c in text: if c.isdigit(): current *= 10 current += int(c) + need_digit = False else: + if need_digit: + raise BadTTL c = c.lower() if c == 'w': total += current * 604800 @@ -62,8 +73,18 @@ def from_text(text): else: raise BadTTL("unknown unit '%s'" % c) current = 0 + need_digit = True if not current == 0: raise BadTTL("trailing integer") - if total < 0 or total > 2147483647: - raise BadTTL("TTL should be between 0 and 2^31 - 1 (inclusive)") + if total < 0 or total > MAX_TTL: + raise BadTTL("TTL should be between 0 and 2**32 - 1 (inclusive)") return total + + +def make(value): + if isinstance(value, int): + return value + elif isinstance(value, str): + return dns.ttl.from_text(value) + else: + raise ValueError('cannot convert value to TTL') diff --git a/lib/dns/update.py b/lib/dns/update.py index 8e796504..a541af22 100644 --- a/lib/dns/update.py +++ b/lib/dns/update.py @@ -38,8 +38,6 @@ class UpdateSection(dns.enum.IntEnum): def _maximum(cls): return 3 -globals().update(UpdateSection.__members__) - class UpdateMessage(dns.message.Message): @@ -310,3 +308,12 @@ class UpdateMessage(dns.message.Message): # backwards compatibility Update = UpdateMessage + +### BEGIN generated UpdateSection constants + +ZONE = UpdateSection.ZONE +PREREQ = UpdateSection.PREREQ +UPDATE = UpdateSection.UPDATE +ADDITIONAL = UpdateSection.ADDITIONAL + +### END generated UpdateSection constants diff --git a/lib/dns/update.pyi b/lib/dns/update.pyi new file mode 100644 index 00000000..eeac0591 --- /dev/null +++ b/lib/dns/update.pyi @@ -0,0 +1,21 @@ +from typing import Optional,Dict,Union,Any + +from . import message, tsig, rdataclass, name + +class Update(message.Message): + def __init__(self, zone : Union[name.Name, str], rdclass : Union[int,str] = rdataclass.IN, keyring : Optional[Dict[name.Name,bytes]] = None, + keyname : Optional[name.Name] = None, keyalgorithm : Optional[name.Name] = tsig.default_algorithm) -> None: + self.id : int + def add(self, name : Union[str,name.Name], *args : Any): + ... + def delete(self, name, *args : Any): + ... + def replace(self, name : Union[str,name.Name], *args : Any): + ... + def present(self, name : Union[str,name.Name], *args : Any): + ... + def absent(self, name : Union[str,name.Name], rdtype=None): + """Require that an owner name (and optionally an rdata type) does + not exist as a prerequisite to the execution of the update.""" + def to_wire(self, origin : Optional[name.Name] = None, max_size=65535, **kw) -> bytes: + ... diff --git a/lib/dns/version.py b/lib/dns/version.py index 0b7c1d13..745a5c7f 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -20,7 +20,7 @@ #: MAJOR MAJOR = 2 #: MINOR -MINOR = 0 +MINOR = 2 #: MICRO MICRO = 0 #: RELEASELEVEL diff --git a/lib/dns/versioned.py b/lib/dns/versioned.py new file mode 100644 index 00000000..42f2c814 --- /dev/null +++ b/lib/dns/versioned.py @@ -0,0 +1,274 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""DNS Versioned Zones.""" + +import collections +try: + import threading as _threading +except ImportError: # pragma: no cover + import dummy_threading as _threading # type: ignore + +import dns.exception +import dns.immutable +import dns.name +import dns.rdataclass +import dns.rdatatype +import dns.rdtypes.ANY.SOA +import dns.zone + + +class UseTransaction(dns.exception.DNSException): + """To alter a versioned zone, use a transaction.""" + + +# Backwards compatibility +Node = dns.zone.VersionedNode +ImmutableNode = dns.zone.ImmutableVersionedNode +Version = dns.zone.Version +WritableVersion = dns.zone.WritableVersion +ImmutableVersion = dns.zone.ImmutableVersion +Transaction = dns.zone.Transaction + + +class Zone(dns.zone.Zone): + + __slots__ = ['_versions', '_versions_lock', '_write_txn', + '_write_waiters', '_write_event', '_pruning_policy', + '_readers'] + + node_factory = Node + + def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, + pruning_policy=None): + """Initialize a versioned zone object. + + *origin* is the origin of the zone. It may be a ``dns.name.Name``, + a ``str``, or ``None``. If ``None``, then the zone's origin will + be set by the first ``$ORIGIN`` line in a zone file. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *pruning policy*, a function taking a `Version` and returning + a `bool`, or `None`. Should the version be pruned? If `None`, + the default policy, which retains one version is used. + """ + super().__init__(origin, rdclass, relativize) + self._versions = collections.deque() + self._version_lock = _threading.Lock() + if pruning_policy is None: + self._pruning_policy = self._default_pruning_policy + else: + self._pruning_policy = pruning_policy + self._write_txn = None + self._write_event = None + self._write_waiters = collections.deque() + self._readers = set() + self._commit_version_unlocked(None, + WritableVersion(self, replacement=True), + origin) + + def reader(self, id=None, serial=None): # pylint: disable=arguments-differ + if id is not None and serial is not None: + raise ValueError('cannot specify both id and serial') + with self._version_lock: + if id is not None: + version = None + for v in reversed(self._versions): + if v.id == id: + version = v + break + if version is None: + raise KeyError('version not found') + elif serial is not None: + if self.relativize: + oname = dns.name.empty + else: + oname = self.origin + version = None + for v in reversed(self._versions): + n = v.nodes.get(oname) + if n: + rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) + if rds and rds[0].serial == serial: + version = v + break + if version is None: + raise KeyError('serial not found') + else: + version = self._versions[-1] + txn = Transaction(self, False, version) + self._readers.add(txn) + return txn + + def writer(self, replacement=False): + event = None + while True: + with self._version_lock: + # Checking event == self._write_event ensures that either + # no one was waiting before we got lucky and found no write + # txn, or we were the one who was waiting and got woken up. + # This prevents "taking cuts" when creating a write txn. + if self._write_txn is None and event == self._write_event: + # Creating the transaction defers version setup + # (i.e. copying the nodes dictionary) until we + # give up the lock, so that we hold the lock as + # short a time as possible. This is why we call + # _setup_version() below. + self._write_txn = Transaction(self, replacement, + make_immutable=True) + # give up our exclusive right to make a Transaction + self._write_event = None + break + # Someone else is writing already, so we will have to + # wait, but we want to do the actual wait outside the + # lock. + event = _threading.Event() + self._write_waiters.append(event) + # wait (note we gave up the lock!) + # + # We only wake one sleeper at a time, so it's important + # that no event waiter can exit this method (e.g. via + # cancelation) without returning a transaction or waking + # someone else up. + # + # This is not a problem with Threading module threads as + # they cannot be canceled, but could be an issue with trio + # or curio tasks when we do the async version of writer(). + # I.e. we'd need to do something like: + # + # try: + # event.wait() + # except trio.Cancelled: + # with self._version_lock: + # self._maybe_wakeup_one_waiter_unlocked() + # raise + # + event.wait() + # Do the deferred version setup. + self._write_txn._setup_version() + return self._write_txn + + def _maybe_wakeup_one_waiter_unlocked(self): + if len(self._write_waiters) > 0: + self._write_event = self._write_waiters.popleft() + self._write_event.set() + + # pylint: disable=unused-argument + def _default_pruning_policy(self, zone, version): + return True + # pylint: enable=unused-argument + + def _prune_versions_unlocked(self): + assert len(self._versions) > 0 + # Don't ever prune a version greater than or equal to one that + # a reader has open. This pins versions in memory while the + # reader is open, and importantly lets the reader open a txn on + # a successor version (e.g. if generating an IXFR). + # + # Note our definition of least_kept also ensures we do not try to + # delete the greatest version. + if len(self._readers) > 0: + least_kept = min(txn.version.id for txn in self._readers) + else: + least_kept = self._versions[-1].id + while self._versions[0].id < least_kept and \ + self._pruning_policy(self, self._versions[0]): + self._versions.popleft() + + def set_max_versions(self, max_versions): + """Set a pruning policy that retains up to the specified number + of versions + """ + if max_versions is not None and max_versions < 1: + raise ValueError('max versions must be at least 1') + if max_versions is None: + def policy(*_): + return False + else: + def policy(zone, _): + return len(zone._versions) > max_versions + self.set_pruning_policy(policy) + + def set_pruning_policy(self, policy): + """Set the pruning policy for the zone. + + The *policy* function takes a `Version` and returns `True` if + the version should be pruned, and `False` otherwise. `None` + may also be specified for policy, in which case the default policy + is used. + + Pruning checking proceeds from the least version and the first + time the function returns `False`, the checking stops. I.e. the + retained versions are always a consecutive sequence. + """ + if policy is None: + policy = self._default_pruning_policy + with self._version_lock: + self._pruning_policy = policy + self._prune_versions_unlocked() + + def _end_read(self, txn): + with self._version_lock: + self._readers.remove(txn) + self._prune_versions_unlocked() + + def _end_write_unlocked(self, txn): + assert self._write_txn == txn + self._write_txn = None + self._maybe_wakeup_one_waiter_unlocked() + + def _end_write(self, txn): + with self._version_lock: + self._end_write_unlocked(txn) + + def _commit_version_unlocked(self, txn, version, origin): + self._versions.append(version) + self._prune_versions_unlocked() + self.nodes = version.nodes + if self.origin is None: + self.origin = origin + # txn can be None in __init__ when we make the empty version. + if txn is not None: + self._end_write_unlocked(txn) + + def _commit_version(self, txn, version, origin): + with self._version_lock: + self._commit_version_unlocked(txn, version, origin) + + def _get_next_version_id(self): + if len(self._versions) > 0: + id = self._versions[-1].id + 1 + else: + id = 1 + return id + + def find_node(self, name, create=False): + if create: + raise UseTransaction + return super().find_node(name) + + def delete_node(self, name): + raise UseTransaction + + def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise UseTransaction + rdataset = super().find_rdataset(name, rdtype, covers) + return dns.rdataset.ImmutableRdataset(rdataset) + + def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise UseTransaction + rdataset = super().get_rdataset(name, rdtype, covers) + return dns.rdataset.ImmutableRdataset(rdataset) + + def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): + raise UseTransaction + + def replace_rdataset(self, name, replacement): + raise UseTransaction diff --git a/lib/dns/win32util.py b/lib/dns/win32util.py new file mode 100644 index 00000000..745317a3 --- /dev/null +++ b/lib/dns/win32util.py @@ -0,0 +1,235 @@ +import sys + +if sys.platform == 'win32': + + import dns.name + + _prefer_wmi = True + + import winreg + + try: + try: + import threading as _threading + except ImportError: # pragma: no cover + import dummy_threading as _threading # type: ignore + import pythoncom + import wmi + _have_wmi = True + except Exception: + _have_wmi = False + + def _config_domain(domain): + # Sometimes DHCP servers add a '.' prefix to the default domain, and + # Windows just stores such values in the registry (see #687). + # Check for this and fix it. + if domain.startswith('.'): + domain = domain[1:] + return dns.name.from_text(domain) + + class DnsInfo: + def __init__(self): + self.domain = None + self.nameservers = [] + self.search = [] + + if _have_wmi: + class _WMIGetter(_threading.Thread): + def __init__(self): + super().__init__() + self.info = DnsInfo() + + def run(self): + pythoncom.CoInitialize() + try: + system = wmi.WMI() + for interface in system.Win32_NetworkAdapterConfiguration(): + if interface.IPEnabled: + self.info.domain = _config_domain(interface.DNSDomain) + self.info.nameservers = list(interface.DNSServerSearchOrder) + self.info.search = [dns.name.from_text(x) for x in + interface.DNSDomainSuffixSearchOrder] + break + finally: + pythoncom.CoUninitialize() + + def get(self): + # We always run in a separate thread to avoid any issues with + # the COM threading model. + self.start() + self.join() + return self.info + else: + class _WMIGetter: + pass + + + class _RegistryGetter: + def __init__(self): + self.info = DnsInfo() + + def _determine_split_char(self, entry): + # + # The windows registry irritatingly changes the list element + # delimiter in between ' ' and ',' (and vice-versa) in various + # versions of windows. + # + if entry.find(' ') >= 0: + split_char = ' ' + elif entry.find(',') >= 0: + split_char = ',' + else: + # probably a singleton; treat as a space-separated list. + split_char = ' ' + return split_char + + def _config_nameservers(self, nameservers): + split_char = self._determine_split_char(nameservers) + ns_list = nameservers.split(split_char) + for ns in ns_list: + if ns not in self.info.nameservers: + self.info.nameservers.append(ns) + + def _config_search(self, search): + split_char = self._determine_split_char(search) + search_list = search.split(split_char) + for s in search_list: + s = dns.name.from_text(s) + if s not in self.info.search: + self.info.search.append(s) + + def _config_fromkey(self, key, always_try_domain): + try: + servers, _ = winreg.QueryValueEx(key, 'NameServer') + except WindowsError: + servers = None + if servers: + self._config_nameservers(servers) + if servers or always_try_domain: + try: + dom, _ = winreg.QueryValueEx(key, 'Domain') + if dom: + self.info.domain = _config_domain(dom) + except WindowsError: + pass + else: + try: + servers, _ = winreg.QueryValueEx(key, 'DhcpNameServer') + except WindowsError: + servers = None + if servers: + self._config_nameservers(servers) + try: + dom, _ = winreg.QueryValueEx(key, 'DhcpDomain') + if dom: + self.info.domain = _config_domain(dom) + except WindowsError: + pass + try: + search, _ = winreg.QueryValueEx(key, 'SearchList') + except WindowsError: + search = None + if search is None: + try: + search, _ = winreg.QueryValueEx(key, 'DhcpSearchList') + except WindowsError: + search = None + if search: + self._config_search(search) + + def _is_nic_enabled(self, lm, guid): + # Look in the Windows Registry to determine whether the network + # interface corresponding to the given guid is enabled. + # + # (Code contributed by Paul Marks, thanks!) + # + try: + # This hard-coded location seems to be consistent, at least + # from Windows 2000 through Vista. + connection_key = winreg.OpenKey( + lm, + r'SYSTEM\CurrentControlSet\Control\Network' + r'\{4D36E972-E325-11CE-BFC1-08002BE10318}' + r'\%s\Connection' % guid) + + try: + # The PnpInstanceID points to a key inside Enum + (pnp_id, ttype) = winreg.QueryValueEx( + connection_key, 'PnpInstanceID') + + if ttype != winreg.REG_SZ: + raise ValueError # pragma: no cover + + device_key = winreg.OpenKey( + lm, r'SYSTEM\CurrentControlSet\Enum\%s' % pnp_id) + + try: + # Get ConfigFlags for this device + (flags, ttype) = winreg.QueryValueEx( + device_key, 'ConfigFlags') + + if ttype != winreg.REG_DWORD: + raise ValueError # pragma: no cover + + # Based on experimentation, bit 0x1 indicates that the + # device is disabled. + # + # XXXRTH I suspect we really want to & with 0x03 so + # that CONFIGFLAGS_REMOVED devices are also ignored, + # but we're shifting to WMI as ConfigFlags is not + # supposed to be used. + return not flags & 0x1 + + finally: + device_key.Close() + finally: + connection_key.Close() + except Exception: # pragma: no cover + return False + + def get(self): + """Extract resolver configuration from the Windows registry.""" + + lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) + try: + tcp_params = winreg.OpenKey(lm, + r'SYSTEM\CurrentControlSet' + r'\Services\Tcpip\Parameters') + try: + self._config_fromkey(tcp_params, True) + finally: + tcp_params.Close() + interfaces = winreg.OpenKey(lm, + r'SYSTEM\CurrentControlSet' + r'\Services\Tcpip\Parameters' + r'\Interfaces') + try: + i = 0 + while True: + try: + guid = winreg.EnumKey(interfaces, i) + i += 1 + key = winreg.OpenKey(interfaces, guid) + try: + if not self._is_nic_enabled(lm, guid): + continue + self._config_fromkey(key, False) + finally: + key.Close() + except EnvironmentError: + break + finally: + interfaces.Close() + finally: + lm.Close() + return self.info + + if _have_wmi and _prefer_wmi: + _getter_class = _WMIGetter + else: + _getter_class = _RegistryGetter + + def get_dns_info(): + """Extract resolver configuration.""" + getter = _getter_class() + return getter.get() diff --git a/lib/dns/wire.py b/lib/dns/wire.py index a3149605..572e27e7 100644 --- a/lib/dns/wire.py +++ b/lib/dns/wire.py @@ -42,6 +42,9 @@ class Parser: def get_uint32(self): return struct.unpack('!I', self.get_bytes(4))[0] + def get_uint48(self): + return int.from_bytes(self.get_bytes(6), 'big') + def get_struct(self, format): return struct.unpack(format, self.get_bytes(struct.calcsize(format))) diff --git a/lib/dns/xfr.py b/lib/dns/xfr.py new file mode 100644 index 00000000..cf9a163e --- /dev/null +++ b/lib/dns/xfr.py @@ -0,0 +1,313 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.exception +import dns.message +import dns.name +import dns.rcode +import dns.serial +import dns.rdatatype +import dns.zone + + +class TransferError(dns.exception.DNSException): + """A zone transfer response got a non-zero rcode.""" + + def __init__(self, rcode): + message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) + super().__init__(message) + self.rcode = rcode + + +class SerialWentBackwards(dns.exception.FormError): + """The current serial number is less than the serial we know.""" + + +class UseTCP(dns.exception.DNSException): + """This IXFR cannot be completed with UDP.""" + + +class Inbound: + """ + State machine for zone transfers. + """ + + def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR, + serial=None, is_udp=False): + """Initialize an inbound zone transfer. + + *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. + + *rdtype* can be `dns.rdatatype.AXFR` or `dns.rdatatype.IXFR` + + *serial* is the base serial number for IXFRs, and is required in + that case. + + *is_udp*, a ``bool`` indidicates if UDP is being used for this + XFR. + """ + self.txn_manager = txn_manager + self.txn = None + self.rdtype = rdtype + if rdtype == dns.rdatatype.IXFR: + if serial is None: + raise ValueError('a starting serial must be supplied for IXFRs') + elif is_udp: + raise ValueError('is_udp specified for AXFR') + self.serial = serial + self.is_udp = is_udp + (_, _, self.origin) = txn_manager.origin_information() + self.soa_rdataset = None + self.done = False + self.expecting_SOA = False + self.delete_mode = False + + def process_message(self, message): + """Process one message in the transfer. + + The message should have the same relativization as was specified when + the `dns.xfr.Inbound` was created. The message should also have been + created with `one_rr_per_rrset=True` because order matters. + + Returns `True` if the transfer is complete, and `False` otherwise. + """ + if self.txn is None: + replacement = self.rdtype == dns.rdatatype.AXFR + self.txn = self.txn_manager.writer(replacement) + rcode = message.rcode() + if rcode != dns.rcode.NOERROR: + raise TransferError(rcode) + # + # We don't require a question section, but if it is present is + # should be correct. + # + if len(message.question) > 0: + if message.question[0].name != self.origin: + raise dns.exception.FormError("wrong question name") + if message.question[0].rdtype != self.rdtype: + raise dns.exception.FormError("wrong question rdatatype") + answer_index = 0 + if self.soa_rdataset is None: + # + # This is the first message. We're expecting an SOA at + # the origin. + # + if not message.answer or message.answer[0].name != self.origin: + raise dns.exception.FormError("No answer or RRset not " + "for zone origin") + rrset = message.answer[0] + name = rrset.name + rdataset = rrset + if rdataset.rdtype != dns.rdatatype.SOA: + raise dns.exception.FormError("first RRset is not an SOA") + answer_index = 1 + self.soa_rdataset = rdataset.copy() + if self.rdtype == dns.rdatatype.IXFR: + if self.soa_rdataset[0].serial == self.serial: + # + # We're already up-to-date. + # + self.done = True + elif dns.serial.Serial(self.soa_rdataset[0].serial) < \ + self.serial: + # It went backwards! + raise SerialWentBackwards + else: + if self.is_udp and len(message.answer[answer_index:]) == 0: + # + # There are no more records, so this is the + # "truncated" response. Say to use TCP + # + raise UseTCP + # + # Note we're expecting another SOA so we can detect + # if this IXFR response is an AXFR-style response. + # + self.expecting_SOA = True + # + # Process the answer section (other than the initial SOA in + # the first message). + # + for rrset in message.answer[answer_index:]: + name = rrset.name + rdataset = rrset + if self.done: + raise dns.exception.FormError("answers after final SOA") + if rdataset.rdtype == dns.rdatatype.SOA and \ + name == self.origin: + # + # Every time we see an origin SOA delete_mode inverts + # + if self.rdtype == dns.rdatatype.IXFR: + self.delete_mode = not self.delete_mode + # + # If this SOA Rdataset is equal to the first we saw + # then we're finished. If this is an IXFR we also + # check that we're seeing the record in the expected + # part of the response. + # + if rdataset == self.soa_rdataset and \ + (self.rdtype == dns.rdatatype.AXFR or + (self.rdtype == dns.rdatatype.IXFR and + self.delete_mode)): + # + # This is the final SOA + # + if self.expecting_SOA: + # We got an empty IXFR sequence! + raise dns.exception.FormError('empty IXFR sequence') + if self.rdtype == dns.rdatatype.IXFR \ + and self.serial != rdataset[0].serial: + raise dns.exception.FormError('unexpected end of IXFR ' + 'sequence') + self.txn.replace(name, rdataset) + self.txn.commit() + self.txn = None + self.done = True + else: + # + # This is not the final SOA + # + self.expecting_SOA = False + if self.rdtype == dns.rdatatype.IXFR: + if self.delete_mode: + # This is the start of an IXFR deletion set + if rdataset[0].serial != self.serial: + raise dns.exception.FormError( + "IXFR base serial mismatch") + else: + # This is the start of an IXFR addition set + self.serial = rdataset[0].serial + self.txn.replace(name, rdataset) + else: + # We saw a non-final SOA for the origin in an AXFR. + raise dns.exception.FormError('unexpected origin SOA ' + 'in AXFR') + continue + if self.expecting_SOA: + # + # We made an IXFR request and are expecting another + # SOA RR, but saw something else, so this must be an + # AXFR response. + # + self.rdtype = dns.rdatatype.AXFR + self.expecting_SOA = False + self.delete_mode = False + self.txn.rollback() + self.txn = self.txn_manager.writer(True) + # + # Note we are falling through into the code below + # so whatever rdataset this was gets written. + # + # Add or remove the data + if self.delete_mode: + self.txn.delete_exact(name, rdataset) + else: + self.txn.add(name, rdataset) + if self.is_udp and not self.done: + # + # This is a UDP IXFR and we didn't get to done, and we didn't + # get the proper "truncated" response + # + raise dns.exception.FormError('unexpected end of UDP IXFR') + return self.done + + # + # Inbounds are context managers. + # + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.txn: + self.txn.rollback() + return False + + +def make_query(txn_manager, serial=0, + use_edns=None, ednsflags=None, payload=None, + request_payload=None, options=None, + keyring=None, keyname=None, + keyalgorithm=dns.tsig.default_algorithm): + """Make an AXFR or IXFR query. + + *txn_manager* is a ``dns.transaction.TransactionManager``, typically a + ``dns.zone.Zone``. + + *serial* is an ``int`` or ``None``. If 0, then IXFR will be + attempted using the most recent serial number from the + *txn_manager*; it is the caller's responsibility to ensure there + are no write transactions active that could invalidate the + retrieved serial. If a serial cannot be determined, AXFR will be + forced. Other integer values are the starting serial to use. + ``None`` forces an AXFR. + + Please see the documentation for :py:func:`dns.message.make_query` and + :py:func:`dns.message.Message.use_tsig` for details on the other parameters + to this function. + + Returns a `(query, serial)` tuple. + """ + (zone_origin, _, origin) = txn_manager.origin_information() + if serial is None: + rdtype = dns.rdatatype.AXFR + elif not isinstance(serial, int): + raise ValueError('serial is not an integer') + elif serial == 0: + with txn_manager.reader() as txn: + rdataset = txn.get(origin, 'SOA') + if rdataset: + serial = rdataset[0].serial + rdtype = dns.rdatatype.IXFR + else: + serial = None + rdtype = dns.rdatatype.AXFR + elif serial > 0 and serial < 4294967296: + rdtype = dns.rdatatype.IXFR + else: + raise ValueError('serial out-of-range') + rdclass = txn_manager.get_class() + q = dns.message.make_query(zone_origin, rdtype, rdclass, + use_edns, False, ednsflags, payload, + request_payload, options) + if serial is not None: + rdata = dns.rdata.from_text(rdclass, 'SOA', f'. . {serial} 0 0 0 0') + rrset = q.find_rrset(q.authority, zone_origin, rdclass, + dns.rdatatype.SOA, create=True) + rrset.add(rdata, 0) + if keyring is not None: + q.use_tsig(keyring, keyname, algorithm=keyalgorithm) + return (q, serial) + +def extract_serial_from_query(query): + """Extract the SOA serial number from query if it is an IXFR and return + it, otherwise return None. + + *query* is a dns.message.QueryMessage that is an IXFR or AXFR request. + + Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have + an appropriate SOA RRset in the authority section.""" + + question = query.question[0] + if question.rdtype == dns.rdatatype.AXFR: + return None + elif question.rdtype != dns.rdatatype.IXFR: + raise ValueError("query is not an AXFR or IXFR") + soa = query.find_rrset(query.authority, question.name, question.rdclass, + dns.rdatatype.SOA) + return soa[0].serial diff --git a/lib/dns/zone.py b/lib/dns/zone.py index e8413c08..2e731446 100644 --- a/lib/dns/zone.py +++ b/lib/dns/zone.py @@ -18,22 +18,26 @@ """DNS Zones.""" import contextlib +import hashlib import io import os -import re -import sys +import struct import dns.exception +import dns.immutable import dns.name import dns.node import dns.rdataclass import dns.rdatatype import dns.rdata import dns.rdtypes.ANY.SOA +import dns.rdtypes.ANY.ZONEMD import dns.rrset import dns.tokenizer +import dns.transaction import dns.ttl import dns.grange +import dns.zonefile class BadZone(dns.exception.DNSException): @@ -56,7 +60,54 @@ class UnknownOrigin(BadZone): """The DNS zone's origin is unknown.""" -class Zone: +class UnsupportedDigestScheme(dns.exception.DNSException): + + """The zone digest's scheme is unsupported.""" + + +class UnsupportedDigestHashAlgorithm(dns.exception.DNSException): + + """The zone digest's origin is unsupported.""" + + +class NoDigest(dns.exception.DNSException): + + """The DNS zone has no ZONEMD RRset at its origin.""" + + +class DigestVerificationFailure(dns.exception.DNSException): + + """The ZONEMD digest failed to verify.""" + + +class DigestScheme(dns.enum.IntEnum): + """ZONEMD Scheme""" + + SIMPLE = 1 + + @classmethod + def _maximum(cls): + return 255 + + +class DigestHashAlgorithm(dns.enum.IntEnum): + """ZONEMD Hash Algorithm""" + + SHA384 = 1 + SHA512 = 2 + + @classmethod + def _maximum(cls): + return 255 + + +_digest_hashers = { + DigestHashAlgorithm.SHA384: hashlib.sha384, + DigestHashAlgorithm.SHA512: hashlib.sha512, +} + + +class Zone(dns.transaction.TransactionManager): """A DNS zone. @@ -77,7 +128,7 @@ class Zone: *origin* is the origin of the zone. It may be a ``dns.name.Name``, a ``str``, or ``None``. If ``None``, then the zone's origin will - be set by the first ``$ORIGIN`` line in a masterfile. + be set by the first ``$ORIGIN`` line in a zone file. *rdclass*, an ``int``, the zone's rdata class; the default is class IN. @@ -150,20 +201,21 @@ class Zone: return self.nodes.__iter__() def keys(self): - return self.nodes.keys() # pylint: disable=dict-keys-not-iterating + return self.nodes.keys() def values(self): - return self.nodes.values() # pylint: disable=dict-values-not-iterating + return self.nodes.values() def items(self): - return self.nodes.items() # pylint: disable=dict-items-not-iterating + return self.nodes.items() def get(self, key): key = self._validate_name(key) return self.nodes.get(key) - def __contains__(self, other): - return other in self.nodes + def __contains__(self, key): + key = self._validate_name(key) + return key in self.nodes def find_node(self, name, create=False): """Find a node in the zone, possibly creating it. @@ -532,7 +584,8 @@ class Zone: for rdata in rds: yield (name, rds.ttl, rdata) - def to_file(self, f, sorted=True, relativize=True, nl=None): + def to_file(self, f, sorted=True, relativize=True, nl=None, + want_comments=False, want_origin=False): """Write a zone to a file. *f*, a file or `str`. If *f* is a string, it is treated @@ -550,6 +603,14 @@ class Zone: *nl*, a ``str`` or None. The end of line string. If not ``None``, the output will use the platform's native end-of-line marker (i.e. LF on POSIX, CRLF on Windows). + + *want_comments*, a ``bool``. If ``True``, emit end-of-line comments + as part of writing the file. If ``False``, the default, do not + emit them. + + *want_origin*, a ``bool``. If ``True``, emit a $ORIGIN line at + the start of the file. If ``False``, the default, do not emit + one. """ with contextlib.ExitStack() as stack: @@ -572,6 +633,16 @@ class Zone: nl_b = nl nl = nl.decode() + if want_origin: + l = '$ORIGIN ' + self.origin.to_text() + l_b = l.encode(file_enc) + try: + f.write(l_b) + f.write(nl_b) + except TypeError: # textual mode + f.write(l) + f.write(nl) + if sorted: names = list(self.keys()) names.sort() @@ -579,12 +650,9 @@ class Zone: names = self.keys() for n in names: l = self[n].to_text(n, origin=self.origin, - relativize=relativize) - if isinstance(l, str): - l_b = l.encode(file_enc) - else: - l_b = l - l = l.decode() + relativize=relativize, + want_comments=want_comments) + l_b = l.encode(file_enc) try: f.write(l_b) @@ -593,7 +661,8 @@ class Zone: f.write(l) f.write(nl) - def to_text(self, sorted=True, relativize=True, nl=None): + def to_text(self, sorted=True, relativize=True, nl=None, + want_comments=False, want_origin=False): """Return a zone's text as though it were written to a file. *sorted*, a ``bool``. If True, the default, then the file @@ -609,10 +678,19 @@ class Zone: ``None``, the output will use the platform's native end-of-line marker (i.e. LF on POSIX, CRLF on Windows). + *want_comments*, a ``bool``. If ``True``, emit end-of-line comments + as part of writing the file. If ``False``, the default, do not + emit them. + + *want_origin*, a ``bool``. If ``True``, emit a $ORIGIN line at + the start of the output. If ``False``, the default, do not emit + one. + Returns a ``str``. """ temp_buffer = io.StringIO() - self.to_file(temp_buffer, sorted, relativize, nl) + self.to_file(temp_buffer, sorted, relativize, nl, want_comments, + want_origin) return_value = temp_buffer.getvalue() temp_buffer.close() return return_value @@ -635,425 +713,334 @@ class Zone: if self.get_rdataset(name, dns.rdatatype.NS) is None: raise NoNS + def _compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): + hashinfo = _digest_hashers.get(hash_algorithm) + if not hashinfo: + raise UnsupportedDigestHashAlgorithm + if scheme != DigestScheme.SIMPLE: + raise UnsupportedDigestScheme -class _MasterReader: - - """Read a DNS master file - - @ivar tok: The tokenizer - @type tok: dns.tokenizer.Tokenizer object - @ivar last_ttl: The last seen explicit TTL for an RR - @type last_ttl: int - @ivar last_ttl_known: Has last TTL been detected - @type last_ttl_known: bool - @ivar default_ttl: The default TTL from a $TTL directive or SOA RR - @type default_ttl: int - @ivar default_ttl_known: Has default TTL been detected - @type default_ttl_known: bool - @ivar last_name: The last name read - @type last_name: dns.name.Name object - @ivar current_origin: The current origin - @type current_origin: dns.name.Name object - @ivar relativize: should names in the zone be relativized? - @type relativize: bool - @ivar zone: the zone - @type zone: dns.zone.Zone object - @ivar saved_state: saved reader state (used when processing $INCLUDE) - @type saved_state: list of (tokenizer, current_origin, last_name, file, - last_ttl, last_ttl_known, default_ttl, default_ttl_known) tuples. - @ivar current_file: the file object of the $INCLUDed file being parsed - (None if no $INCLUDE is active). - @ivar allow_include: is $INCLUDE allowed? - @type allow_include: bool - @ivar check_origin: should sanity checks of the origin node be done? - The default is True. - @type check_origin: bool - """ - - def __init__(self, tok, origin, rdclass, relativize, zone_factory=Zone, - allow_include=False, check_origin=True): - if isinstance(origin, str): - origin = dns.name.from_text(origin) - self.tok = tok - self.current_origin = origin - self.relativize = relativize - self.last_ttl = 0 - self.last_ttl_known = False - self.default_ttl = 0 - self.default_ttl_known = False - self.last_name = self.current_origin - self.zone = zone_factory(origin, rdclass, relativize=relativize) - self.saved_state = [] - self.current_file = None - self.allow_include = allow_include - self.check_origin = check_origin - - def _eat_line(self): - while 1: - token = self.tok.get() - if token.is_eol_or_eof(): - break - - def _rr_line(self): - """Process one line from a DNS master file.""" - # Name - if self.current_origin is None: - raise UnknownOrigin - token = self.tok.get(want_leading=True) - if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.current_origin) - else: - token = self.tok.get() - if token.is_eol_or_eof(): - # treat leading WS followed by EOL/EOF as if they were EOL/EOF. - return - self.tok.unget(token) - name = self.last_name - if not name.is_subdomain(self.zone.origin): - self._eat_line() - return if self.relativize: - name = name.relativize(self.zone.origin) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError + origin_name = dns.name.empty + else: + origin_name = self.origin + hasher = hashinfo() + for (name, node) in sorted(self.items()): + rrnamebuf = name.to_digestable(self.origin) + for rdataset in sorted(node, + key=lambda rds: (rds.rdtype, rds.covers)): + if name == origin_name and \ + dns.rdatatype.ZONEMD in (rdataset.rdtype, rdataset.covers): + continue + rrfixed = struct.pack('!HHI', rdataset.rdtype, + rdataset.rdclass, rdataset.ttl) + rdatas = [rdata.to_digestable(self.origin) + for rdata in rdataset] + for rdata in sorted(rdatas): + rrlen = struct.pack('!H', len(rdata)) + hasher.update(rrnamebuf + rrfixed + rrlen + rdata) + return hasher.digest() - # TTL - ttl = None - try: - ttl = dns.ttl.from_text(token.value) - self.last_ttl = ttl - self.last_ttl_known = True - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.ttl.BadTTL: - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl + def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): + if self.relativize: + origin_name = dns.name.empty + else: + origin_name = self.origin + serial = self.get_rdataset(origin_name, dns.rdatatype.SOA)[0].serial + digest = self._compute_digest(hash_algorithm, scheme) + return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass, + dns.rdatatype.ZONEMD, + serial, scheme, hash_algorithm, + digest) - # Class - try: - rdclass = dns.rdataclass.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.exception.SyntaxError: - raise - except Exception: - rdclass = self.zone.rdclass - if rdclass != self.zone.rdclass: - raise dns.exception.SyntaxError("RR class is not zone's class") - # Type - try: - rdtype = dns.rdatatype.from_text(token.value) - except Exception: - raise dns.exception.SyntaxError( - "unknown rdatatype '%s'" % token.value) - n = self.zone.nodes.get(name) - if n is None: - n = self.zone.node_factory() - self.zone.nodes[name] = n - try: - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.current_origin, self.relativize, - self.zone.origin) - except dns.exception.SyntaxError: - # Catch and reraise. - raise - except Exception: - # All exceptions that occur in the processing of rdata - # are treated as syntax errors. This is not strictly - # correct, but it is correct almost all of the time. - # We convert them to syntax errors so that we can emit - # helpful filename:line info. - (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError( - "caught exception {}: {}".format(str(ty), str(va))) - - if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: - # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default - # TTL from the SOA minttl if no $TTL statement is present before the - # SOA is parsed. - self.default_ttl = rd.minimum - self.default_ttl_known = True - if ttl is None: - # if we didn't have a TTL on the SOA, set it! - ttl = rd.minimum - - # TTL check. We had to wait until now to do this as the SOA RR's - # own TTL can be inferred from its minimum. - if ttl is None: - raise dns.exception.SyntaxError("Missing default TTL value") - - covers = rd.covers() - rds = n.find_rdataset(rdclass, rdtype, covers, True) - rds.add(rd, ttl) - - def _parse_modify(self, side): - # Here we catch everything in '{' '}' in a group so we can replace it - # with ''. - is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") - is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$") - is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$") - # Sometimes there are modifiers in the hostname. These come after - # the dollar sign. They are in the form: ${offset[,width[,base]]}. - # Make names - g1 = is_generate1.match(side) - if g1: - mod, sign, offset, width, base = g1.groups() - if sign == '': - sign = '+' - g2 = is_generate2.match(side) - if g2: - mod, sign, offset = g2.groups() - if sign == '': - sign = '+' - width = 0 - base = 'd' - g3 = is_generate3.match(side) - if g3: - mod, sign, offset, width = g3.groups() - if sign == '': - sign = '+' - base = 'd' - - if not (g1 or g2 or g3): - mod = '' - sign = '+' - offset = 0 - width = 0 - base = 'd' - - if base != 'd': - raise NotImplementedError() - - return mod, sign, offset, width, base - - def _generate_line(self): - # range lhs [ttl] [class] type rhs [ comment ] - """Process one line containing the GENERATE statement from a DNS - master file.""" - if self.current_origin is None: - raise UnknownOrigin - - token = self.tok.get() - # Range (required) - try: - start, stop, step = dns.grange.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except Exception: - raise dns.exception.SyntaxError - - # lhs (required) - try: - lhs = token.value - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except Exception: - raise dns.exception.SyntaxError - - # TTL - try: - ttl = dns.ttl.from_text(token.value) - self.last_ttl = ttl - self.last_ttl_known = True - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.ttl.BadTTL: - if not (self.last_ttl_known or self.default_ttl_known): - raise dns.exception.SyntaxError("Missing default TTL value") - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl - # Class - try: - rdclass = dns.rdataclass.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.exception.SyntaxError: - raise dns.exception.SyntaxError - except Exception: - rdclass = self.zone.rdclass - if rdclass != self.zone.rdclass: - raise dns.exception.SyntaxError("RR class is not zone's class") - # Type - try: - rdtype = dns.rdatatype.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except Exception: - raise dns.exception.SyntaxError("unknown rdatatype '%s'" % - token.value) - - # rhs (required) - rhs = token.value - - lmod, lsign, loffset, lwidth, lbase = self._parse_modify(lhs) - rmod, rsign, roffset, rwidth, rbase = self._parse_modify(rhs) - for i in range(start, stop + 1, step): - # +1 because bind is inclusive and python is exclusive - - if lsign == '+': - lindex = i + int(loffset) - elif lsign == '-': - lindex = i - int(loffset) - - if rsign == '-': - rindex = i - int(roffset) - elif rsign == '+': - rindex = i + int(roffset) - - lzfindex = str(lindex).zfill(int(lwidth)) - rzfindex = str(rindex).zfill(int(rwidth)) - - name = lhs.replace('$%s' % (lmod), lzfindex) - rdata = rhs.replace('$%s' % (rmod), rzfindex) - - self.last_name = dns.name.from_text(name, self.current_origin, - self.tok.idna_codec) - name = self.last_name - if not name.is_subdomain(self.zone.origin): - self._eat_line() - return - if self.relativize: - name = name.relativize(self.zone.origin) - - n = self.zone.nodes.get(name) - if n is None: - n = self.zone.node_factory() - self.zone.nodes[name] = n + def verify_digest(self, zonemd=None): + if zonemd: + digests = [zonemd] + else: + digests = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) + if digests is None: + raise NoDigest + for digest in digests: try: - rd = dns.rdata.from_text(rdclass, rdtype, rdata, - self.current_origin, self.relativize, - self.zone.origin) - except dns.exception.SyntaxError: - # Catch and reraise. - raise + computed = self._compute_digest(digest.hash_algorithm, + digest.scheme) + if computed == digest.digest: + return except Exception: - # All exceptions that occur in the processing of rdata - # are treated as syntax errors. This is not strictly - # correct, but it is correct almost all of the time. - # We convert them to syntax errors so that we can emit - # helpful filename:line info. - (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError("caught exception %s: %s" % - (str(ty), str(va))) + pass + raise DigestVerificationFailure - covers = rd.covers() - rds = n.find_rdataset(rdclass, rdtype, covers, True) - rds.add(rd, ttl) + # TransactionManager methods - def read(self): - """Read a DNS master file and build a zone object. + def reader(self): + return Transaction(self, False, + Version(self, 1, self.nodes, self.origin)) - @raises dns.zone.NoSOA: No SOA RR was found at the zone origin - @raises dns.zone.NoNS: No NS RRset was found at the zone origin - """ + def writer(self, replacement=False): + txn = Transaction(self, replacement) + txn._setup_version() + return txn - try: - while 1: - token = self.tok.get(True, True) - if token.is_eof(): - if self.current_file is not None: - self.current_file.close() - if len(self.saved_state) > 0: - (self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known) = self.saved_state.pop(-1) - continue - break - elif token.is_eol(): - continue - elif token.is_comment(): - self.tok.get_eol() - continue - elif token.value[0] == '$': - c = token.value.upper() - if c == '$TTL': - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError("bad $TTL") - self.default_ttl = dns.ttl.from_text(token.value) - self.default_ttl_known = True - self.tok.get_eol() - elif c == '$ORIGIN': - self.current_origin = self.tok.get_name() - self.tok.get_eol() - if self.zone.origin is None: - self.zone.origin = self.current_origin - elif c == '$INCLUDE' and self.allow_include: - token = self.tok.get() - filename = token.value - token = self.tok.get() - if token.is_identifier(): - new_origin =\ - dns.name.from_text(token.value, - self.current_origin, - self.tok.idna_codec) - self.tok.get_eol() - elif not token.is_eol_or_eof(): - raise dns.exception.SyntaxError( - "bad origin in $INCLUDE") - else: - new_origin = self.current_origin - self.saved_state.append((self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known)) - self.current_file = open(filename, 'r') - self.tok = dns.tokenizer.Tokenizer(self.current_file, - filename) - self.current_origin = new_origin - elif c == '$GENERATE': - self._generate_line() - else: - raise dns.exception.SyntaxError( - "Unknown master file directive '" + c + "'") - continue - self.tok.unget(token) - self._rr_line() - except dns.exception.SyntaxError as detail: - (filename, line_number) = self.tok.where() - if detail is None: - detail = "syntax error" - ex = dns.exception.SyntaxError( - "%s:%d: %s" % (filename, line_number, detail)) - tb = sys.exc_info()[2] - raise ex.with_traceback(tb) from None + def origin_information(self): + if self.relativize: + effective = dns.name.empty + else: + effective = self.origin + return (self.origin, self.relativize, effective) - # Now that we're done reading, do some basic checking of the zone. - if self.check_origin: - self.zone.check_origin() + def get_class(self): + return self.rdclass + + # Transaction methods + + def _end_read(self, txn): + pass + + def _end_write(self, txn): + pass + + def _commit_version(self, _, version, origin): + self.nodes = version.nodes + if self.origin is None: + self.origin = origin + + def _get_next_version_id(self): + # Versions are ephemeral and all have id 1 + return 1 + + +# These classes used to be in dns.versioned, but have moved here so we can use +# the copy-on-write transaction mechanism for both kinds of zones. In a +# regular zone, the version only exists during the transaction, and the nodes +# are regular dns.node.Nodes. + +# A node with a version id. + +class VersionedNode(dns.node.Node): + __slots__ = ['id'] + + def __init__(self): + super().__init__() + # A proper id will get set by the Version + self.id = 0 + + +@dns.immutable.immutable +class ImmutableVersionedNode(VersionedNode): + __slots__ = ['id'] + + def __init__(self, node): + super().__init__() + self.id = node.id + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + raise TypeError("immutable") + + def replace_rdataset(self, replacement): + raise TypeError("immutable") + + def is_immutable(self): + return True + + +class Version: + def __init__(self, zone, id, nodes=None, origin=None): + self.zone = zone + self.id = id + if nodes is not None: + self.nodes = nodes + else: + self.nodes = {} + self.origin = origin + + def _validate_name(self, name): + if name.is_absolute(): + if not name.is_subdomain(self.zone.origin): + raise KeyError("name is not a subdomain of the zone origin") + if self.zone.relativize: + # XXXRTH should it be an error if self.origin is still None? + name = name.relativize(self.origin) + return name + + def get_node(self, name): + name = self._validate_name(name) + return self.nodes.get(name) + + def get_rdataset(self, name, rdtype, covers): + node = self.get_node(name) + if node is None: + return None + return node.get_rdataset(self.zone.rdclass, rdtype, covers) + + def items(self): + return self.nodes.items() + + +class WritableVersion(Version): + def __init__(self, zone, replacement=False): + # The zone._versions_lock must be held by our caller in a versioned + # zone. + id = zone._get_next_version_id() + super().__init__(zone, id) + if not replacement: + # We copy the map, because that gives us a simple and thread-safe + # way of doing versions, and we have a garbage collector to help + # us. We only make new node objects if we actually change the + # node. + self.nodes.update(zone.nodes) + # We have to copy the zone origin as it may be None in the first + # version, and we don't want to mutate the zone until we commit. + self.origin = zone.origin + self.changed = set() + + def _maybe_cow(self, name): + name = self._validate_name(name) + node = self.nodes.get(name) + if node is None or name not in self.changed: + new_node = self.zone.node_factory() + if hasattr(new_node, 'id'): + # We keep doing this for backwards compatibility, as earlier + # code used new_node.id != self.id for the "do we need to CoW?" + # test. Now we use the changed set as this works with both + # regular zones and versioned zones. + new_node.id = self.id + if node is not None: + # moo! copy on write! + new_node.rdatasets.extend(node.rdatasets) + self.nodes[name] = new_node + self.changed.add(name) + return new_node + else: + return node + + def delete_node(self, name): + name = self._validate_name(name) + if name in self.nodes: + del self.nodes[name] + self.changed.add(name) + + def put_rdataset(self, name, rdataset): + node = self._maybe_cow(name) + node.replace_rdataset(rdataset) + + def delete_rdataset(self, name, rdtype, covers): + node = self._maybe_cow(name) + node.delete_rdataset(self.zone.rdclass, rdtype, covers) + if len(node) == 0: + del self.nodes[name] + + +@dns.immutable.immutable +class ImmutableVersion(Version): + def __init__(self, version): + # We tell super() that it's a replacement as we don't want it + # to copy the nodes, as we're about to do that with an + # immutable Dict. + super().__init__(version.zone, True) + # set the right id! + self.id = version.id + # keep the origin + self.origin = version.origin + # Make changed nodes immutable + for name in version.changed: + node = version.nodes.get(name) + # it might not exist if we deleted it in the version + if node: + version.nodes[name] = ImmutableVersionedNode(node) + self.nodes = dns.immutable.Dict(version.nodes, True) + + +class Transaction(dns.transaction.Transaction): + + def __init__(self, zone, replacement, version=None, make_immutable=False): + read_only = version is not None + super().__init__(zone, replacement, read_only) + self.version = version + self.make_immutable = make_immutable + + @property + def zone(self): + return self.manager + + def _setup_version(self): + assert self.version is None + self.version = WritableVersion(self.zone, self.replacement) + + def _get_rdataset(self, name, rdtype, covers): + return self.version.get_rdataset(name, rdtype, covers) + + def _put_rdataset(self, name, rdataset): + assert not self.read_only + self.version.put_rdataset(name, rdataset) + + def _delete_name(self, name): + assert not self.read_only + self.version.delete_node(name) + + def _delete_rdataset(self, name, rdtype, covers): + assert not self.read_only + self.version.delete_rdataset(name, rdtype, covers) + + def _name_exists(self, name): + return self.version.get_node(name) is not None + + def _changed(self): + if self.read_only: + return False + else: + return len(self.version.changed) > 0 + + def _end_transaction(self, commit): + if self.read_only: + self.zone._end_read(self) + elif commit and len(self.version.changed) > 0: + if self.make_immutable: + version = ImmutableVersion(self.version) + else: + version = self.version + self.zone._commit_version(self, version, self.version.origin) + else: + # rollback + self.zone._end_write(self) + + def _set_origin(self, origin): + if self.version.origin is None: + self.version.origin = origin + + def _iterate_rdatasets(self): + for (name, node) in self.version.items(): + for rdataset in node: + yield (name, rdataset) + + def _get_node(self, name): + return self.version.get_node(name) def from_text(text, origin=None, rdclass=dns.rdataclass.IN, relativize=True, zone_factory=Zone, filename=None, allow_include=False, check_origin=True, idna_codec=None): - """Build a zone object from a master file format string. + """Build a zone object from a zone file format string. - *text*, a ``str``, the master file format input. + *text*, a ``str``, the zone file format input. *origin*, a ``dns.name.Name``, a ``str``, or ``None``. The origin of the zone; if not specified, the first ``$ORIGIN`` statement in the - masterfile will determine the origin of the zone. + zone file will determine the origin of the zone. *rdclass*, an ``int``, the zone's rdata class; the default is class IN. @@ -1094,25 +1081,33 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, if filename is None: filename = '' - tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) - reader = _MasterReader(tok, origin, rdclass, relativize, zone_factory, - allow_include=allow_include, - check_origin=check_origin) - reader.read() - return reader.zone + zone = zone_factory(origin, rdclass, relativize=relativize) + with zone.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) + reader = dns.zonefile.Reader(tok, rdclass, txn, + allow_include=allow_include) + try: + reader.read() + except dns.zonefile.UnknownOrigin: + # for backwards compatibility + raise dns.zone.UnknownOrigin + # Now that we're done reading, do some basic checking of the zone. + if check_origin: + zone.check_origin() + return zone def from_file(f, origin=None, rdclass=dns.rdataclass.IN, relativize=True, zone_factory=Zone, filename=None, allow_include=True, check_origin=True): - """Read a master file and build a zone object. + """Read a zone file and build a zone object. *f*, a file or ``str``. If *f* is a string, it is treated as the name of a file to open. *origin*, a ``dns.name.Name``, a ``str``, or ``None``. The origin of the zone; if not specified, the first ``$ORIGIN`` statement in the - masterfile will determine the origin of the zone. + zone file will determine the origin of the zone. *rdclass*, an ``int``, the zone's rdata class; the default is class IN. diff --git a/lib/dns/zone.pyi b/lib/dns/zone.pyi new file mode 100644 index 00000000..272814fe --- /dev/null +++ b/lib/dns/zone.pyi @@ -0,0 +1,55 @@ +from typing import Generator, Optional, Union, Tuple, Iterable, Callable, Any, Iterator, TextIO, BinaryIO, Dict +from . import rdata, zone, rdataclass, name, rdataclass, message, rdatatype, exception, node, rdataset, rrset, rdatatype + +class BadZone(exception.DNSException): ... +class NoSOA(BadZone): ... +class NoNS(BadZone): ... +class UnknownOrigin(BadZone): ... + +class Zone: + def __getitem__(self, key : str) -> node.Node: + ... + def __init__(self, origin : Union[str,name.Name], rdclass : int = rdataclass.IN, relativize : bool = True) -> None: + self.nodes : Dict[str,node.Node] + self.origin = origin + def values(self): + return self.nodes.values() + def iterate_rdatas(self, rdtype : Union[int,str] = rdatatype.ANY, covers : Union[int,str] = None) -> Iterable[Tuple[name.Name, int, rdata.Rdata]]: + ... + def __iter__(self) -> Iterator[str]: + ... + def get_node(self, name : Union[name.Name,str], create=False) -> Optional[node.Node]: + ... + def find_rrset(self, name : Union[str,name.Name], rdtype : Union[int,str], covers=rdatatype.NONE) -> rrset.RRset: + ... + def find_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, + create=False) -> rdataset.Rdataset: + ... + def get_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, create=False) -> Optional[rdataset.Rdataset]: + ... + def get_rrset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> Optional[rrset.RRset]: + ... + def replace_rdataset(self, name : Union[str,name.Name], replacement : rdataset.Rdataset) -> None: + ... + def delete_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> None: + ... + def iterate_rdatasets(self, rdtype : Union[str,int] =rdatatype.ANY, + covers : Union[str,int] =rdatatype.NONE): + ... + def to_file(self, f : Union[TextIO, BinaryIO, str], sorted=True, relativize=True, nl : Optional[bytes] = None): + ... + def to_text(self, sorted=True, relativize=True, nl : Optional[str] = None) -> str: + ... + +def from_xfr(xfr : Generator[Any,Any,message.Message], zone_factory : Callable[..., zone.Zone] = zone.Zone, relativize=True, check_origin=True): + ... + +def from_text(text : str, origin : Optional[Union[str,name.Name]] = None, rdclass : int = rdataclass.IN, + relativize=True, zone_factory : Callable[...,zone.Zone] = zone.Zone, filename : Optional[str] = None, + allow_include=False, check_origin=True) -> zone.Zone: + ... + +def from_file(f, origin : Optional[Union[str,name.Name]] = None, rdclass=rdataclass.IN, + relativize=True, zone_factory : Callable[..., zone.Zone] = Zone, filename : Optional[str] = None, + allow_include=True, check_origin=True) -> zone.Zone: + ... diff --git a/lib/dns/zonefile.py b/lib/dns/zonefile.py new file mode 100644 index 00000000..53b40880 --- /dev/null +++ b/lib/dns/zonefile.py @@ -0,0 +1,624 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Zones.""" + +import re +import sys + +import dns.exception +import dns.name +import dns.node +import dns.rdataclass +import dns.rdatatype +import dns.rdata +import dns.rdtypes.ANY.SOA +import dns.rrset +import dns.tokenizer +import dns.transaction +import dns.ttl +import dns.grange + + +class UnknownOrigin(dns.exception.DNSException): + """Unknown origin""" + + +class CNAMEAndOtherData(dns.exception.DNSException): + """A node has a CNAME and other data""" + + +def _check_cname_and_other_data(txn, name, rdataset): + rdataset_kind = dns.node.NodeKind.classify_rdataset(rdataset) + node = txn.get_node(name) + if node is None: + # empty nodes are neutral. + return + node_kind = node.classify() + if node_kind == dns.node.NodeKind.CNAME and \ + rdataset_kind == dns.node.NodeKind.REGULAR: + raise CNAMEAndOtherData('rdataset type is not compatible with a ' + 'CNAME node') + elif node_kind == dns.node.NodeKind.REGULAR and \ + rdataset_kind == dns.node.NodeKind.CNAME: + raise CNAMEAndOtherData('CNAME rdataset is not compatible with a ' + 'regular data node') + # Otherwise at least one of the node and the rdataset is neutral, so + # adding the rdataset is ok + + +class Reader: + + """Read a DNS zone file into a transaction.""" + + def __init__(self, tok, rdclass, txn, allow_include=False, + allow_directives=True, force_name=None, + force_ttl=None, force_rdclass=None, force_rdtype=None, + default_ttl=None): + self.tok = tok + (self.zone_origin, self.relativize, _) = \ + txn.manager.origin_information() + self.current_origin = self.zone_origin + self.last_ttl = 0 + self.last_ttl_known = False + if force_ttl is not None: + default_ttl = force_ttl + if default_ttl is None: + self.default_ttl = 0 + self.default_ttl_known = False + else: + self.default_ttl = default_ttl + self.default_ttl_known = True + self.last_name = self.current_origin + self.zone_rdclass = rdclass + self.txn = txn + self.saved_state = [] + self.current_file = None + self.allow_include = allow_include + self.allow_directives = allow_directives + self.force_name = force_name + self.force_ttl = force_ttl + self.force_rdclass = force_rdclass + self.force_rdtype = force_rdtype + self.txn.check_put_rdataset(_check_cname_and_other_data) + + def _eat_line(self): + while 1: + token = self.tok.get() + if token.is_eol_or_eof(): + break + + def _get_identifier(self): + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + return token + + def _rr_line(self): + """Process one line from a DNS zone file.""" + token = None + # Name + if self.force_name is not None: + name = self.force_name + else: + if self.current_origin is None: + raise UnknownOrigin + token = self.tok.get(want_leading=True) + if not token.is_whitespace(): + self.last_name = self.tok.as_name(token, self.current_origin) + else: + token = self.tok.get() + if token.is_eol_or_eof(): + # treat leading WS followed by EOL/EOF as if they were EOL/EOF. + return + self.tok.unget(token) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() + return + if self.relativize: + name = name.relativize(self.zone_origin) + + # TTL + if self.force_ttl is not None: + ttl = self.force_ttl + self.last_ttl = ttl + self.last_ttl_known = True + else: + token = self._get_identifier() + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = None + except dns.ttl.BadTTL: + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + self.tok.unget(token) + + # Class + if self.force_rdclass is not None: + rdclass = self.force_rdclass + else: + token = self._get_identifier() + try: + rdclass = dns.rdataclass.from_text(token.value) + except dns.exception.SyntaxError: + raise + except Exception: + rdclass = self.zone_rdclass + self.tok.unget(token) + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + + # Type + if self.force_rdtype is not None: + rdtype = self.force_rdtype + else: + token = self._get_identifier() + try: + rdtype = dns.rdatatype.from_text(token.value) + except Exception: + raise dns.exception.SyntaxError( + "unknown rdatatype '%s'" % token.value) + + try: + rd = dns.rdata.from_text(rdclass, rdtype, self.tok, + self.current_origin, self.relativize, + self.zone_origin) + except dns.exception.SyntaxError: + # Catch and reraise. + raise + except Exception: + # All exceptions that occur in the processing of rdata + # are treated as syntax errors. This is not strictly + # correct, but it is correct almost all of the time. + # We convert them to syntax errors so that we can emit + # helpful filename:line info. + (ty, va) = sys.exc_info()[:2] + raise dns.exception.SyntaxError( + "caught exception {}: {}".format(str(ty), str(va))) + + if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: + # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default + # TTL from the SOA minttl if no $TTL statement is present before the + # SOA is parsed. + self.default_ttl = rd.minimum + self.default_ttl_known = True + if ttl is None: + # if we didn't have a TTL on the SOA, set it! + ttl = rd.minimum + + # TTL check. We had to wait until now to do this as the SOA RR's + # own TTL can be inferred from its minimum. + if ttl is None: + raise dns.exception.SyntaxError("Missing default TTL value") + + self.txn.add(name, ttl, rd) + + def _parse_modify(self, side): + # Here we catch everything in '{' '}' in a group so we can replace it + # with ''. + is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") + is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$") + is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$") + # Sometimes there are modifiers in the hostname. These come after + # the dollar sign. They are in the form: ${offset[,width[,base]]}. + # Make names + g1 = is_generate1.match(side) + if g1: + mod, sign, offset, width, base = g1.groups() + if sign == '': + sign = '+' + g2 = is_generate2.match(side) + if g2: + mod, sign, offset = g2.groups() + if sign == '': + sign = '+' + width = 0 + base = 'd' + g3 = is_generate3.match(side) + if g3: + mod, sign, offset, width = g3.groups() + if sign == '': + sign = '+' + base = 'd' + + if not (g1 or g2 or g3): + mod = '' + sign = '+' + offset = 0 + width = 0 + base = 'd' + + if base != 'd': + raise NotImplementedError() + + return mod, sign, offset, width, base + + def _generate_line(self): + # range lhs [ttl] [class] type rhs [ comment ] + """Process one line containing the GENERATE statement from a DNS + zone file.""" + if self.current_origin is None: + raise UnknownOrigin + + token = self.tok.get() + # Range (required) + try: + start, stop, step = dns.grange.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError + + # lhs (required) + try: + lhs = token.value + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError + + # TTL + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.ttl.BadTTL: + if not (self.last_ttl_known or self.default_ttl_known): + raise dns.exception.SyntaxError("Missing default TTL value") + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + # Class + try: + rdclass = dns.rdataclass.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise dns.exception.SyntaxError + except Exception: + rdclass = self.zone_rdclass + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + # Type + try: + rdtype = dns.rdatatype.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % + token.value) + + # rhs (required) + rhs = token.value + + # The code currently only supports base 'd', so the last value + # in the tuple _parse_modify returns is ignored + lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs) + rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs) + for i in range(start, stop + 1, step): + # +1 because bind is inclusive and python is exclusive + + if lsign == '+': + lindex = i + int(loffset) + elif lsign == '-': + lindex = i - int(loffset) + + if rsign == '-': + rindex = i - int(roffset) + elif rsign == '+': + rindex = i + int(roffset) + + lzfindex = str(lindex).zfill(int(lwidth)) + rzfindex = str(rindex).zfill(int(rwidth)) + + name = lhs.replace('$%s' % (lmod), lzfindex) + rdata = rhs.replace('$%s' % (rmod), rzfindex) + + self.last_name = dns.name.from_text(name, self.current_origin, + self.tok.idna_codec) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() + return + if self.relativize: + name = name.relativize(self.zone_origin) + + try: + rd = dns.rdata.from_text(rdclass, rdtype, rdata, + self.current_origin, self.relativize, + self.zone_origin) + except dns.exception.SyntaxError: + # Catch and reraise. + raise + except Exception: + # All exceptions that occur in the processing of rdata + # are treated as syntax errors. This is not strictly + # correct, but it is correct almost all of the time. + # We convert them to syntax errors so that we can emit + # helpful filename:line info. + (ty, va) = sys.exc_info()[:2] + raise dns.exception.SyntaxError("caught exception %s: %s" % + (str(ty), str(va))) + + self.txn.add(name, ttl, rd) + + def read(self): + """Read a DNS zone file and build a zone object. + + @raises dns.zone.NoSOA: No SOA RR was found at the zone origin + @raises dns.zone.NoNS: No NS RRset was found at the zone origin + """ + + try: + while 1: + token = self.tok.get(True, True) + if token.is_eof(): + if self.current_file is not None: + self.current_file.close() + if len(self.saved_state) > 0: + (self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known) = self.saved_state.pop(-1) + continue + break + elif token.is_eol(): + continue + elif token.is_comment(): + self.tok.get_eol() + continue + elif token.value[0] == '$' and self.allow_directives: + c = token.value.upper() + if c == '$TTL': + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError("bad $TTL") + self.default_ttl = dns.ttl.from_text(token.value) + self.default_ttl_known = True + self.tok.get_eol() + elif c == '$ORIGIN': + self.current_origin = self.tok.get_name() + self.tok.get_eol() + if self.zone_origin is None: + self.zone_origin = self.current_origin + self.txn._set_origin(self.current_origin) + elif c == '$INCLUDE' and self.allow_include: + token = self.tok.get() + filename = token.value + token = self.tok.get() + if token.is_identifier(): + new_origin =\ + dns.name.from_text(token.value, + self.current_origin, + self.tok.idna_codec) + self.tok.get_eol() + elif not token.is_eol_or_eof(): + raise dns.exception.SyntaxError( + "bad origin in $INCLUDE") + else: + new_origin = self.current_origin + self.saved_state.append((self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known)) + self.current_file = open(filename, 'r') + self.tok = dns.tokenizer.Tokenizer(self.current_file, + filename) + self.current_origin = new_origin + elif c == '$GENERATE': + self._generate_line() + else: + raise dns.exception.SyntaxError( + "Unknown zone file directive '" + c + "'") + continue + self.tok.unget(token) + self._rr_line() + except dns.exception.SyntaxError as detail: + (filename, line_number) = self.tok.where() + if detail is None: + detail = "syntax error" + ex = dns.exception.SyntaxError( + "%s:%d: %s" % (filename, line_number, detail)) + tb = sys.exc_info()[2] + raise ex.with_traceback(tb) from None + + +class RRsetsReaderTransaction(dns.transaction.Transaction): + + def __init__(self, manager, replacement, read_only): + assert not read_only + super().__init__(manager, replacement, read_only) + self.rdatasets = {} + + def _get_rdataset(self, name, rdtype, covers): + return self.rdatasets.get((name, rdtype, covers)) + + def _get_node(self, name): + rdatasets = [] + for (rdataset_name, _, _), rdataset in self.rdatasets.items(): + if name == rdataset_name: + rdatasets.append(rdataset) + if len(rdatasets) == 0: + return None + node = dns.node.Node() + node.rdatasets = rdatasets + return node + + def _put_rdataset(self, name, rdataset): + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset + + def _delete_name(self, name): + # First remove any changes involving the name + remove = [] + for key in self.rdatasets: + if key[0] == name: + remove.append(key) + if len(remove) > 0: + for key in remove: + del self.rdatasets[key] + + def _delete_rdataset(self, name, rdtype, covers): + try: + del self.rdatasets[(name, rdtype, covers)] + except KeyError: + pass + + def _name_exists(self, name): + for (n, _, _) in self.rdatasets: + if n == name: + return True + return False + + def _changed(self): + return len(self.rdatasets) > 0 + + def _end_transaction(self, commit): + if commit and self._changed(): + rrsets = [] + for (name, _, _), rdataset in self.rdatasets.items(): + rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype, + rdataset.covers) + rrset.update(rdataset) + rrsets.append(rrset) + self.manager.set_rrsets(rrsets) + + def _set_origin(self, origin): + pass + + +class RRSetsReaderManager(dns.transaction.TransactionManager): + def __init__(self, origin=dns.name.root, relativize=False, + rdclass=dns.rdataclass.IN): + self.origin = origin + self.relativize = relativize + self.rdclass = rdclass + self.rrsets = [] + + def writer(self, replacement=False): + assert replacement is True + return RRsetsReaderTransaction(self, True, False) + + def get_class(self): + return self.rdclass + + def origin_information(self): + if self.relativize: + effective = dns.name.empty + else: + effective = self.origin + return (self.origin, self.relativize, effective) + + def set_rrsets(self, rrsets): + self.rrsets = rrsets + + +def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN, + default_rdclass=dns.rdataclass.IN, + rdtype=None, default_ttl=None, idna_codec=None, + origin=dns.name.root, relativize=False): + """Read one or more rrsets from the specified text, possibly subject + to restrictions. + + *text*, a file object or a string, is the input to process. + + *name*, a string, ``dns.name.Name``, or ``None``, is the owner name of + the rrset. If not ``None``, then the owner name is "forced", and the + input must not specify an owner name. If ``None``, then any owner names + are allowed and must be present in the input. + + *ttl*, an ``int``, string, or None. If not ``None``, the the TTL is + forced to be the specified value and the input must not specify a TTL. + If ``None``, then a TTL may be specified in the input. If it is not + specified, then the *default_ttl* will be used. + + *rdclass*, a ``dns.rdataclass.RdataClass``, string, or ``None``. If + not ``None``, then the class is forced to the specified value, and the + input must not specify a class. If ``None``, then the input may specify + a class that matches *default_rdclass*. Note that it is not possible to + return rrsets with differing classes; specifying ``None`` for the class + simply allows the user to optionally type a class as that may be convenient + when cutting and pasting. + + *default_rdclass*, a ``dns.rdataclass.RdataClass`` or string. The class + of the returned rrsets. + + *rdtype*, a ``dns.rdatatype.RdataType``, string, or ``None``. If not + ``None``, then the type is forced to the specified value, and the + input must not specify a type. If ``None``, then a type must be present + for each RR. + + *default_ttl*, an ``int``, string, or ``None``. If not ``None``, then if + the TTL is not forced and is not specified, then this value will be used. + if ``None``, then if the TTL is not forced an error will occur if the TTL + is not specified. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. Note that codecs only apply to the owner name; dnspython does + not do IDNA for names in rdata, as there is no IDNA zonefile format. + + *origin*, a string, ``dns.name.Name``, or ``None``, is the origin for any + relative names in the input, and also the origin to relativize to if + *relativize* is ``True``. + + *relativize*, a bool. If ``True``, names are relativized to the *origin*; + if ``False`` then any relative names in the input are made absolute by + appending the *origin*. + """ + if isinstance(origin, str): + origin = dns.name.from_text(origin, dns.name.root, idna_codec) + if isinstance(name, str): + name = dns.name.from_text(name, origin, idna_codec) + if isinstance(ttl, str): + ttl = dns.ttl.from_text(ttl) + if isinstance(default_ttl, str): + default_ttl = dns.ttl.from_text(default_ttl) + if rdclass is not None: + rdclass = dns.rdataclass.RdataClass.make(rdclass) + default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + if rdtype is not None: + rdtype = dns.rdatatype.RdataType.make(rdtype) + manager = RRSetsReaderManager(origin, relativize, default_rdclass) + with manager.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, '', idna_codec=idna_codec) + reader = Reader(tok, default_rdclass, txn, allow_directives=False, + force_name=name, force_ttl=ttl, force_rdclass=rdclass, + force_rdtype=rdtype, default_ttl=default_ttl) + reader.read() + return manager.rrsets diff --git a/requirements.txt b/requirements.txt index 4a3e3518..2425b598 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ cheroot==8.6.0 cherrypy==18.6.1 cloudinary==1.28.1 distro==1.6.0 -dnspython==2.0.0 +dnspython==2.2.0 facebook-sdk==3.1.0 future==0.18.2 gntp==1.0.3