From 4d62245cf54ef070506a68ae307d8b9fd07454a6 Mon Sep 17 00:00:00 2001 From: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> Date: Thu, 14 Oct 2021 21:36:41 -0700 Subject: [PATCH] Update dnspython-2.2.0 --- lib/dns/__init__.py | 11 +- lib/dns/_asyncbackend.py | 60 + lib/dns/_asyncio_backend.py | 140 +++ lib/dns/_compat.py | 21 - lib/dns/_curio_backend.py | 106 ++ lib/dns/_trio_backend.py | 119 ++ lib/dns/asyncbackend.py | 96 ++ lib/dns/asyncquery.py | 500 ++++++++ lib/dns/asyncresolver.py | 257 ++++ lib/dns/dnssec.py | 586 +++++---- lib/dns/e164.py | 78 +- lib/dns/edns.py | 296 ++++- lib/dns/entropy.py | 106 +- lib/dns/enum.py | 90 ++ lib/dns/exception.py | 60 +- lib/dns/flags.py | 104 +- lib/dns/grange.py | 24 +- lib/dns/inet.py | 141 ++- lib/dns/ipv4.py | 39 +- lib/dns/ipv6.py | 89 +- lib/dns/message.py | 1480 +++++++++++++---------- lib/dns/name.py | 740 ++++++++---- lib/dns/namedict.py | 40 +- lib/dns/node.py | 125 +- lib/dns/opcode.py | 94 +- lib/dns/py.typed | 0 lib/dns/query.py | 1156 ++++++++++++------ lib/dns/rcode.py | 148 ++- lib/dns/rdata.py | 512 +++++--- lib/dns/rdataclass.py | 130 +- lib/dns/rdataset.py | 193 +-- lib/dns/rdatatype.py | 378 +++--- lib/dns/rdtypes/ANY/AFSDB.py | 27 +- lib/dns/rdtypes/ANY/AMTRELAY.py | 79 ++ lib/dns/{hash.py => rdtypes/ANY/AVC.py} | 21 +- lib/dns/rdtypes/ANY/CAA.py | 35 +- lib/dns/rdtypes/ANY/CDNSKEY.py | 7 +- lib/dns/rdtypes/ANY/CDS.py | 2 + lib/dns/rdtypes/ANY/CERT.py | 51 +- lib/dns/rdtypes/ANY/CNAME.py | 2 + lib/dns/rdtypes/ANY/CSYNC.py | 105 +- lib/dns/rdtypes/ANY/DLV.py | 2 + lib/dns/rdtypes/ANY/DNAME.py | 6 +- lib/dns/rdtypes/ANY/DNSKEY.py | 7 +- lib/dns/rdtypes/ANY/DS.py | 2 + lib/dns/rdtypes/ANY/EUI48.py | 8 +- lib/dns/rdtypes/ANY/EUI64.py | 8 +- lib/dns/rdtypes/ANY/GPOS.py | 113 +- lib/dns/rdtypes/ANY/HINFO.py | 57 +- lib/dns/rdtypes/ANY/HIP.py | 71 +- lib/dns/rdtypes/ANY/ISDN.py | 57 +- lib/dns/rdtypes/ANY/LOC.py | 137 ++- lib/dns/rdtypes/ANY/MX.py | 2 + lib/dns/rdtypes/ANY/NINFO.py | 25 + lib/dns/rdtypes/ANY/NS.py | 2 + lib/dns/rdtypes/ANY/NSEC.py | 112 +- lib/dns/rdtypes/ANY/NSEC3.py | 157 +-- lib/dns/rdtypes/ANY/NSEC3PARAM.py | 48 +- lib/dns/rdtypes/ANY/OPENPGPKEY.py | 50 + lib/dns/rdtypes/ANY/OPT.py | 67 + lib/dns/rdtypes/ANY/PTR.py | 2 + lib/dns/rdtypes/ANY/RP.py | 59 +- lib/dns/rdtypes/ANY/RRSIG.py | 86 +- lib/dns/rdtypes/ANY/RT.py | 2 + lib/dns/rdtypes/ANY/SOA.py | 83 +- lib/dns/rdtypes/ANY/SPF.py | 6 +- lib/dns/rdtypes/ANY/SSHFP.py | 44 +- lib/dns/rdtypes/ANY/TLSA.py | 48 +- lib/dns/rdtypes/ANY/TSIG.py | 91 ++ lib/dns/rdtypes/ANY/TXT.py | 2 + lib/dns/rdtypes/ANY/URI.py | 46 +- lib/dns/rdtypes/ANY/X25.py | 32 +- lib/dns/rdtypes/ANY/__init__.py | 11 +- lib/dns/rdtypes/CH/A.py | 56 + lib/dns/rdtypes/CH/__init__.py | 22 + lib/dns/rdtypes/IN/A.py | 21 +- lib/dns/rdtypes/IN/AAAA.py | 28 +- lib/dns/rdtypes/IN/APL.py | 83 +- lib/dns/rdtypes/IN/DHCID.py | 33 +- lib/dns/rdtypes/IN/IPSECKEY.py | 126 +- lib/dns/rdtypes/IN/KX.py | 4 +- lib/dns/rdtypes/IN/NAPTR.py | 73 +- lib/dns/rdtypes/IN/NSAP.py | 22 +- lib/dns/rdtypes/IN/NSAP_PTR.py | 2 + lib/dns/rdtypes/IN/PX.py | 61 +- lib/dns/rdtypes/IN/SRV.py | 51 +- lib/dns/rdtypes/IN/WKS.py | 46 +- lib/dns/rdtypes/IN/__init__.py | 3 + lib/dns/rdtypes/__init__.py | 4 + lib/dns/rdtypes/dnskeybase.py | 106 +- lib/dns/rdtypes/dsbase.py | 51 +- lib/dns/rdtypes/euibase.py | 23 +- lib/dns/rdtypes/mxbase.py | 59 +- lib/dns/rdtypes/nsbase.py | 47 +- lib/dns/rdtypes/txtbase.py | 63 +- lib/dns/rdtypes/util.py | 166 +++ lib/dns/renderer.py | 272 ++--- lib/dns/resolver.py | 1451 ++++++++++++---------- lib/dns/reversename.py | 99 +- lib/dns/rrset.py | 97 +- lib/dns/serial.py | 117 ++ lib/dns/set.py | 147 +-- lib/dns/tokenizer.py | 372 +++--- lib/dns/tsig.py | 223 ++-- lib/dns/tsigkeyring.py | 37 +- lib/dns/ttl.py | 33 +- lib/dns/update.py | 211 ++-- lib/dns/version.py | 26 +- lib/dns/wire.py | 82 ++ lib/dns/wiredata.py | 84 -- lib/dns/zone.py | 862 +++++++------ 111 files changed, 9077 insertions(+), 5877 deletions(-) create mode 100644 lib/dns/_asyncbackend.py create mode 100644 lib/dns/_asyncio_backend.py delete mode 100644 lib/dns/_compat.py create mode 100644 lib/dns/_curio_backend.py create mode 100644 lib/dns/_trio_backend.py create mode 100644 lib/dns/asyncbackend.py create mode 100644 lib/dns/asyncquery.py create mode 100644 lib/dns/asyncresolver.py create mode 100644 lib/dns/enum.py create mode 100644 lib/dns/py.typed create mode 100644 lib/dns/rdtypes/ANY/AMTRELAY.py rename lib/dns/{hash.py => rdtypes/ANY/AVC.py} (66%) create mode 100644 lib/dns/rdtypes/ANY/NINFO.py create mode 100644 lib/dns/rdtypes/ANY/OPENPGPKEY.py create mode 100644 lib/dns/rdtypes/ANY/OPT.py create mode 100644 lib/dns/rdtypes/ANY/TSIG.py create mode 100644 lib/dns/rdtypes/CH/A.py create mode 100644 lib/dns/rdtypes/CH/__init__.py create mode 100644 lib/dns/rdtypes/util.py create mode 100644 lib/dns/serial.py create mode 100644 lib/dns/wire.py delete mode 100644 lib/dns/wiredata.py diff --git a/lib/dns/__init__.py b/lib/dns/__init__.py index c848e485..b944701d 100644 --- a/lib/dns/__init__.py +++ b/lib/dns/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -16,13 +18,15 @@ """dnspython DNS toolkit""" __all__ = [ + 'asyncbackend', + 'asyncquery', + 'asyncresolver', 'dnssec', 'e164', 'edns', 'entropy', 'exception', 'flags', - 'hash', 'inet', 'ipv4', 'ipv6', @@ -41,6 +45,7 @@ __all__ = [ 'resolver', 'reversename', 'rrset', + 'serial', 'set', 'tokenizer', 'tsig', @@ -49,6 +54,8 @@ __all__ = [ 'rdtypes', 'update', 'version', - 'wiredata', + 'wire', 'zone', ] + +from dns.version import version as __version__ # noqa diff --git a/lib/dns/_asyncbackend.py b/lib/dns/_asyncbackend.py new file mode 100644 index 00000000..c7ecfada --- /dev/null +++ b/lib/dns/_asyncbackend.py @@ -0,0 +1,60 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This is a nullcontext for both sync and async. 3.7 has a nullcontext, +# but it is only for sync use. + +class NullContext: + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, exc_type, exc_value, traceback): + pass + + async def __aenter__(self): + return self.enter_result + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + +# These are declared here so backends can import them without creating +# circular dependencies with dns.asyncbackend. + +class Socket: # pragma: no cover + async def close(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + +class DatagramSocket(Socket): # pragma: no cover + async def sendto(self, what, destination, timeout): + pass + + async def recvfrom(self, size, timeout): + pass + + +class StreamSocket(Socket): # pragma: no cover + async def sendall(self, what, destination, timeout): + pass + + async def recv(self, size, timeout): + pass + + +class Backend: # pragma: no cover + def name(self): + return 'unknown' + + async def make_socket(self, af, socktype, proto=0, + source=None, destination=None, timeout=None, + ssl_context=None, server_hostname=None): + raise NotImplementedError diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py new file mode 100644 index 00000000..3af34ff8 --- /dev/null +++ b/lib/dns/_asyncio_backend.py @@ -0,0 +1,140 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""asyncio library query support""" + +import socket +import asyncio + +import dns._asyncbackend +import dns.exception + + +def _get_running_loop(): + try: + return asyncio.get_running_loop() + except AttributeError: # pragma: no cover + return asyncio.get_event_loop() + + +class _DatagramProtocol: + def __init__(self): + self.transport = None + self.recvfrom = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + if self.recvfrom: + self.recvfrom.set_result((data, addr)) + self.recvfrom = None + + def error_received(self, exc): # pragma: no cover + if self.recvfrom: + self.recvfrom.set_exception(exc) + + def connection_lost(self, exc): + if self.recvfrom: + self.recvfrom.set_exception(exc) + + def close(self): + self.transport.close() + + +async def _maybe_wait_for(awaitable, timeout): + if timeout: + try: + return await asyncio.wait_for(awaitable, timeout) + except asyncio.TimeoutError: + raise dns.exception.Timeout(timeout=timeout) + else: + return await awaitable + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, family, transport, protocol): + self.family = family + self.transport = transport + self.protocol = protocol + + async def sendto(self, what, destination, timeout): # pragma: no cover + # no timeout for asyncio sendto + self.transport.sendto(what, destination) + + async def recvfrom(self, size, timeout): + # ignore size as there's no way I know to tell protocol about it + done = _get_running_loop().create_future() + assert self.protocol.recvfrom is None + self.protocol.recvfrom = done + await _maybe_wait_for(done, timeout) + return done.result() + + async def close(self): + self.protocol.close() + + async def getpeername(self): + return self.transport.get_extra_info('peername') + + async def getsockname(self): + return self.transport.get_extra_info('sockname') + + +class StreamSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, af, reader, writer): + self.family = af + self.reader = reader + self.writer = writer + + async def sendall(self, what, timeout): + self.writer.write(what), + return await _maybe_wait_for(self.writer.drain(), timeout) + raise dns.exception.Timeout(timeout=timeout) + + async def recv(self, count, timeout): + return await _maybe_wait_for(self.reader.read(count), + timeout) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + self.writer.close() + try: + await self.writer.wait_closed() + except AttributeError: # pragma: no cover + pass + + async def getpeername(self): + return self.writer.get_extra_info('peername') + + async def getsockname(self): + return self.writer.get_extra_info('sockname') + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return 'asyncio' + + async def make_socket(self, af, socktype, proto=0, + source=None, destination=None, timeout=None, + ssl_context=None, server_hostname=None): + loop = _get_running_loop() + if socktype == socket.SOCK_DGRAM: + transport, protocol = await loop.create_datagram_endpoint( + _DatagramProtocol, source, family=af, + proto=proto) + return DatagramSocket(af, transport, protocol) + elif socktype == socket.SOCK_STREAM: + (r, w) = await _maybe_wait_for( + asyncio.open_connection(destination[0], + destination[1], + ssl=ssl_context, + family=af, + proto=proto, + local_addr=source, + server_hostname=server_hostname), + timeout) + return StreamSocket(af, r, w) + raise NotImplementedError('unsupported socket ' + + f'type {socktype}') # pragma: no cover + + async def sleep(self, interval): + await asyncio.sleep(interval) diff --git a/lib/dns/_compat.py b/lib/dns/_compat.py deleted file mode 100644 index cffe4bb9..00000000 --- a/lib/dns/_compat.py +++ /dev/null @@ -1,21 +0,0 @@ -import sys - - -if sys.version_info > (3,): - long = int - xrange = range -else: - long = long - xrange = xrange - -# unicode / binary types -if sys.version_info > (3,): - text_type = str - binary_type = bytes - string_types = (str,) - unichr = chr -else: - text_type = unicode - binary_type = str - string_types = (basestring,) - unichr = unichr diff --git a/lib/dns/_curio_backend.py b/lib/dns/_curio_backend.py new file mode 100644 index 00000000..300e1b89 --- /dev/null +++ b/lib/dns/_curio_backend.py @@ -0,0 +1,106 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""curio async I/O library query support""" + +import socket +import curio +import curio.socket # type: ignore + +import dns._asyncbackend +import dns.exception +import dns.inet + + +def _maybe_timeout(timeout): + if timeout: + return curio.ignore_after(timeout) + else: + return dns._asyncbackend.NullContext() + + +# for brevity +_lltuple = dns.inet.low_level_address_tuple + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + self.socket = socket + self.family = socket.family + + async def sendto(self, what, destination, timeout): + async with _maybe_timeout(timeout): + return await self.socket.sendto(what, destination) + raise dns.exception.Timeout(timeout=timeout) # pragma: no cover + + async def recvfrom(self, size, timeout): + async with _maybe_timeout(timeout): + return await self.socket.recvfrom(size) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + await self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + async def getsockname(self): + return self.socket.getsockname() + + +class StreamSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + self.socket = socket + self.family = socket.family + + async def sendall(self, what, timeout): + async with _maybe_timeout(timeout): + return await self.socket.sendall(what) + raise dns.exception.Timeout(timeout=timeout) + + async def recv(self, size, timeout): + async with _maybe_timeout(timeout): + return await self.socket.recv(size) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + await self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + async def getsockname(self): + return self.socket.getsockname() + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return 'curio' + + async def make_socket(self, af, socktype, proto=0, + source=None, destination=None, timeout=None, + ssl_context=None, server_hostname=None): + if socktype == socket.SOCK_DGRAM: + s = curio.socket.socket(af, socktype, proto) + try: + if source: + s.bind(_lltuple(source, af)) + except Exception: # pragma: no cover + await s.close() + raise + return DatagramSocket(s) + elif socktype == socket.SOCK_STREAM: + if source: + source_addr = _lltuple(source, af) + else: + source_addr = None + async with _maybe_timeout(timeout): + s = await curio.open_connection(destination[0], destination[1], + ssl=ssl_context, + source_addr=source_addr, + server_hostname=server_hostname) + return StreamSocket(s) + raise NotImplementedError('unsupported socket ' + + f'type {socktype}') # pragma: no cover + + async def sleep(self, interval): + await curio.sleep(interval) diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py new file mode 100644 index 00000000..92ea8796 --- /dev/null +++ b/lib/dns/_trio_backend.py @@ -0,0 +1,119 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""trio async I/O library query support""" + +import socket +import trio +import trio.socket # type: ignore + +import dns._asyncbackend +import dns.exception +import dns.inet + + +def _maybe_timeout(timeout): + if timeout: + return trio.move_on_after(timeout) + else: + return dns._asyncbackend.NullContext() + + +# for brevity +_lltuple = dns.inet.low_level_address_tuple + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + self.socket = socket + self.family = socket.family + + async def sendto(self, what, destination, timeout): + with _maybe_timeout(timeout): + return await self.socket.sendto(what, destination) + raise dns.exception.Timeout(timeout=timeout) # pragma: no cover + + async def recvfrom(self, size, timeout): + with _maybe_timeout(timeout): + return await self.socket.recvfrom(size) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + async def getsockname(self): + return self.socket.getsockname() + + +class StreamSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, family, stream, tls=False): + self.family = family + self.stream = stream + self.tls = tls + + async def sendall(self, what, timeout): + with _maybe_timeout(timeout): + return await self.stream.send_all(what) + raise dns.exception.Timeout(timeout=timeout) + + async def recv(self, size, timeout): + with _maybe_timeout(timeout): + return await self.stream.receive_some(size) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + await self.stream.aclose() + + async def getpeername(self): + if self.tls: + return self.stream.transport_stream.socket.getpeername() + else: + return self.stream.socket.getpeername() + + async def getsockname(self): + if self.tls: + return self.stream.transport_stream.socket.getsockname() + else: + return self.stream.socket.getsockname() + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return 'trio' + + async def make_socket(self, af, socktype, proto=0, source=None, + destination=None, timeout=None, + ssl_context=None, server_hostname=None): + s = trio.socket.socket(af, socktype, proto) + stream = None + try: + if source: + await s.bind(_lltuple(source, af)) + if socktype == socket.SOCK_STREAM: + with _maybe_timeout(timeout): + await s.connect(_lltuple(destination, af)) + except Exception: # pragma: no cover + s.close() + raise + if socktype == socket.SOCK_DGRAM: + return DatagramSocket(s) + elif socktype == socket.SOCK_STREAM: + stream = trio.SocketStream(s) + s = None + tls = False + if ssl_context: + tls = True + try: + stream = trio.SSLStream(stream, ssl_context, + server_hostname=server_hostname) + except Exception: # pragma: no cover + await stream.aclose() + raise + return StreamSocket(af, stream, tls) + raise NotImplementedError('unsupported socket ' + + f'type {socktype}') # pragma: no cover + + async def sleep(self, interval): + await trio.sleep(interval) diff --git a/lib/dns/asyncbackend.py b/lib/dns/asyncbackend.py new file mode 100644 index 00000000..9582a6f8 --- /dev/null +++ b/lib/dns/asyncbackend.py @@ -0,0 +1,96 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.exception + +from dns._asyncbackend import Socket, DatagramSocket, \ + StreamSocket, Backend # noqa: + + +_default_backend = None + +_backends = {} + +# Allow sniffio import to be disabled for testing purposes +_no_sniffio = False + +class AsyncLibraryNotFoundError(dns.exception.DNSException): + pass + + +def get_backend(name): + """Get the specified asychronous backend. + + *name*, a ``str``, the name of the backend. Currently the "trio", + "curio", and "asyncio" backends are available. + + Raises NotImplementError if an unknown backend name is specified. + """ + backend = _backends.get(name) + if backend: + return backend + if name == 'trio': + import dns._trio_backend + backend = dns._trio_backend.Backend() + elif name == 'curio': + import dns._curio_backend + backend = dns._curio_backend.Backend() + elif name == 'asyncio': + import dns._asyncio_backend + backend = dns._asyncio_backend.Backend() + else: + raise NotImplementedError(f'unimplemented async backend {name}') + _backends[name] = backend + return backend + + +def sniff(): + """Attempt to determine the in-use asynchronous I/O library by using + the ``sniffio`` module if it is available. + + Returns the name of the library, or raises AsyncLibraryNotFoundError + if the library cannot be determined. + """ + try: + if _no_sniffio: + raise ImportError + import sniffio + try: + return sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + raise AsyncLibraryNotFoundError('sniffio cannot determine ' + + 'async library') + except ImportError: + import asyncio + try: + asyncio.get_running_loop() + return 'asyncio' + except RuntimeError: + raise AsyncLibraryNotFoundError('no async library detected') + except AttributeError: # pragma: no cover + # we have to check current_task on 3.6 + if not asyncio.Task.current_task(): + raise AsyncLibraryNotFoundError('no async library detected') + return 'asyncio' + + +def get_default_backend(): + """Get the default backend, initializing it if necessary. + """ + if _default_backend: + return _default_backend + + return set_default_backend(sniff()) + + +def set_default_backend(name): + """Set the default backend. + + It's not normally necessary to call this method, as + ``get_default_backend()`` will initialize the backend + appropriately in many cases. If ``sniffio`` is not installed, or + in testing situations, this function allows the backend to be set + explicitly. + """ + global _default_backend + _default_backend = get_backend(name) + return _default_backend diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py new file mode 100644 index 00000000..b7926480 --- /dev/null +++ b/lib/dns/asyncquery.py @@ -0,0 +1,500 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Talk to a DNS server.""" + +import socket +import struct +import time + +import dns.asyncbackend +import dns.exception +import dns.inet +import dns.name +import dns.message +import dns.rcode +import dns.rdataclass +import dns.rdatatype + +from dns.query import _compute_times, _matches_destination, BadResponse, ssl + + +# for brevity +_lltuple = dns.inet.low_level_address_tuple + + +def _source_tuple(af, address, port): + # Make a high level source tuple, or return None if address and port + # are both None + if address or port: + if address is None: + if af == socket.AF_INET: + address = '0.0.0.0' + elif af == socket.AF_INET6: + address = '::' + else: + raise NotImplementedError(f'unknown address family {af}') + return (address, port) + else: + return None + + +def _timeout(expiration, now=None): + if expiration: + if not now: + now = time.time() + return max(expiration - now, 0) + else: + return None + + +async def send_udp(sock, what, destination, expiration=None): + """Send a DNS message to the specified UDP socket. + + *sock*, a ``dns.asyncbackend.DatagramSocket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where to send the query. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + sent_time = time.time() + n = await sock.sendto(what, destination, _timeout(expiration, sent_time)) + return (n, sent_time) + + +async def receive_udp(sock, destination=None, expiration=None, + ignore_unexpected=False, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False, + raise_on_truncation=False): + """Read a DNS message from a UDP socket. + + *sock*, a ``dns.asyncbackend.DatagramSocket``. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received + message, the received time, and the address where the message arrived from. + """ + + wire = b'' + while 1: + (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): + break + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation) + return (r, received_time, from_address) + +async def udp(q, where, timeout=None, port=53, source=None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False, + ignore_trailing=False, raise_on_truncation=False, sock=None, + backend=None): + """Return the response obtained after sending a query via UDP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, + the socket to use for the query. If ``None``, the default, a + socket is created. Note that if a socket is provided, the + *source*, *source_port*, and *backend* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Returns a ``dns.message.Message``. + """ + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + s = None + # After 3.6 is no longer supported, this can use an AsyncExitStack. + try: + af = dns.inet.af_for_address(where) + destination = _lltuple((where, port), af) + if sock: + s = sock + else: + if not backend: + backend = dns.asyncbackend.get_default_backend() + stuple = _source_tuple(af, source, source_port) + s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) + await send_udp(s, wire, destination, expiration) + (r, received_time, _) = await receive_udp(s, destination, expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing, + raise_on_truncation) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + finally: + if not sock and s: + await s.close() + +async def udp_with_fallback(q, where, timeout=None, port=53, source=None, + source_port=0, ignore_unexpected=False, + one_rr_per_rrset=False, ignore_trailing=False, + udp_sock=None, tcp_sock=None, backend=None): + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, + the socket to use for the UDP query. If ``None``, the default, a + socket is created. Note that if a socket is provided the *source*, + *source_port*, and *backend* are ignored for the UDP query. + + *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the + socket to use for the TCP query. If ``None``, the default, a + socket is created. Note that if a socket is provided *where*, + *source*, *source_port*, and *backend* are ignored for the TCP query. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` + if and only if TCP was used. + """ + try: + response = await udp(q, where, timeout, port, source, source_port, + ignore_unexpected, one_rr_per_rrset, + ignore_trailing, True, udp_sock, backend) + return (response, False) + except dns.message.Truncated: + response = await tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, tcp_sock, + backend) + return (response, True) + + +async def send_tcp(sock, what, expiration=None): + """Send a DNS message to the specified TCP socket. + + *sock*, a ``socket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + l = len(what) + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = struct.pack("!H", l) + what + sent_time = time.time() + await sock.sendall(tcpmsg, expiration) + return (len(tcpmsg), sent_time) + + +async def _read_exactly(sock, count, expiration): + """Read the specified number of bytes from stream. Keep trying until we + either get the desired amount, or we hit EOF. + """ + s = b'' + while count > 0: + n = await sock.recv(count, _timeout(expiration)) + if n == b'': + raise EOFError + count = count - len(n) + s = s + n + return s + + +async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False): + """Read a DNS message from a TCP socket. + + *sock*, a ``socket``. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``(dns.message.Message, float)`` tuple of the received message + and the received time. + """ + + ldata = await _read_exactly(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = await _read_exactly(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + return (r, received_time) + + +async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock=None, + backend=None): + """Return the response obtained after sending a query via TCP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the + socket to use for the query. If ``None``, the default, a socket + is created. Note that if a socket is provided + *where*, *port*, *source*, *source_port*, and *backend* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Returns a ``dns.message.Message``. + """ + + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + s = None + # After 3.6 is no longer supported, this can use an AsyncExitStack. + try: + if sock: + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + await sock.getpeername() + s = sock + else: + # These are simple (address, port) pairs, not + # family-dependent tuples you pass to lowlevel socket + # code. + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, + dtuple, timeout) + await send_tcp(s, wire, expiration) + (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + finally: + if not sock and s: + await s.close() + +async def tls(q, where, timeout=None, port=853, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock=None, + backend=None, ssl_context=None, server_hostname=None): + """Return the response obtained after sending a query via TLS. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 853. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket + to use for the query. If ``None``, the default, a socket is + created. Note that if a socket is provided, it must be a + connected SSL stream socket, and *where*, *port*, + *source*, *source_port*, *backend*, *ssl_context*, and *server_hostname* + are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing + a TLS connection. If ``None``, the default, creates one with the default + configuration. + + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + + Returns a ``dns.message.Message``. + """ + # After 3.6 is no longer supported, this can use an AsyncExitStack. + (begin_time, expiration) = _compute_times(timeout) + if not sock: + if ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + ssl_context.check_hostname = False + else: + ssl_context = None + server_hostname = None + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, + dtuple, timeout, ssl_context, + server_hostname) + else: + s = sock + try: + timeout = _timeout(expiration) + response = await tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, s, backend) + end_time = time.time() + response.time = end_time - begin_time + return response + finally: + if not sock and s: + await s.close() diff --git a/lib/dns/asyncresolver.py b/lib/dns/asyncresolver.py new file mode 100644 index 00000000..3ac334f5 --- /dev/null +++ b/lib/dns/asyncresolver.py @@ -0,0 +1,257 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Asynchronous DNS stub resolver.""" + +import time + +import dns.asyncbackend +import dns.asyncquery +import dns.exception +import dns.query +import dns.resolver + +# import some resolver symbols for brevity +from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA + + +# for indentation purposes below +_udp = dns.asyncquery.udp +_tcp = dns.asyncquery.tcp + + +class Resolver(dns.resolver.Resolver): + + async def resolve(self, qname, rdtype=dns.rdatatype.A, + rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime=None, search=None, + backend=None): + """Query nameservers asynchronously to find the answer to the question. + + The *qname*, *rdtype*, and *rdclass* parameters may be objects + of the appropriate type, or strings that can be converted into objects + of the appropriate type. + + *qname*, a ``dns.name.Name`` or ``str``, the query name. + + *rdtype*, an ``int`` or ``str``, the query type. + + *rdclass*, an ``int`` or ``str``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *source*, a ``str`` or ``None``. If not ``None``, bind to this IP + address when making queries. + + *raise_on_no_answer*, a ``bool``. If ``True``, raise + ``dns.resolver.NoAnswer`` if there's no answer to the question. + + *source_port*, an ``int``, the port from which to send the message. + + *lifetime*, a ``float``, how many seconds a query should run + before timing out. + + *search*, a ``bool`` or ``None``, determines whether the + search list configured in the system's resolver configuration + are used for relative names, and whether the resolver's domain + may be added to relative names. The default is ``None``, + which causes the value of the resolver's + ``use_search_by_default`` attribute to be used. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. + + Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after + DNAME substitution. + + Raises ``dns.resolver.NoAnswer`` if *raise_on_no_answer* is + ``True`` and the query name exists but has no RRset of the + desired type and class. + + Raises ``dns.resolver.NoNameservers`` if no non-broken + nameservers are available to answer the question. + + Returns a ``dns.resolver.Answer`` instance. + + """ + + resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, + raise_on_no_answer, search) + if not backend: + backend = dns.asyncbackend.get_default_backend() + start = time.time() + while True: + (request, answer) = resolution.next_request() + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + # cache hit! + return answer + done = False + while not done: + (nameserver, port, tcp, backoff) = resolution.next_nameserver() + if backoff: + await backend.sleep(backoff) + timeout = self._compute_timeout(start, lifetime) + try: + if dns.inet.is_address(nameserver): + if tcp: + response = await _tcp(request, nameserver, + timeout, port, + source, source_port, + backend=backend) + else: + response = await _udp(request, nameserver, + timeout, port, + source, source_port, + raise_on_truncation=True, + backend=backend) + else: + # We don't do DoH yet. + raise NotImplementedError + except Exception as ex: + (_, done) = resolution.query_result(None, ex) + continue + (answer, done) = resolution.query_result(response, None) + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + return answer + + async def query(self, *args, **kwargs): + # We have to define something here as we don't want to inherit the + # parent's query(). + raise NotImplementedError + + async def resolve_address(self, ipaddr, *args, **kwargs): + """Use an asynchronous resolver to run a reverse query for PTR + records. + + This utilizes the resolve() method to perform a PTR lookup on the + specified IP address. + + *ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get + the PTR record for. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + + """ + + return await self.resolve(dns.reversename.from_address(ipaddr), + rdtype=dns.rdatatype.PTR, + rdclass=dns.rdataclass.IN, + *args, **kwargs) + +default_resolver = None + + +def get_default_resolver(): + """Get the default asynchronous resolver, initializing it if necessary.""" + if default_resolver is None: + reset_default_resolver() + return default_resolver + + +def reset_default_resolver(): + """Re-initialize default asynchronous resolver. + + Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX + systems) will be re-read immediately. + """ + + global default_resolver + default_resolver = Resolver() + + +async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, search=None, backend=None): + """Query nameservers asynchronously to find the answer to the question. + + This is a convenience function that uses the default resolver + object to make the query. + + See ``dns.asyncresolver.Resolver.resolve`` for more information on the + parameters. + """ + + return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, + source, raise_on_no_answer, + source_port, search, backend) + + +async def resolve_address(ipaddr, *args, **kwargs): + """Use a resolver to run a reverse query for PTR records. + + See ``dns.asyncresolver.Resolver.resolve_address`` for more + information on the parameters. + """ + + return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) + + +async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, + resolver=None, backend=None): + """Find the name of the zone which contains the specified name. + + *name*, an absolute ``dns.name.Name`` or ``str``, the query name. + + *rdclass*, an ``int``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the + resolver to use. If ``None``, the default resolver is used. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS + root. (This is only likely to happen if you're using non-default + root servers in your network and they are misconfigured.) + + Returns a ``dns.name.Name``. + """ + + if isinstance(name, str): + name = dns.name.from_text(name, dns.name.root) + if resolver is None: + resolver = get_default_resolver() + if not name.is_absolute(): + raise NotAbsolute(name) + while True: + try: + answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass, + tcp, backend=backend) + if answer.rrset.name == name: + return name + # otherwise we were CNAMEd or DNAMEd and need to look higher + except (NXDOMAIN, NoAnswer): + pass + try: + name = name.parent() + except dns.name.NoParent: # pragma: no cover + raise NoRootSOA diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index fec12082..c50abf8d 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,99 +17,88 @@ """Common DNSSEC-related functions and constants.""" -from io import BytesIO +import hashlib import struct import time +import base64 +import dns.enum import dns.exception -import dns.hash import dns.name import dns.node import dns.rdataset import dns.rdata import dns.rdatatype import dns.rdataclass -from ._compat import string_types class UnsupportedAlgorithm(dns.exception.DNSException): - """The DNSSEC algorithm is not supported.""" class ValidationFailure(dns.exception.DNSException): - """The DNSSEC signature is invalid.""" -RSAMD5 = 1 -DH = 2 -DSA = 3 -ECC = 4 -RSASHA1 = 5 -DSANSEC3SHA1 = 6 -RSASHA1NSEC3SHA1 = 7 -RSASHA256 = 8 -RSASHA512 = 10 -ECDSAP256SHA256 = 13 -ECDSAP384SHA384 = 14 -INDIRECT = 252 -PRIVATEDNS = 253 -PRIVATEOID = 254 -_algorithm_by_text = { - 'RSAMD5': RSAMD5, - 'DH': DH, - 'DSA': DSA, - 'ECC': ECC, - 'RSASHA1': RSASHA1, - 'DSANSEC3SHA1': DSANSEC3SHA1, - 'RSASHA1NSEC3SHA1': RSASHA1NSEC3SHA1, - 'RSASHA256': RSASHA256, - 'RSASHA512': RSASHA512, - 'INDIRECT': INDIRECT, - 'ECDSAP256SHA256': ECDSAP256SHA256, - 'ECDSAP384SHA384': ECDSAP384SHA384, - 'PRIVATEDNS': PRIVATEDNS, - 'PRIVATEOID': PRIVATEOID, -} +class Algorithm(dns.enum.IntEnum): + RSAMD5 = 1 + DH = 2 + DSA = 3 + ECC = 4 + RSASHA1 = 5 + DSANSEC3SHA1 = 6 + RSASHA1NSEC3SHA1 = 7 + RSASHA256 = 8 + RSASHA512 = 10 + ECCGOST = 12 + ECDSAP256SHA256 = 13 + ECDSAP384SHA384 = 14 + ED25519 = 15 + ED448 = 16 + INDIRECT = 252 + PRIVATEDNS = 253 + PRIVATEOID = 254 -# We construct the inverse mapping programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mapping not to be true inverse. + @classmethod + def _maximum(cls): + return 255 -_algorithm_by_value = dict((y, x) for x, y in _algorithm_by_text.items()) + +globals().update(Algorithm.__members__) def algorithm_from_text(text): - """Convert text into a DNSSEC algorithm value - @rtype: int""" + """Convert text into a DNSSEC algorithm value. - value = _algorithm_by_text.get(text.upper()) - if value is None: - value = int(text) - return value + *text*, a ``str``, the text to convert to into an algorithm value. + + Returns an ``int``. + """ + + return Algorithm.from_text(text) def algorithm_to_text(value): """Convert a DNSSEC algorithm value to text - @rtype: string""" - text = _algorithm_by_value.get(value) - if text is None: - text = str(value) - return text + *value*, an ``int`` a DNSSEC algorithm. + + Returns a ``str``, the name of a DNSSEC algorithm. + """ + + return Algorithm.to_text(value) -def _to_rdata(record, origin): - s = BytesIO() - record.to_wire(s, origin=origin) - return s.getvalue() +def key_id(key): + """Return the key id (a 16-bit number) for the specified key. + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` -def key_id(key, origin=None): - rdata = _to_rdata(key, origin) - rdata = bytearray(rdata) - if key.algorithm == RSAMD5: + Returns an ``int`` between 0 and 65535 + """ + + rdata = key.to_wire() + if key.algorithm == Algorithm.RSAMD5: return (rdata[-3] << 8) + rdata[-2] else: total = 0 @@ -119,24 +110,60 @@ def key_id(key, origin=None): total += ((total >> 16) & 0xffff) return total & 0xffff +class DSDigest(dns.enum.IntEnum): + """DNSSEC Delgation Signer Digest Algorithm""" + + SHA1 = 1 + SHA256 = 2 + SHA384 = 4 + + @classmethod + def _maximum(cls): + return 255 + def make_ds(name, key, algorithm, origin=None): - if algorithm.upper() == 'SHA1': - dsalg = 1 - hash = dns.hash.hashes['SHA1']() - elif algorithm.upper() == 'SHA256': - dsalg = 2 - hash = dns.hash.hashes['SHA256']() + """Create a DS record for a DNSSEC key. + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. + + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``, the key the DS is about. + + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. + + *origin*, a ``dns.name.Name`` or ``None``. If `key` is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.rdtypes.ANY.DS.DS`` + """ + + try: + if isinstance(algorithm, str): + algorithm = DSDigest[algorithm.upper()] + except Exception: + raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + + if algorithm == DSDigest.SHA1: + dshash = hashlib.sha1() + elif algorithm == DSDigest.SHA256: + dshash = hashlib.sha256() + elif algorithm == DSDigest.SHA384: + dshash = hashlib.sha384() else: raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) - if isinstance(name, string_types): + if isinstance(name, str): name = dns.name.from_text(name, origin) - hash.update(name.canonicalize().to_wire()) - hash.update(_to_rdata(key, origin)) - digest = hash.digest() + dshash.update(name.canonicalize().to_wire()) + dshash.update(key.to_wire(origin=origin)) + digest = dshash.digest() - dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, dsalg) + digest + dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \ + digest return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, len(dsrdata)) @@ -162,101 +189,109 @@ def _find_candidate_keys(keys, rrsig): def _is_rsa(algorithm): - return algorithm in (RSAMD5, RSASHA1, - RSASHA1NSEC3SHA1, RSASHA256, - RSASHA512) + return algorithm in (Algorithm.RSAMD5, Algorithm.RSASHA1, + Algorithm.RSASHA1NSEC3SHA1, Algorithm.RSASHA256, + Algorithm.RSASHA512) def _is_dsa(algorithm): - return algorithm in (DSA, DSANSEC3SHA1) + return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1) def _is_ecdsa(algorithm): - return _have_ecdsa and (algorithm in (ECDSAP256SHA256, ECDSAP384SHA384)) + return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384) + + +def _is_eddsa(algorithm): + return algorithm in (Algorithm.ED25519, Algorithm.ED448) + + +def _is_gost(algorithm): + return algorithm == Algorithm.ECCGOST def _is_md5(algorithm): - return algorithm == RSAMD5 + return algorithm == Algorithm.RSAMD5 def _is_sha1(algorithm): - return algorithm in (DSA, RSASHA1, - DSANSEC3SHA1, RSASHA1NSEC3SHA1) + return algorithm in (Algorithm.DSA, Algorithm.RSASHA1, + Algorithm.DSANSEC3SHA1, Algorithm.RSASHA1NSEC3SHA1) def _is_sha256(algorithm): - return algorithm in (RSASHA256, ECDSAP256SHA256) + return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256) def _is_sha384(algorithm): - return algorithm == ECDSAP384SHA384 + return algorithm == Algorithm.ECDSAP384SHA384 def _is_sha512(algorithm): - return algorithm == RSASHA512 + return algorithm == Algorithm.RSASHA512 def _make_hash(algorithm): if _is_md5(algorithm): - return dns.hash.hashes['MD5']() + return hashes.MD5() if _is_sha1(algorithm): - return dns.hash.hashes['SHA1']() + return hashes.SHA1() if _is_sha256(algorithm): - return dns.hash.hashes['SHA256']() + return hashes.SHA256() if _is_sha384(algorithm): - return dns.hash.hashes['SHA384']() + return hashes.SHA384() if _is_sha512(algorithm): - return dns.hash.hashes['SHA512']() + return hashes.SHA512() + if algorithm == Algorithm.ED25519: + return hashes.SHA512() + if algorithm == Algorithm.ED448: + return hashes.SHAKE256(114) + raise ValidationFailure('unknown hash for algorithm %u' % algorithm) -def _make_algorithm_id(algorithm): - if _is_md5(algorithm): - oid = [0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05] - elif _is_sha1(algorithm): - oid = [0x2b, 0x0e, 0x03, 0x02, 0x1a] - elif _is_sha256(algorithm): - oid = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01] - elif _is_sha512(algorithm): - oid = [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03] - else: - raise ValidationFailure('unknown algorithm %u' % algorithm) - olen = len(oid) - dlen = _make_hash(algorithm).digest_size - idbytes = [0x30] + [8 + olen + dlen] + \ - [0x30, olen + 4] + [0x06, olen] + oid + \ - [0x05, 0x00] + [0x04, dlen] - return struct.pack('!%dB' % len(idbytes), *idbytes) +def _bytes_to_long(b): + return int.from_bytes(b, 'big') def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): - """Validate an RRset against a single signature rdata + """Validate an RRset against a single signature rdata, throwing an + exception if validation is not successful. - The owner name of the rrsig is assumed to be the same as the owner name - of the rrset. + *rrset*, the RRset to validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. - @param rrset: The RRset to validate - @type rrset: dns.rrset.RRset or (dns.name.Name, dns.rdataset.Rdataset) - tuple - @param rrsig: The signature rdata - @type rrsig: dns.rrset.Rdata - @param keys: The key dictionary. - @type keys: a dictionary keyed by dns.name.Name with node or rdataset - values - @param origin: The origin to use for relative names - @type origin: dns.name.Name or None - @param now: The time to use when validating the signatures. The default - is the current time. - @type now: int + *rrsig*, a ``dns.rdata.Rdata``, the signature to validate. + + *keys*, the key dictionary, used to find the DNSKEY associated + with a given name. The dictionary is keyed by a + ``dns.name.Name``, and has ``dns.node.Node`` or + ``dns.rdataset.Rdataset`` values. + + *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative + names. + + *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to + use as the current time when validating. If ``None``, the actual current + time is used. + + Raises ``ValidationFailure`` if the signature is expired, not yet valid, + the public key is invalid, the algorithm is unknown, the verification + fails, etc. + + Raises ``UnsupportedAlgorithm`` if the algorithm is recognized by + dnspython but not implemented. """ - if isinstance(origin, string_types): + if isinstance(origin, str): origin = dns.name.from_text(origin, dns.name.root) - for candidate_key in _find_candidate_keys(keys, rrsig): - if not candidate_key: - raise ValidationFailure('unknown key') + candidate_keys = _find_candidate_keys(keys, rrsig) + if candidate_keys is None: + raise ValidationFailure('unknown key') + for candidate_key in candidate_keys: # For convenience, allow the rrset to be specified as a (name, # rdataset) tuple as well as a proper rrset if isinstance(rrset, tuple): @@ -273,8 +308,6 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): if rrsig.inception > now: raise ValidationFailure('not yet valid') - hash = _make_hash(rrsig.algorithm) - if _is_rsa(rrsig.algorithm): keyptr = candidate_key.key (bytes_,) = struct.unpack('!B', keyptr[0:1]) @@ -284,11 +317,13 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): keyptr = keyptr[2:] rsa_e = keyptr[0:bytes_] rsa_n = keyptr[bytes_:] - keylen = len(rsa_n) * 8 - pubkey = Crypto.PublicKey.RSA.construct( - (Crypto.Util.number.bytes_to_long(rsa_n), - Crypto.Util.number.bytes_to_long(rsa_e))) - sig = (Crypto.Util.number.bytes_to_long(rrsig.signature),) + try: + public_key = rsa.RSAPublicNumbers( + _bytes_to_long(rsa_e), + _bytes_to_long(rsa_n)).public_key(default_backend()) + except ValueError: + raise ValidationFailure('invalid public key') + sig = rrsig.signature elif _is_dsa(rrsig.algorithm): keyptr = candidate_key.key (t,) = struct.unpack('!B', keyptr[0:1]) @@ -301,41 +336,62 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): dsa_g = keyptr[0:octets] keyptr = keyptr[octets:] dsa_y = keyptr[0:octets] - pubkey = Crypto.PublicKey.DSA.construct( - (Crypto.Util.number.bytes_to_long(dsa_y), - Crypto.Util.number.bytes_to_long(dsa_g), - Crypto.Util.number.bytes_to_long(dsa_p), - Crypto.Util.number.bytes_to_long(dsa_q))) - (dsa_r, dsa_s) = struct.unpack('!20s20s', rrsig.signature[1:]) - sig = (Crypto.Util.number.bytes_to_long(dsa_r), - Crypto.Util.number.bytes_to_long(dsa_s)) + try: + public_key = dsa.DSAPublicNumbers( + _bytes_to_long(dsa_y), + dsa.DSAParameterNumbers( + _bytes_to_long(dsa_p), + _bytes_to_long(dsa_q), + _bytes_to_long(dsa_g))).public_key(default_backend()) + except ValueError: + raise ValidationFailure('invalid public key') + sig_r = rrsig.signature[1:21] + sig_s = rrsig.signature[21:] + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), + _bytes_to_long(sig_s)) elif _is_ecdsa(rrsig.algorithm): - if rrsig.algorithm == ECDSAP256SHA256: - curve = ecdsa.curves.NIST256p - key_len = 32 - elif rrsig.algorithm == ECDSAP384SHA384: - curve = ecdsa.curves.NIST384p - key_len = 48 - else: - # shouldn't happen - raise ValidationFailure('unknown ECDSA curve') keyptr = candidate_key.key - x = Crypto.Util.number.bytes_to_long(keyptr[0:key_len]) - y = Crypto.Util.number.bytes_to_long(keyptr[key_len:key_len * 2]) - assert ecdsa.ecdsa.point_is_valid(curve.generator, x, y) - point = ecdsa.ellipticcurve.Point(curve.curve, x, y, curve.order) - verifying_key = ecdsa.keys.VerifyingKey.from_public_point(point, - curve) - pubkey = ECKeyWrapper(verifying_key, key_len) - r = rrsig.signature[:key_len] - s = rrsig.signature[key_len:] - sig = ecdsa.ecdsa.Signature(Crypto.Util.number.bytes_to_long(r), - Crypto.Util.number.bytes_to_long(s)) + if rrsig.algorithm == Algorithm.ECDSAP256SHA256: + curve = ec.SECP256R1() + octets = 32 + else: + curve = ec.SECP384R1() + octets = 48 + ecdsa_x = keyptr[0:octets] + ecdsa_y = keyptr[octets:octets * 2] + try: + public_key = ec.EllipticCurvePublicNumbers( + curve=curve, + x=_bytes_to_long(ecdsa_x), + y=_bytes_to_long(ecdsa_y)).public_key(default_backend()) + except ValueError: + raise ValidationFailure('invalid public key') + sig_r = rrsig.signature[0:octets] + sig_s = rrsig.signature[octets:] + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), + _bytes_to_long(sig_s)) + + elif _is_eddsa(rrsig.algorithm): + keyptr = candidate_key.key + if rrsig.algorithm == Algorithm.ED25519: + loader = ed25519.Ed25519PublicKey + else: + loader = ed448.Ed448PublicKey + try: + public_key = loader.from_public_bytes(keyptr) + except ValueError: + raise ValidationFailure('invalid public key') + sig = rrsig.signature + elif _is_gost(rrsig.algorithm): + raise UnsupportedAlgorithm( + 'algorithm "%s" not supported by dnspython' % + algorithm_to_text(rrsig.algorithm)) else: raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm) - hash.update(_to_rdata(rrsig, origin)[:18]) - hash.update(rrsig.signer.to_digestable(origin)) + data = b'' + data += rrsig.to_wire(origin=origin)[:18] + data += rrsig.signer.to_digestable(origin) if rrsig.labels < len(rrname) - 1: suffix = rrname.split(rrsig.labels + 1)[1] @@ -345,54 +401,69 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): rrsig.original_ttl) rrlist = sorted(rdataset) for rr in rrlist: - hash.update(rrnamebuf) - hash.update(rrfixed) + data += rrnamebuf + data += rrfixed rrdata = rr.to_digestable(origin) rrlen = struct.pack('!H', len(rrdata)) - hash.update(rrlen) - hash.update(rrdata) + data += rrlen + data += rrdata - digest = hash.digest() - - if _is_rsa(rrsig.algorithm): - # PKCS1 algorithm identifier goop - digest = _make_algorithm_id(rrsig.algorithm) + digest - padlen = keylen // 8 - len(digest) - 3 - digest = struct.pack('!%dB' % (2 + padlen + 1), - *([0, 1] + [0xFF] * padlen + [0])) + digest - elif _is_dsa(rrsig.algorithm) or _is_ecdsa(rrsig.algorithm): - pass - else: - # Raise here for code clarity; this won't actually ever happen - # since if the algorithm is really unknown we'd already have - # raised an exception above - raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm) - - if pubkey.verify(digest, sig): + chosen_hash = _make_hash(rrsig.algorithm) + try: + if _is_rsa(rrsig.algorithm): + public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) + elif _is_dsa(rrsig.algorithm): + public_key.verify(sig, data, chosen_hash) + elif _is_ecdsa(rrsig.algorithm): + public_key.verify(sig, data, ec.ECDSA(chosen_hash)) + elif _is_eddsa(rrsig.algorithm): + public_key.verify(sig, data) + else: + # Raise here for code clarity; this won't actually ever happen + # since if the algorithm is really unknown we'd already have + # raised an exception above + raise ValidationFailure('unknown algorithm %u' % + rrsig.algorithm) # pragma: no cover + # If we got here, we successfully verified so we can return + # without error return + except InvalidSignature: + # this happens on an individual validation failure + continue + # nothing verified -- raise failure: raise ValidationFailure('verify failure') def _validate(rrset, rrsigset, keys, origin=None, now=None): - """Validate an RRset + """Validate an RRset against a signature RRset, throwing an exception + if none of the signatures validate. - @param rrset: The RRset to validate - @type rrset: dns.rrset.RRset or (dns.name.Name, dns.rdataset.Rdataset) - tuple - @param rrsigset: The signature RRset - @type rrsigset: dns.rrset.RRset or (dns.name.Name, dns.rdataset.Rdataset) - tuple - @param keys: The key dictionary. - @type keys: a dictionary keyed by dns.name.Name with node or rdataset - values - @param origin: The origin to use for relative names - @type origin: dns.name.Name or None - @param now: The time to use when validating the signatures. The default - is the current time. - @type now: int + *rrset*, the RRset to validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *rrsigset*, the signature RRset. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *keys*, the key dictionary, used to find the DNSKEY associated + with a given name. The dictionary is keyed by a + ``dns.name.Name``, and has ``dns.node.Node`` or + ``dns.rdataset.Rdataset`` values. + + *origin*, a ``dns.name.Name``, the origin to use for relative names; + defaults to None. + + *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to + use as the current time when validating. If ``None``, the actual current + time is used. + + Raises ``ValidationFailure`` if the signature is expired, not yet valid, + the public key is invalid, the algorithm is unknown, the verification + fails, etc. """ - if isinstance(origin, string_types): + if isinstance(origin, str): origin = dns.name.from_text(origin, dns.name.root) if isinstance(rrset, tuple): @@ -408,7 +479,7 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): rrsigrdataset = rrsigset rrname = rrname.choose_relativity(origin) - rrsigname = rrname.choose_relativity(origin) + rrsigname = rrsigname.choose_relativity(origin) if rrname != rrsigname: raise ValidationFailure("owner names do not match") @@ -416,42 +487,95 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): try: _validate_rrsig(rrset, rrsig, keys, origin, now) return - except ValidationFailure: + except (ValidationFailure, UnsupportedAlgorithm): pass raise ValidationFailure("no RRSIGs validated") -def _need_pycrypto(*args, **kwargs): - raise NotImplementedError("DNSSEC validation requires pycrypto") +class NSEC3Hash(dns.enum.IntEnum): + """NSEC3 hash algorithm""" + + SHA1 = 1 + + @classmethod + def _maximum(cls): + return 255 + +def nsec3_hash(domain, salt, iterations, algorithm): + """ + Calculate the NSEC3 hash, according to + https://tools.ietf.org/html/rfc5155#section-5 + + *domain*, a ``dns.name.Name`` or ``str``, the name to hash. + + *salt*, a ``str``, ``bytes``, or ``None``, the hash salt. If a + string, it is decoded as a hex string. + + *iterations*, an ``int``, the number of iterations. + + *algorithm*, a ``str`` or ``int``, the hash algorithm. + The only defined algorithm is SHA1. + + Returns a ``str``, the encoded NSEC3 hash. + """ + + b32_conversion = str.maketrans( + "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV" + ) + + try: + if isinstance(algorithm, str): + algorithm = NSEC3Hash[algorithm.upper()] + except Exception: + raise ValueError("Wrong hash algorithm (only SHA1 is supported)") + + if algorithm != NSEC3Hash.SHA1: + raise ValueError("Wrong hash algorithm (only SHA1 is supported)") + + salt_encoded = salt + if salt is None: + salt_encoded = b'' + elif isinstance(salt, str): + if len(salt) % 2 == 0: + salt_encoded = bytes.fromhex(salt) + else: + raise ValueError("Invalid salt length") + + if not isinstance(domain, dns.name.Name): + domain = dns.name.from_text(domain) + domain_encoded = domain.canonicalize().to_wire() + + digest = hashlib.sha1(domain_encoded + salt_encoded).digest() + for i in range(iterations): + digest = hashlib.sha1(digest + salt_encoded).digest() + + output = base64.b32encode(digest).decode("utf-8") + output = output.translate(b32_conversion) + + return output + + +def _need_pyca(*args, **kwargs): + raise ImportError("DNSSEC validation requires " + + "python cryptography") # pragma: no cover + try: - import Crypto.PublicKey.RSA - import Crypto.PublicKey.DSA - import Crypto.Util.number - validate = _validate - validate_rrsig = _validate_rrsig - _have_pycrypto = True -except ImportError: - validate = _need_pycrypto - validate_rrsig = _need_pycrypto - _have_pycrypto = False - -try: - import ecdsa - import ecdsa.ecdsa - import ecdsa.ellipticcurve - import ecdsa.keys - _have_ecdsa = True - - class ECKeyWrapper(object): - - def __init__(self, key, key_len): - self.key = key - self.key_len = key_len - - def verify(self, digest, sig): - diglong = Crypto.Util.number.bytes_to_long(digest) - return self.key.pubkey.verifies(diglong, sig) - -except ImportError: - _have_ecdsa = False + from cryptography.exceptions import InvalidSignature + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import padding + from cryptography.hazmat.primitives.asymmetric import utils + from cryptography.hazmat.primitives.asymmetric import dsa + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography.hazmat.primitives.asymmetric import ed448 + from cryptography.hazmat.primitives.asymmetric import rsa +except ImportError: # pragma: no cover + validate = _need_pyca + validate_rrsig = _need_pyca + _have_pyca = False +else: + validate = _validate # type: ignore + validate_rrsig = _validate_rrsig # type: ignore + _have_pyca = True diff --git a/lib/dns/e164.py b/lib/dns/e164.py index 2cc911cd..83731b2c 100644 --- a/lib/dns/e164.py +++ b/lib/dns/e164.py @@ -1,4 +1,6 @@ -# Copyright (C) 2006, 2007, 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,31 +15,31 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS E.164 helpers - -@var public_enum_domain: The DNS public ENUM domain, e164.arpa. -@type public_enum_domain: dns.name.Name object -""" - +"""DNS E.164 helpers.""" import dns.exception import dns.name import dns.resolver -from ._compat import string_types +#: The public E.164 domain. public_enum_domain = dns.name.from_text('e164.arpa.') def from_e164(text, origin=public_enum_domain): """Convert an E.164 number in textual form into a Name object whose value is the ENUM domain name for that number. - @param text: an E.164 number in textual form. - @type text: str - @param origin: The domain in which the number should be constructed. - The default is e164.arpa. - @type origin: dns.name.Name object or None - @rtype: dns.name.Name object + + Non-digits in the text are ignored, i.e. "16505551212", + "+1.650.555.1212" and "1 (650) 555-1212" are all the same. + + *text*, a ``str``, is an E.164 number in textual form. + + *origin*, a ``dns.name.Name``, the domain in which the number + should be constructed. The default is ``e164.arpa.``. + + Returns a ``dns.name.Name``. """ + parts = [d for d in text if d.isdigit()] parts.reverse() return dns.name.from_text('.'.join(parts), origin=origin) @@ -45,40 +47,58 @@ def from_e164(text, origin=public_enum_domain): def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): """Convert an ENUM domain name into an E.164 number. - @param name: the ENUM domain name. - @type name: dns.name.Name object. - @param origin: A domain containing the ENUM domain name. The - name is relativized to this domain before being converted to text. - @type origin: dns.name.Name object or None - @param want_plus_prefix: if True, add a '+' to the beginning of the - returned number. - @rtype: str + + Note that dnspython does not have any information about preferred + number formats within national numbering plans, so all numbers are + emitted as a simple string of digits, prefixed by a '+' (unless + *want_plus_prefix* is ``False``). + + *name* is a ``dns.name.Name``, the ENUM domain name. + + *origin* is a ``dns.name.Name``, a domain containing the ENUM + domain name. The name is relativized to this domain before being + converted to text. If ``None``, no relativization is done. + + *want_plus_prefix* is a ``bool``. If True, add a '+' to the beginning of + the returned number. + + Returns a ``str``. + """ if origin is not None: name = name.relativize(origin) - dlabels = [d for d in name.labels if (d.isdigit() and len(d) == 1)] + dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1] if len(dlabels) != len(name.labels): raise dns.exception.SyntaxError('non-digit labels in ENUM domain name') dlabels.reverse() text = b''.join(dlabels) if want_plus_prefix: text = b'+' + text - return text + return text.decode() def query(number, domains, resolver=None): """Look for NAPTR RRs for the specified number in the specified domains. e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) + + *number*, a ``str`` is the number to look for. + + *domains* is an iterable containing ``dns.name.Name`` values. + + *resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If + ``None``, the default resolver is used. """ + if resolver is None: resolver = dns.resolver.get_default_resolver() + e_nx = dns.resolver.NXDOMAIN() for domain in domains: - if isinstance(domain, string_types): + if isinstance(domain, str): domain = dns.name.from_text(domain) qname = dns.e164.from_e164(number, domain) try: - return resolver.query(qname, 'NAPTR') - except dns.resolver.NXDOMAIN: - pass - raise dns.resolver.NXDOMAIN + return resolver.resolve(qname, 'NAPTR') + except dns.resolver.NXDOMAIN as e: + e_nx += e + raise e_nx diff --git a/lib/dns/edns.py b/lib/dns/edns.py index 8ac676bc..28718d52 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -1,4 +1,6 @@ -# Copyright (C) 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,47 +17,85 @@ """EDNS Options""" -NSID = 3 +import math +import socket +import struct +import dns.enum +import dns.inet -class Option(object): +class OptionType(dns.enum.IntEnum): + #: NSID + NSID = 3 + #: DAU + DAU = 5 + #: DHU + DHU = 6 + #: N3U + N3U = 7 + #: ECS (client-subnet) + ECS = 8 + #: EXPIRE + EXPIRE = 9 + #: COOKIE + COOKIE = 10 + #: KEEPALIVE + KEEPALIVE = 11 + #: PADDING + PADDING = 12 + #: CHAIN + CHAIN = 13 - """Base class for all EDNS option types. - """ + @classmethod + def _maximum(cls): + return 65535 + +globals().update(OptionType.__members__) + +class Option: + + """Base class for all EDNS option types.""" def __init__(self, otype): """Initialize an option. - @param otype: The rdata type - @type otype: int + + *otype*, an ``int``, is the option type. """ self.otype = otype - def to_wire(self, file): + def to_wire(self, file=None): """Convert an option to wire format. + + Returns a ``bytes`` or ``None``. + """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover @classmethod - def from_wire(cls, otype, wire, current, olen): - """Build an EDNS option object from wire format + def from_wire_parser(cls, otype, parser): + """Build an EDNS option object from wire format. - @param otype: The option type - @type otype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param olen: The length of the wire-format option data - @type olen: int - @rtype: dns.edns.Option instance""" - raise NotImplementedError + *otype*, an ``int``, is the option type. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restructed to the option length. + + Returns a ``dns.edns.Option``. + """ + raise NotImplementedError # pragma: no cover def _cmp(self, other): """Compare an EDNS option with another option of the same type. - Return < 0 if self < other, 0 if self == other, - and > 0 if self > other. + + Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*. """ - raise NotImplementedError + wire = self.to_wire() + owire = other.to_wire() + if wire == owire: + return 0 + if wire > owire: + return 1 + return -1 def __eq__(self, other): if not isinstance(other, Option): @@ -66,9 +106,9 @@ class Option(object): def __ne__(self, other): if not isinstance(other, Option): - return False + return True if self.otype != other.otype: - return False + return True return self._cmp(other) != 0 def __lt__(self, other): @@ -95,56 +135,210 @@ class Option(object): return NotImplemented return self._cmp(other) > 0 + def __str__(self): + return self.to_text() + class GenericOption(Option): - """Generate Rdata Class + """Generic Option Class This class is used for EDNS option types for which we have no better implementation. """ def __init__(self, otype, data): - super(GenericOption, self).__init__(otype) + super().__init__(otype) self.data = data - def to_wire(self, file): - file.write(self.data) + def to_wire(self, file=None): + if file: + file.write(self.data) + else: + return self.data + + def to_text(self): + return "Generic %d" % self.otype @classmethod - def from_wire(cls, otype, wire, current, olen): - return cls(otype, wire[current: current + olen]) + def from_wire_parser(cls, otype, parser): + return cls(otype, parser.get_remaining()) + + +class ECSOption(Option): + """EDNS Client Subnet (ECS, RFC7871)""" + + def __init__(self, address, srclen=None, scopelen=0): + """*address*, a ``str``, is the client address information. + + *srclen*, an ``int``, the source prefix length, which is the + leftmost number of bits of the address to be used for the + lookup. The default is 24 for IPv4 and 56 for IPv6. + + *scopelen*, an ``int``, the scope prefix length. This value + must be 0 in queries, and should be set in responses. + """ + + super().__init__(OptionType.ECS) + af = dns.inet.af_for_address(address) + + if af == socket.AF_INET6: + self.family = 2 + if srclen is None: + srclen = 56 + elif af == socket.AF_INET: + self.family = 1 + if srclen is None: + srclen = 24 + else: + raise ValueError('Bad ip family') + + self.address = address + self.srclen = srclen + self.scopelen = scopelen + + addrdata = dns.inet.inet_pton(af, address) + nbytes = int(math.ceil(srclen / 8.0)) + + # Truncate to srclen and pad to the end of the last octet needed + # See RFC section 6 + self.addrdata = addrdata[:nbytes] + nbits = srclen % 8 + if nbits != 0: + last = struct.pack('B', + ord(self.addrdata[-1:]) & (0xff << (8 - nbits))) + self.addrdata = self.addrdata[:-1] + last + + def to_text(self): + return "ECS {}/{} scope/{}".format(self.address, self.srclen, + self.scopelen) + + @staticmethod + def from_text(text): + """Convert a string into a `dns.edns.ECSOption` + + *text*, a `str`, the text form of the option. + + Returns a `dns.edns.ECSOption`. + + Examples: + + >>> import dns.edns + >>> + >>> # basic example + >>> dns.edns.ECSOption.from_text('1.2.3.4/24') + >>> + >>> # also understands scope + >>> dns.edns.ECSOption.from_text('1.2.3.4/24/32') + >>> + >>> # IPv6 + >>> dns.edns.ECSOption.from_text('2001:4b98::1/64/64') + >>> + >>> # it understands results from `dns.edns.ECSOption.to_text()` + >>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') + """ + optional_prefix = 'ECS' + tokens = text.split() + ecs_text = None + if len(tokens) == 1: + ecs_text = tokens[0] + elif len(tokens) == 2: + if tokens[0] != optional_prefix: + raise ValueError('could not parse ECS from "{}"'.format(text)) + ecs_text = tokens[1] + else: + raise ValueError('could not parse ECS from "{}"'.format(text)) + n_slashes = ecs_text.count('/') + if n_slashes == 1: + address, srclen = ecs_text.split('/') + scope = 0 + elif n_slashes == 2: + address, srclen, scope = ecs_text.split('/') + else: + raise ValueError('could not parse ECS from "{}"'.format(text)) + try: + scope = int(scope) + except ValueError: + raise ValueError('invalid scope ' + + '"{}": scope must be an integer'.format(scope)) + try: + srclen = int(srclen) + except ValueError: + raise ValueError('invalid srclen ' + + '"{}": srclen must be an integer'.format(srclen)) + return ECSOption(address, srclen, scope) + + def to_wire(self, file=None): + value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + + self.addrdata) + if file: + file.write(value) + else: + return value + + @classmethod + def from_wire_parser(cls, otype, parser): + family, src, scope = parser.get_struct('!HBB') + addrlen = int(math.ceil(src / 8.0)) + prefix = parser.get_bytes(addrlen) + if family == 1: + pad = 4 - addrlen + addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad) + elif family == 2: + pad = 16 - addrlen + addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad) + else: + raise ValueError('unsupported family') + + return cls(addr, src, scope) - def _cmp(self, other): - if self.data == other.data: - return 0 - if self.data > other.data: - return 1 - return -1 _type_to_class = { + OptionType.ECS: ECSOption } - def get_option_class(otype): + """Return the class for the specified option type. + + The GenericOption class is used if a more specific class is not + known. + """ + cls = _type_to_class.get(otype) if cls is None: cls = GenericOption return cls -def option_from_wire(otype, wire, current, olen): - """Build an EDNS option object from wire format +def option_from_wire_parser(otype, parser): + """Build an EDNS option object from wire format. - @param otype: The option type - @type otype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param olen: The length of the wire-format option data - @type olen: int - @rtype: dns.edns.Option instance""" + *otype*, an ``int``, is the option type. + *parser*, a ``dns.wire.Parser``, the parser, which should be + restricted to the option length. + + Returns an instance of a subclass of ``dns.edns.Option``. + """ cls = get_option_class(otype) - return cls.from_wire(otype, wire, current, olen) + otype = OptionType.make(otype) + return cls.from_wire_parser(otype, parser) + + +def option_from_wire(otype, wire, current, olen): + """Build an EDNS option object from wire format. + + *otype*, an ``int``, is the option type. + + *wire*, a ``bytes``, is the wire-format message. + + *current*, an ``int``, is the offset in *wire* of the beginning + of the rdata. + + *olen*, an ``int``, is the length of the wire-format option data + + Returns an instance of a subclass of ``dns.edns.Option``. + """ + parser = dns.wire.Parser(wire, current) + with parser.restrict_to(olen): + return option_from_wire_parser(otype, parser) diff --git a/lib/dns/entropy.py b/lib/dns/entropy.py index 43841a7a..086bba78 100644 --- a/lib/dns/entropy.py +++ b/lib/dns/entropy.py @@ -1,4 +1,6 @@ -# Copyright (C) 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -14,85 +16,76 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import os +import hashlib +import random import time -from ._compat import long, binary_type try: import threading as _threading -except ImportError: - import dummy_threading as _threading +except ImportError: # pragma: no cover + import dummy_threading as _threading # type: ignore -class EntropyPool(object): +class EntropyPool: + + # This is an entropy pool for Python implementations that do not + # have a working SystemRandom. I'm not sure there are any, but + # leaving this code doesn't hurt anything as the library code + # is used if present. def __init__(self, seed=None): self.pool_index = 0 self.digest = None self.next_byte = 0 self.lock = _threading.Lock() - try: - import hashlib - self.hash = hashlib.sha1() - self.hash_len = 20 - except: - try: - import sha - self.hash = sha.new() - self.hash_len = 20 - except: - import md5 - self.hash = md5.new() - self.hash_len = 16 + self.hash = hashlib.sha1() + self.hash_len = 20 self.pool = bytearray(b'\0' * self.hash_len) if seed is not None: - self.stir(bytearray(seed)) + self._stir(bytearray(seed)) self.seeded = True + self.seed_pid = os.getpid() else: self.seeded = False + self.seed_pid = 0 - def stir(self, entropy, already_locked=False): - if not already_locked: - self.lock.acquire() - try: - for c in entropy: - if self.pool_index == self.hash_len: - self.pool_index = 0 - b = c & 0xff - self.pool[self.pool_index] ^= b - self.pool_index += 1 - finally: - if not already_locked: - self.lock.release() + def _stir(self, entropy): + for c in entropy: + if self.pool_index == self.hash_len: + self.pool_index = 0 + b = c & 0xff + self.pool[self.pool_index] ^= b + self.pool_index += 1 + + def stir(self, entropy): + with self.lock: + self._stir(entropy) def _maybe_seed(self): - if not self.seeded: + if not self.seeded or self.seed_pid != os.getpid(): try: seed = os.urandom(16) - except: + except Exception: # pragma: no cover try: - r = open('/dev/urandom', 'rb', 0) - try: + with open('/dev/urandom', 'rb', 0) as r: seed = r.read(16) - finally: - r.close() - except: + except Exception: seed = str(time.time()) self.seeded = True + self.seed_pid = os.getpid() + self.digest = None seed = bytearray(seed) - self.stir(seed, True) + self._stir(seed) def random_8(self): - self.lock.acquire() - try: + with self.lock: self._maybe_seed() if self.digest is None or self.next_byte == self.hash_len: - self.hash.update(binary_type(self.pool)) + self.hash.update(bytes(self.pool)) self.digest = bytearray(self.hash.digest()) - self.stir(self.digest, True) + self._stir(self.digest) self.next_byte = 0 value = self.digest[self.next_byte] self.next_byte += 1 - finally: - self.lock.release() return value def random_16(self): @@ -103,25 +96,34 @@ class EntropyPool(object): def random_between(self, first, last): size = last - first + 1 - if size > long(4294967296): + if size > 4294967296: raise ValueError('too big') if size > 65536: rand = self.random_32 - max = long(4294967295) + max = 4294967295 elif size > 256: rand = self.random_16 max = 65535 else: rand = self.random_8 max = 255 - return (first + size * rand() // (max + 1)) + return first + size * rand() // (max + 1) pool = EntropyPool() +try: + system_random = random.SystemRandom() +except Exception: # pragma: no cover + system_random = None def random_16(): - return pool.random_16() - + if system_random is not None: + return system_random.randrange(0, 65536) + else: + return pool.random_16() def between(first, last): - return pool.random_between(first, last) + if system_random is not None: + return system_random.randrange(first, last + 1) + else: + return pool.random_between(first, last) diff --git a/lib/dns/enum.py b/lib/dns/enum.py new file mode 100644 index 00000000..11536f2b --- /dev/null +++ b/lib/dns/enum.py @@ -0,0 +1,90 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import enum + +class IntEnum(enum.IntEnum): + @classmethod + def _check_value(cls, value): + max = cls._maximum() + if value < 0 or value > max: + name = cls._short_name() + raise ValueError(f"{name} must be between >= 0 and <= {max}") + + @classmethod + def from_text(cls, text): + text = text.upper() + try: + return cls[text] + except KeyError: + pass + prefix = cls._prefix() + if text.startswith(prefix) and text[len(prefix):].isdigit(): + value = int(text[len(prefix):]) + cls._check_value(value) + try: + return cls(value) + except ValueError: + return value + raise cls._unknown_exception_class() + + @classmethod + def to_text(cls, value): + cls._check_value(value) + try: + return cls(value).name + except ValueError: + return f"{cls._prefix()}{value}" + + @classmethod + def make(cls, value): + """Convert text or a value into an enumerated type, if possible. + + *value*, the ``int`` or ``str`` to convert. + + Raises a class-specific exception if a ``str`` is provided that + cannot be converted. + + Raises ``ValueError`` if the value is out of range. + + Returns an enumeration from the calling class corresponding to the + value, if one is defined, or an ``int`` otherwise. + """ + + if isinstance(value, str): + return cls.from_text(value) + cls._check_value(value) + try: + return cls(value) + except ValueError: + return value + + @classmethod + def _maximum(cls): + raise NotImplementedError + + @classmethod + def _short_name(cls): + return cls.__name__.lower() + + @classmethod + def _prefix(cls): + return '' + + @classmethod + def _unknown_exception_class(cls): + return ValueError diff --git a/lib/dns/exception.py b/lib/dns/exception.py index 62fbe2cb..8f1d4888 100644 --- a/lib/dns/exception.py +++ b/lib/dns/exception.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,47 +15,53 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""Common DNS Exceptions.""" +"""Common DNS Exceptions. +Dnspython modules may also define their own exceptions, which will +always be subclasses of ``DNSException``. +""" class DNSException(Exception): - """Abstract base class shared by all dnspython exceptions. It supports two basic modes of operation: - a) Old/compatible mode is used if __init__ was called with - empty **kwargs. - In compatible mode all *args are passed to standard Python Exception class - as before and all *args are printed by standard __str__ implementation. - Class variable msg (or doc string if msg is None) is returned from str() - if *args is empty. + a) Old/compatible mode is used if ``__init__`` was called with + empty *kwargs*. In compatible mode all *args* are passed + to the standard Python Exception class as before and all *args* are + printed by the standard ``__str__`` implementation. Class variable + ``msg`` (or doc string if ``msg`` is ``None``) is returned from ``str()`` + if *args* is empty. - b) New/parametrized mode is used if __init__ was called with - non-empty **kwargs. - In the new mode *args has to be empty and all kwargs has to exactly match - set in class variable self.supp_kwargs. All kwargs are stored inside - self.kwargs and used in new __str__ implementation to construct - formatted message based on self.fmt string. + b) New/parametrized mode is used if ``__init__`` was called with + non-empty *kwargs*. + In the new mode *args* must be empty and all kwargs must match + those set in class variable ``supp_kwargs``. All kwargs are stored inside + ``self.kwargs`` and used in a new ``__str__`` implementation to construct + a formatted message based on the ``fmt`` class variable, a ``string``. - In the simplest case it is enough to override supp_kwargs and fmt - class variables to get nice parametrized messages. + In the simplest case it is enough to override the ``supp_kwargs`` + and ``fmt`` class variables to get nice parametrized messages. """ + msg = None # non-parametrized message supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check) fmt = None # message parametrized with results from _fmt_kwargs def __init__(self, *args, **kwargs): self._check_params(*args, **kwargs) - self._check_kwargs(**kwargs) - self.kwargs = kwargs + if kwargs: + self.kwargs = self._check_kwargs(**kwargs) + self.msg = str(self) + else: + self.kwargs = dict() # defined but empty for old mode exceptions if self.msg is None: # doc string is better implicit message than empty string self.msg = self.__doc__ if args: - super(DNSException, self).__init__(*args) + super().__init__(*args) else: - super(DNSException, self).__init__(self.msg) + super().__init__(self.msg) def _check_params(self, *args, **kwargs): """Old exceptions supported only args and not kwargs. @@ -68,6 +76,7 @@ class DNSException(Exception): assert set(kwargs.keys()) == self.supp_kwargs, \ 'following set of keyword args is required: %s' % ( self.supp_kwargs) + return kwargs def _fmt_kwargs(self, **kwargs): """Format kwargs before printing them. @@ -94,31 +103,26 @@ class DNSException(Exception): return self.fmt.format(**fmtargs) else: # print *args directly in the same way as old DNSException - return super(DNSException, self).__str__() + return super().__str__() class FormError(DNSException): - """DNS message is malformed.""" class SyntaxError(DNSException): - """Text input is malformed.""" class UnexpectedEnd(SyntaxError): - """Text input ended unexpectedly.""" class TooBig(DNSException): - """The DNS message is too big.""" class Timeout(DNSException): - """The DNS operation timed out.""" - supp_kwargs = set(['timeout']) + supp_kwargs = {'timeout'} fmt = "The DNS operation timed out after {timeout} seconds" diff --git a/lib/dns/flags.py b/lib/dns/flags.py index 388d6aaa..4eb6d90c 100644 --- a/lib/dns/flags.py +++ b/lib/dns/flags.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,98 +17,90 @@ """DNS Message Flags.""" +import enum + # Standard DNS flags -QR = 0x8000 -AA = 0x0400 -TC = 0x0200 -RD = 0x0100 -RA = 0x0080 -AD = 0x0020 -CD = 0x0010 +class Flag(enum.IntFlag): + #: Query Response + QR = 0x8000 + #: Authoritative Answer + AA = 0x0400 + #: Truncated Response + TC = 0x0200 + #: Recursion Desired + RD = 0x0100 + #: Recursion Available + RA = 0x0080 + #: Authentic Data + AD = 0x0020 + #: Checking Disabled + CD = 0x0010 + +globals().update(Flag.__members__) + # EDNS flags -DO = 0x8000 - -_by_text = { - 'QR': QR, - 'AA': AA, - 'TC': TC, - 'RD': RD, - 'RA': RA, - 'AD': AD, - 'CD': CD -} - -_edns_by_text = { - 'DO': DO -} +class EDNSFlag(enum.IntFlag): + #: DNSSEC answer OK + DO = 0x8000 -# We construct the inverse mappings programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mappings not to be true inverses. - -_by_value = dict((y, x) for x, y in _by_text.items()) - -_edns_by_value = dict((y, x) for x, y in _edns_by_text.items()) +globals().update(EDNSFlag.__members__) -def _order_flags(table): - order = list(table.items()) - order.sort() - order.reverse() - return order - -_flags_order = _order_flags(_by_value) - -_edns_flags_order = _order_flags(_edns_by_value) - - -def _from_text(text, table): +def _from_text(text, enum_class): flags = 0 tokens = text.split() for t in tokens: - flags = flags | table[t.upper()] + flags |= enum_class[t.upper()] return flags -def _to_text(flags, table, order): +def _to_text(flags, enum_class): text_flags = [] - for k, v in order: - if flags & k != 0: - text_flags.append(v) + for k, v in enum_class.__members__.items(): + if flags & v != 0: + text_flags.append(k) return ' '.join(text_flags) def from_text(text): """Convert a space-separated list of flag text values into a flags value. - @rtype: int""" - return _from_text(text, _by_text) + Returns an ``int`` + """ + + return _from_text(text, Flag) def to_text(flags): """Convert a flags value into a space-separated list of flag text values. - @rtype: string""" - return _to_text(flags, _by_value, _flags_order) + Returns a ``str``. + """ + + return _to_text(flags, Flag) def edns_from_text(text): """Convert a space-separated list of EDNS flag text values into a EDNS flags value. - @rtype: int""" - return _from_text(text, _edns_by_text) + Returns an ``int`` + """ + + return _from_text(text, EDNSFlag) def edns_to_text(flags): """Convert an EDNS flags value into a space-separated list of EDNS flag text values. - @rtype: string""" - return _to_text(flags, _edns_by_value, _edns_flags_order) + Returns a ``str``. + """ + + return _to_text(flags, EDNSFlag) diff --git a/lib/dns/grange.py b/lib/dns/grange.py index 01a3257b..ffe8be7c 100644 --- a/lib/dns/grange.py +++ b/lib/dns/grange.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2012-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -17,23 +19,25 @@ import dns - def from_text(text): - """Convert the text form of a range in a GENERATE statement to an + """Convert the text form of a range in a ``$GENERATE`` statement to an integer. - @param text: the textual range - @type text: string - @return: The start, stop and step values. - @rtype: tuple - """ - # TODO, figure out the bounds on start, stop and step. + *text*, a ``str``, the textual range in ``$GENERATE`` form. + Returns a tuple of three ``int`` values ``(start, stop, step)``. + """ + + # TODO, figure out the bounds on start, stop and step. step = 1 cur = '' state = 0 # state 0 1 2 3 4 # x - y / z + + if text and text[0] == '-': + raise dns.exception.SyntaxError("Start cannot be a negative number") + for c in text: if c == '-' and state == 0: start = int(cur) @@ -49,7 +53,7 @@ def from_text(text): raise dns.exception.SyntaxError("Could not parse %s" % (c)) if state in (1, 3): - raise dns.exception.SyntaxError + raise dns.exception.SyntaxError() if state == 2: stop = int(cur) diff --git a/lib/dns/inet.py b/lib/dns/inet.py index 966285e7..25d99c2c 100644 --- a/lib/dns/inet.py +++ b/lib/dns/inet.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -21,36 +23,30 @@ import dns.ipv4 import dns.ipv6 -# We assume that AF_INET is always defined. - +# We assume that AF_INET and AF_INET6 are always defined. We keep +# these here for the benefit of any old code (unlikely though that +# is!). AF_INET = socket.AF_INET - -# AF_INET6 might not be defined in the socket module, but we need it. -# We'll try to use the socket module's value, and if it doesn't work, -# we'll use our own value. - -try: - AF_INET6 = socket.AF_INET6 -except AttributeError: - AF_INET6 = 9999 +AF_INET6 = socket.AF_INET6 def inet_pton(family, text): """Convert the textual form of a network address into its binary form. - @param family: the address family - @type family: int - @param text: the textual address - @type text: string - @raises NotImplementedError: the address family specified is not + *family* is an ``int``, the address family. + + *text* is a ``str``, the textual address. + + Raises ``NotImplementedError`` if the address family specified is not implemented. - @rtype: string + + Returns a ``bytes``. """ if family == AF_INET: return dns.ipv4.inet_aton(text) elif family == AF_INET6: - return dns.ipv6.inet_aton(text) + return dns.ipv6.inet_aton(text, True) else: raise NotImplementedError @@ -58,14 +54,16 @@ def inet_pton(family, text): def inet_ntop(family, address): """Convert the binary form of a network address into its textual form. - @param family: the address family - @type family: int - @param address: the binary address - @type address: string - @raises NotImplementedError: the address family specified is not + *family* is an ``int``, the address family. + + *address* is a ``bytes``, the network address in binary form. + + Raises ``NotImplementedError`` if the address family specified is not implemented. - @rtype: string + + Returns a ``str``. """ + if family == AF_INET: return dns.ipv4.inet_ntoa(address) elif family == AF_INET6: @@ -77,35 +75,96 @@ def inet_ntop(family, address): def af_for_address(text): """Determine the address family of a textual-form network address. - @param text: the textual address - @type text: string - @raises ValueError: the address family cannot be determined from the input. - @rtype: int + *text*, a ``str``, the textual address. + + Raises ``ValueError`` if the address family cannot be determined + from the input. + + Returns an ``int``. """ + try: dns.ipv4.inet_aton(text) return AF_INET - except: + except Exception: try: - dns.ipv6.inet_aton(text) + dns.ipv6.inet_aton(text, True) return AF_INET6 - except: + except Exception: raise ValueError def is_multicast(text): """Is the textual-form network address a multicast address? - @param text: the textual address - @raises ValueError: the address family cannot be determined from the input. - @rtype: bool + *text*, a ``str``, the textual address. + + Raises ``ValueError`` if the address family cannot be determined + from the input. + + Returns a ``bool``. """ + try: - first = ord(dns.ipv4.inet_aton(text)[0]) - return (first >= 224 and first <= 239) - except: + first = dns.ipv4.inet_aton(text)[0] + return first >= 224 and first <= 239 + except Exception: try: - first = ord(dns.ipv6.inet_aton(text)[0]) - return (first == 255) - except: + first = dns.ipv6.inet_aton(text, True)[0] + return first == 255 + except Exception: raise ValueError + + +def is_address(text): + """Is the specified string an IPv4 or IPv6 address? + + *text*, a ``str``, the textual address. + + Returns a ``bool``. + """ + + try: + dns.ipv4.inet_aton(text) + return True + except Exception: + try: + dns.ipv6.inet_aton(text, True) + return True + except Exception: + return False + + +def low_level_address_tuple(high_tuple, af=None): + """Given a "high-level" address tuple, i.e. + an (address, port) return the appropriate "low-level" address tuple + suitable for use in socket calls. + + If an *af* other than ``None`` is provided, it is assumed the + address in the high-level tuple is valid and has that af. If af + is ``None``, then af_for_address will be called. + + """ + address, port = high_tuple + if af is None: + af = af_for_address(address) + if af == AF_INET: + return (address, port) + elif af == AF_INET6: + i = address.find('%') + if i < 0: + # no scope, shortcut! + return (address, port, 0, 0) + # try to avoid getaddrinfo() + addrpart = address[:i] + scope = address[i + 1:] + if scope.isdigit(): + return (addrpart, port, 0, int(scope)) + try: + return (addrpart, port, 0, socket.if_nametoindex(scope)) + except AttributeError: + ai_flags = socket.AI_NUMERICHOST + ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) + return tup + else: + raise NotImplementedError(f'unknown address family {af}') diff --git a/lib/dns/ipv4.py b/lib/dns/ipv4.py index 3fef282b..e1f38d3d 100644 --- a/lib/dns/ipv4.py +++ b/lib/dns/ipv4.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -18,30 +20,29 @@ import struct import dns.exception -from ._compat import binary_type def inet_ntoa(address): - """Convert an IPv4 address in network form to text form. + """Convert an IPv4 address in binary form to text form. - @param address: The IPv4 address - @type address: string - @returns: string + *address*, a ``bytes``, the IPv4 address in binary form. + + Returns a ``str``. """ + if len(address) != 4: raise dns.exception.SyntaxError - if not isinstance(address, bytearray): - address = bytearray(address) - return (u'%u.%u.%u.%u' % (address[0], address[1], - address[2], address[3])).encode() + return ('%u.%u.%u.%u' % (address[0], address[1], + address[2], address[3])) def inet_aton(text): - """Convert an IPv4 address in text form to network form. + """Convert an IPv4 address in text form to binary form. - @param text: The IPv4 address - @type text: string - @returns: string + *text*, a ``str``, the IPv4 address in textual form. + + Returns a ``bytes``. """ - if not isinstance(text, binary_type): + + if not isinstance(text, bytes): text = text.encode() parts = text.split(b'.') if len(parts) != 4: @@ -49,11 +50,11 @@ def inet_aton(text): for part in parts: if not part.isdigit(): raise dns.exception.SyntaxError - if len(part) > 1 and part[0] == '0': + if len(part) > 1 and part[0] == ord('0'): # No leading zeros raise dns.exception.SyntaxError try: - bytes = [int(part) for part in parts] - return struct.pack('BBBB', *bytes) - except: + b = [int(part) for part in parts] + return struct.pack('BBBB', *b) + except Exception: raise dns.exception.SyntaxError diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index ee991e85..5424fcea 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,17 +22,16 @@ import binascii import dns.exception import dns.ipv4 -from ._compat import xrange, binary_type -_leading_zero = re.compile(b'0+([0-9a-f]+)') +_leading_zero = re.compile(r'0+([0-9a-f]+)') def inet_ntoa(address): - """Convert a network format IPv6 address into text. + """Convert an IPv6 address in binary form to text form. - @param address: the binary address - @type address: string - @rtype: string - @raises ValueError: the address isn't 16 bytes long + *address*, a ``bytes``, the IPv6 address in binary form. + + Raises ``ValueError`` if the address isn't 16 bytes long. + Returns a ``str``. """ if len(address) != 16: @@ -40,12 +41,12 @@ def inet_ntoa(address): i = 0 l = len(hex) while i < l: - chunk = hex[i : i + 4] + chunk = hex[i:i + 4].decode() # strip leading zeros. we do this with an re instead of # with lstrip() because lstrip() didn't support chars until # python 2.2.2 m = _leading_zero.match(chunk) - if not m is None: + if m is not None: chunk = m.group(1) chunks.append(chunk) i += 4 @@ -56,8 +57,8 @@ def inet_ntoa(address): best_len = 0 start = -1 last_was_zero = False - for i in xrange(8): - if chunks[i] != b'0': + for i in range(8): + if chunks[i] != '0': if last_was_zero: end = i current_len = end - start @@ -77,59 +78,70 @@ def inet_ntoa(address): if best_len > 1: if best_start == 0 and \ (best_len == 6 or - best_len == 5 and chunks[5] == b'ffff'): + best_len == 5 and chunks[5] == 'ffff'): # We have an embedded IPv4 address if best_len == 6: - prefix = b'::' + prefix = '::' else: - prefix = b'::ffff:' + prefix = '::ffff:' hex = prefix + dns.ipv4.inet_ntoa(address[12:]) else: - hex = b':'.join(chunks[:best_start]) + b'::' + \ - b':'.join(chunks[best_start + best_len:]) + hex = ':'.join(chunks[:best_start]) + '::' + \ + ':'.join(chunks[best_start + best_len:]) else: - hex = b':'.join(chunks) + hex = ':'.join(chunks) return hex -_v4_ending = re.compile(b'(.*):(\d+\.\d+\.\d+\.\d+)$') -_colon_colon_start = re.compile(b'::.*') -_colon_colon_end = re.compile(b'.*::$') +_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$') +_colon_colon_start = re.compile(br'::.*') +_colon_colon_end = re.compile(br'.*::$') -def inet_aton(text): - """Convert a text format IPv6 address into network format. +def inet_aton(text, ignore_scope=False): + """Convert an IPv6 address in text form to binary form. - @param text: the textual address - @type text: string - @rtype: string - @raises dns.exception.SyntaxError: the text was not properly formatted + *text*, a ``str``, the IPv6 address in textual form. + + *ignore_scope*, a ``bool``. If ``True``, a scope will be ignored. + If ``False``, the default, it is an error for a scope to be present. + + Returns a ``bytes``. """ # # Our aim here is not something fast; we just want something that works. # - if not isinstance(text, binary_type): + if not isinstance(text, bytes): text = text.encode() + if ignore_scope: + parts = text.split(b'%') + l = len(parts) + if l == 2: + text = parts[0] + elif l > 2: + raise dns.exception.SyntaxError + if text == b'::': text = b'0::' # # Get rid of the icky dot-quad syntax if we have it. # m = _v4_ending.match(text) - if not m is None: - b = bytearray(dns.ipv4.inet_aton(m.group(2))) - text = (u"%s:%02x%02x:%02x%02x" % (m.group(1).decode(), b[0], b[1], - b[2], b[3])).encode() + if m is not None: + b = dns.ipv4.inet_aton(m.group(2)) + text = (u"{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), + b[0], b[1], b[2], + b[3])).encode() # # Try to turn '::' into ':'; if no match try to # turn '::' into ':' # m = _colon_colon_start.match(text) - if not m is None: + if m is not None: text = text[1:] else: m = _colon_colon_end.match(text) - if not m is None: + if m is not None: text = text[:-1] # # Now canonicalize into 8 chunks of 4 hex digits each @@ -145,7 +157,7 @@ def inet_aton(text): if seen_empty: raise dns.exception.SyntaxError seen_empty = True - for i in xrange(0, 8 - l + 1): + for i in range(0, 8 - l + 1): canonical.append(b'0000') else: lc = len(c) @@ -169,4 +181,11 @@ def inet_aton(text): _mapped_prefix = b'\x00' * 10 + b'\xff\xff' def is_mapped(address): + """Is the specified address a mapped IPv4 address? + + *address*, a ``bytes`` is an IPv6 address in binary form. + + Returns a ``bool``. + """ + return address.startswith(_mapped_prefix) diff --git a/lib/dns/message.py b/lib/dns/message.py index 9b8dcd0f..60b74c19 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,14 +17,13 @@ """DNS Messages""" -from __future__ import absolute_import - -from io import StringIO -import struct -import sys +import contextlib +import io import time +import dns.wire import dns.edns +import dns.enum import dns.exception import dns.flags import dns.name @@ -35,120 +36,68 @@ import dns.rdatatype import dns.rrset import dns.renderer import dns.tsig -import dns.wiredata - -from ._compat import long, xrange, string_types +import dns.rdtypes.ANY.OPT +import dns.rdtypes.ANY.TSIG class ShortHeader(dns.exception.FormError): - """The DNS packet passed to from_wire() is too short.""" class TrailingJunk(dns.exception.FormError): - """The DNS packet passed to from_wire() has extra junk at the end of it.""" class UnknownHeaderField(dns.exception.DNSException): - """The header field name was not recognized when converting from text into a message.""" class BadEDNS(dns.exception.FormError): - - """OPT record occurred somewhere other than the start of + """An OPT record occurred somewhere other than the additional data section.""" class BadTSIG(dns.exception.FormError): - """A TSIG record occurred somewhere other than the end of the additional data section.""" class UnknownTSIGKey(dns.exception.DNSException): - """A TSIG with an unknown key was received.""" -class Message(object): +class Truncated(dns.exception.DNSException): + """The truncated flag is set.""" - """A DNS message. + supp_kwargs = {'message'} - @ivar id: The query id; the default is a randomly chosen id. - @type id: int - @ivar flags: The DNS flags of the message. @see: RFC 1035 for an - explanation of these flags. - @type flags: int - @ivar question: The question section. - @type question: list of dns.rrset.RRset objects - @ivar answer: The answer section. - @type answer: list of dns.rrset.RRset objects - @ivar authority: The authority section. - @type authority: list of dns.rrset.RRset objects - @ivar additional: The additional data section. - @type additional: list of dns.rrset.RRset objects - @ivar edns: The EDNS level to use. The default is -1, no Edns. - @type edns: int - @ivar ednsflags: The EDNS flags - @type ednsflags: long - @ivar payload: The EDNS payload size. The default is 0. - @type payload: int - @ivar options: The EDNS options - @type options: list of dns.edns.Option objects - @ivar request_payload: The associated request's EDNS payload size. - @type request_payload: int - @ivar keyring: The TSIG keyring to use. The default is None. - @type keyring: dict - @ivar keyname: The TSIG keyname to use. The default is None. - @type keyname: dns.name.Name object - @ivar keyalgorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm. Constants for TSIG algorithms are defined - in dns.tsig, and the currently implemented algorithms are - HMAC_MD5, HMAC_SHA1, HMAC_SHA224, HMAC_SHA256, HMAC_SHA384, and - HMAC_SHA512. - @type keyalgorithm: string - @ivar request_mac: The TSIG MAC of the request message associated with - this message; used when validating TSIG signatures. @see: RFC 2845 for - more information on TSIG fields. - @type request_mac: string - @ivar fudge: TSIG time fudge; default is 300 seconds. - @type fudge: int - @ivar original_id: TSIG original id; defaults to the message's id - @type original_id: int - @ivar tsig_error: TSIG error code; default is 0. - @type tsig_error: int - @ivar other_data: TSIG other data. - @type other_data: string - @ivar mac: The TSIG MAC for this message. - @type mac: string - @ivar xfr: Is the message being used to contain the results of a DNS - zone transfer? The default is False. - @type xfr: bool - @ivar origin: The origin of the zone in messages which are used for - zone transfers or for DNS dynamic updates. The default is None. - @type origin: dns.name.Name object - @ivar tsig_ctx: The TSIG signature context associated with this - message. The default is None. - @type tsig_ctx: hmac.HMAC object - @ivar had_tsig: Did the message decoded from wire format have a TSIG - signature? - @type had_tsig: bool - @ivar multi: Is this message part of a multi-message sequence? The - default is false. This variable is used when validating TSIG signatures - on messages which are part of a zone transfer. - @type multi: bool - @ivar first: Is this message standalone, or the first of a multi - message sequence? This variable is used when validating TSIG signatures - on messages which are part of a zone transfer. - @type first: bool - @ivar index: An index of rrsets in the message. The index key is - (section, name, rdclass, rdtype, covers, deleting). Indexing can be - disabled by setting the index to None. - @type index: dict - """ + def message(self): + """As much of the message as could be processed. + + Returns a ``dns.message.Message``. + """ + return self.kwargs['message'] + + +class MessageSection(dns.enum.IntEnum): + """Message sections""" + QUESTION = 0 + ANSWER = 1 + AUTHORITY = 2 + ADDITIONAL = 3 + + @classmethod + def _maximum(cls): + return 3 + +globals().update(MessageSection.__members__) + + +class Message: + """A DNS message.""" + + _section_enum = MessageSection def __init__(self, id=None): if id is None: @@ -156,86 +105,88 @@ class Message(object): else: self.id = id self.flags = 0 - self.question = [] - self.answer = [] - self.authority = [] - self.additional = [] - self.edns = -1 - self.ednsflags = 0 - self.payload = 0 - self.options = [] + self.sections = [[], [], [], []] + self.opt = None self.request_payload = 0 self.keyring = None - self.keyname = None - self.keyalgorithm = dns.tsig.default_algorithm - self.request_mac = '' - self.other_data = '' - self.tsig_error = 0 - self.fudge = 300 - self.original_id = self.id - self.mac = '' + self.tsig = None + self.request_mac = b'' self.xfr = False self.origin = None self.tsig_ctx = None - self.had_tsig = False - self.multi = False - self.first = True self.index = {} + @property + def question(self): + """ The question section.""" + return self.sections[0] + + @question.setter + def question(self, v): + self.sections[0] = v + + @property + def answer(self): + """ The answer section.""" + return self.sections[1] + + @answer.setter + def answer(self, v): + self.sections[1] = v + + @property + def authority(self): + """ The authority section.""" + return self.sections[2] + + @authority.setter + def authority(self, v): + self.sections[2] = v + + @property + def additional(self): + """ The additional data section.""" + return self.sections[3] + + @additional.setter + def additional(self, v): + self.sections[3] = v + def __repr__(self): return '' def __str__(self): return self.to_text() - def to_text(self, origin=None, relativize=True, **kw): + def to_text(self, origin=None, relativize=True, **kw): """Convert the message to text. - The I{origin}, I{relativize}, and any other keyword - arguments are passed to the rrset to_wire() method. + The *origin*, *relativize*, and any other keyword + arguments are passed to the RRset ``to_wire()`` method. - @rtype: string + Returns a ``str``. """ - s = StringIO() - s.write(u'id %d\n' % self.id) - s.write(u'opcode %s\n' % + s = io.StringIO() + s.write('id %d\n' % self.id) + s.write('opcode %s\n' % dns.opcode.to_text(dns.opcode.from_flags(self.flags))) rc = dns.rcode.from_flags(self.flags, self.ednsflags) - s.write(u'rcode %s\n' % dns.rcode.to_text(rc)) - s.write(u'flags %s\n' % dns.flags.to_text(self.flags)) + s.write('rcode %s\n' % dns.rcode.to_text(rc)) + s.write('flags %s\n' % dns.flags.to_text(self.flags)) if self.edns >= 0: - s.write(u'edns %s\n' % self.edns) + s.write('edns %s\n' % self.edns) if self.ednsflags != 0: - s.write(u'eflags %s\n' % + s.write('eflags %s\n' % dns.flags.edns_to_text(self.ednsflags)) - s.write(u'payload %d\n' % self.payload) - is_update = dns.opcode.is_update(self.flags) - if is_update: - s.write(u';ZONE\n') - else: - s.write(u';QUESTION\n') - for rrset in self.question: - s.write(rrset.to_text(origin, relativize, **kw)) - s.write(u'\n') - if is_update: - s.write(u';PREREQ\n') - else: - s.write(u';ANSWER\n') - for rrset in self.answer: - s.write(rrset.to_text(origin, relativize, **kw)) - s.write(u'\n') - if is_update: - s.write(u';UPDATE\n') - else: - s.write(u';AUTHORITY\n') - for rrset in self.authority: - s.write(rrset.to_text(origin, relativize, **kw)) - s.write(u'\n') - s.write(u';ADDITIONAL\n') - for rrset in self.additional: - s.write(rrset.to_text(origin, relativize, **kw)) - s.write(u'\n') + s.write('payload %d\n' % self.payload) + for opt in self.options: + s.write('option %s\n' % opt.to_text()) + for (name, which) in self._section_enum.__members__.items(): + s.write(f';{name}\n') + for rrset in self.section_from_number(which): + s.write(rrset.to_text(origin, relativize, **kw)) + s.write('\n') # # We strip off the final \n so the caller can print the result without # doing weird things to get around eccentricities in Python print @@ -246,41 +197,35 @@ class Message(object): def __eq__(self, other): """Two messages are equal if they have the same content in the header, question, answer, and authority sections. - @rtype: bool""" + + Returns a ``bool``. + """ + if not isinstance(other, Message): return False if self.id != other.id: return False if self.flags != other.flags: return False - for n in self.question: - if n not in other.question: - return False - for n in other.question: - if n not in self.question: - return False - for n in self.answer: - if n not in other.answer: - return False - for n in other.answer: - if n not in self.answer: - return False - for n in self.authority: - if n not in other.authority: - return False - for n in other.authority: - if n not in self.authority: - return False + for i, section in enumerate(self.sections): + other_section = other.sections[i] + for n in section: + if n not in other_section: + return False + for n in other_section: + if n not in section: + return False return True def __ne__(self, other): - """Are two messages not equal? - @rtype: bool""" return not self.__eq__(other) def is_response(self, other): - """Is other a response to self? - @rtype: bool""" + """Is *other* a response this message? + + Returns a ``bool``. + """ + if other.flags & dns.flags.QR == 0 or \ self.id != other.id or \ dns.opcode.from_flags(self.flags) != \ @@ -290,6 +235,10 @@ class Message(object): dns.rcode.NOERROR: return True if dns.opcode.is_update(self.flags): + # This is assuming the "sender doesn't include anything + # from the update", but we don't care to check the other + # case, which is that all the sections are returned and + # identical. return True for n in self.question: if n not in other.question: @@ -300,46 +249,80 @@ class Message(object): return True def section_number(self, section): - if section is self.question: - return 0 - elif section is self.answer: - return 1 - elif section is self.authority: - return 2 - elif section is self.additional: - return 3 - else: - raise ValueError('unknown section') + """Return the "section number" of the specified section for use + in indexing. + + *section* is one of the section attributes of this message. + + Raises ``ValueError`` if the section isn't known. + + Returns an ``int``. + """ + + for i, our_section in enumerate(self.sections): + if section is our_section: + return self._section_enum(i) + raise ValueError('unknown section') + + def section_from_number(self, number): + """Return the section list associated with the specified section + number. + + *number* is a section number `int` or the text form of a section + name. + + Raises ``ValueError`` if the section isn't known. + + Returns a ``list``. + """ + + section = self._section_enum.make(number) + return self.sections[section] def find_rrset(self, section, name, rdclass, rdtype, covers=dns.rdatatype.NONE, deleting=None, create=False, force_unique=False): """Find the RRset with the given attributes in the specified section. - @param section: the section of the message to look in, e.g. - self.answer. - @type section: list of dns.rrset.RRset objects - @param name: the name of the RRset - @type name: dns.name.Name object - @param rdclass: the class of the RRset - @type rdclass: int - @param rdtype: the type of the RRset - @type rdtype: int - @param covers: the covers value of the RRset - @type covers: int - @param deleting: the deleting value of the RRset - @type deleting: int - @param create: If True, create the RRset if it is not found. - The created RRset is appended to I{section}. - @type create: bool - @param force_unique: If True and create is also True, create a - new RRset regardless of whether a matching RRset exists already. - @type force_unique: bool - @raises KeyError: the RRset was not found and create was False - @rtype: dns.rrset.RRset object""" + *section*, an ``int`` section number, or one of the section + attributes of this message. This specifies the + the section of the message to search. For example:: - key = (self.section_number(section), - name, rdclass, rdtype, covers, deleting) + my_message.find_rrset(my_message.answer, name, rdclass, rdtype) + my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) + + *name*, a ``dns.name.Name``, the name of the RRset. + + *rdclass*, an ``int``, the class of the RRset. + + *rdtype*, an ``int``, the type of the RRset. + + *covers*, an ``int`` or ``None``, the covers value of the RRset. + The default is ``None``. + + *deleting*, an ``int`` or ``None``, the deleting value of the RRset. + The default is ``None``. + + *create*, a ``bool``. If ``True``, create the RRset if it is not found. + The created RRset is appended to *section*. + + *force_unique*, a ``bool``. If ``True`` and *create* is also ``True``, + create a new RRset regardless of whether a matching RRset exists + already. The default is ``False``. This is useful when creating + DDNS Update messages, as order matters for them. + + Raises ``KeyError`` if the RRset was not found and create was + ``False``. + + Returns a ``dns.rrset.RRset object``. + """ + + if isinstance(section, int): + section_number = section + section = self.section_from_number(section_number) + else: + section_number = self.section_number(section) + key = (section_number, name, rdclass, rdtype, covers, deleting) if not force_unique: if self.index is not None: rrset = self.index.get(key) @@ -364,26 +347,35 @@ class Message(object): If the RRset is not found, None is returned. - @param section: the section of the message to look in, e.g. - self.answer. - @type section: list of dns.rrset.RRset objects - @param name: the name of the RRset - @type name: dns.name.Name object - @param rdclass: the class of the RRset - @type rdclass: int - @param rdtype: the type of the RRset - @type rdtype: int - @param covers: the covers value of the RRset - @type covers: int - @param deleting: the deleting value of the RRset - @type deleting: int - @param create: If True, create the RRset if it is not found. - The created RRset is appended to I{section}. - @type create: bool - @param force_unique: If True and create is also True, create a - new RRset regardless of whether a matching RRset exists already. - @type force_unique: bool - @rtype: dns.rrset.RRset object or None""" + *section*, an ``int`` section number, or one of the section + attributes of this message. This specifies the + the section of the message to search. For example:: + + my_message.get_rrset(my_message.answer, name, rdclass, rdtype) + my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) + + *name*, a ``dns.name.Name``, the name of the RRset. + + *rdclass*, an ``int``, the class of the RRset. + + *rdtype*, an ``int``, the type of the RRset. + + *covers*, an ``int`` or ``None``, the covers value of the RRset. + The default is ``None``. + + *deleting*, an ``int`` or ``None``, the deleting value of the RRset. + The default is ``None``. + + *create*, a ``bool``. If ``True``, create the RRset if it is not found. + The created RRset is appended to *section*. + + *force_unique*, a ``bool``. If ``True`` and *create* is also ``True``, + create a new RRset regardless of whether a matching RRset exists + already. The default is ``False``. This is useful when creating + DDNS Update messages, as order matters for them. + + Returns a ``dns.rrset.RRset object`` or ``None``. + """ try: rrset = self.find_rrset(section, name, rdclass, rdtype, covers, @@ -392,23 +384,35 @@ class Message(object): rrset = None return rrset - def to_wire(self, origin=None, max_size=0, **kw): + def to_wire(self, origin=None, max_size=0, multi=False, tsig_ctx=None, + **kw): """Return a string containing the message in DNS compressed wire format. - Additional keyword arguments are passed to the rrset to_wire() + Additional keyword arguments are passed to the RRset ``to_wire()`` method. - @param origin: The origin to be appended to any relative names. - @type origin: dns.name.Name object - @param max_size: The maximum size of the wire format output; default - is 0, which means 'the message's request payload, if nonzero, or - 65536'. - @type max_size: int - @raises dns.exception.TooBig: max_size was exceeded - @rtype: string + *origin*, a ``dns.name.Name`` or ``None``, the origin to be appended + to any relative names. If ``None``, and the message has an origin + attribute that is not ``None``, then it will be used. + + *max_size*, an ``int``, the maximum size of the wire format + output; default is 0, which means "the message's request + payload, if nonzero, or 65535". + + *multi*, a ``bool``, should be set to ``True`` if this message is + part of a multiple message sequence. + + *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used + when signing zone transfers. + + Raises ``dns.exception.TooBig`` if *max_size* was exceeded. + + Returns a ``bytes``. """ + if origin is None and self.origin is not None: + origin = self.origin if max_size == 0: if self.request_payload != 0: max_size = self.request_payload @@ -425,82 +429,150 @@ class Message(object): r.add_rrset(dns.renderer.ANSWER, rrset, **kw) for rrset in self.authority: r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw) - if self.edns >= 0: - r.add_edns(self.edns, self.ednsflags, self.payload, self.options) + if self.opt is not None: + r.add_rrset(dns.renderer.ADDITIONAL, self.opt) for rrset in self.additional: r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) r.write_header() - if self.keyname is not None: - r.add_tsig(self.keyname, self.keyring[self.keyname], - self.fudge, self.original_id, self.tsig_error, - self.other_data, self.request_mac, - self.keyalgorithm) - self.mac = r.mac + if self.tsig is not None: + (new_tsig, ctx) = dns.tsig.sign(r.get_wire(), + self.keyring, + self.tsig[0], + int(time.time()), + self.request_mac, + tsig_ctx, + multi) + self.tsig.clear() + self.tsig.add(new_tsig) + r.add_rrset(dns.renderer.ADDITIONAL, self.tsig) + r.write_header() + if multi: + self.tsig_ctx = ctx return r.get_wire() - def use_tsig(self, keyring, keyname=None, fudge=300, - original_id=None, tsig_error=0, other_data='', - algorithm=dns.tsig.default_algorithm): - """When sending, a TSIG signature using the specified keyring - and keyname should be added. + @staticmethod + def _make_tsig(keyname, algorithm, time_signed, fudge, mac, original_id, + error, other): + tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, + algorithm, time_signed, fudge, mac, + original_id, error, other) + return dns.rrset.from_rdata(keyname, 0, tsig) - @param keyring: The TSIG keyring to use; defaults to None. - @type keyring: dict - @param keyname: The name of the TSIG key to use; defaults to None. - The key must be defined in the keyring. If a keyring is specified - but a keyname is not, then the key used will be the first key in the - keyring. Note that the order of keys in a dictionary is not defined, - so applications should supply a keyname when a keyring is used, unless - they know the keyring contains only one key. - @type keyname: dns.name.Name or string - @param fudge: TSIG time fudge; default is 300 seconds. - @type fudge: int - @param original_id: TSIG original id; defaults to the message's id - @type original_id: int - @param tsig_error: TSIG error code; default is 0. - @type tsig_error: int - @param other_data: TSIG other data. - @type other_data: string - @param algorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm + def use_tsig(self, keyring, keyname=None, fudge=300, + original_id=None, tsig_error=0, other_data=b'', + algorithm=dns.tsig.default_algorithm): + """When sending, a TSIG signature using the specified key + should be added. + + *key*, a ``dns.tsig.Key`` is the key to use. If a key is specified, + the *keyring* and *algorithm* fields are not used. + + *keyring*, a ``dict`` or ``dns.tsig.Key``, is either the TSIG + keyring or key to use. + + The format of a keyring dict is a mapping from TSIG key name, as + ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``. + If a ``dict`` *keyring* is specified but a *keyname* is not, the key + used will be the first key in the *keyring*. Note that the order of + keys in a dictionary is not defined, so applications should supply a + keyname when a ``dict`` keyring is used, unless they know the keyring + contains only one key. + + *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of + thes TSIG key to use; defaults to ``None``. If *keyring* is a + ``dict``, the key must be defined in it. If *keyring* is a + ``dns.tsig.Key``, this is ignored. + + *fudge*, an ``int``, the TSIG time fudge. + + *original_id*, an ``int``, the TSIG original id. If ``None``, + the message's id is used. + + *tsig_error*, an ``int``, the TSIG error code. + + *other_data*, a ``bytes``, the TSIG other data. + + *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use. This is + only used if *keyring* is a ``dict``, and the key entry is a ``bytes``. """ - self.keyring = keyring - if keyname is None: - self.keyname = list(self.keyring.keys())[0] + if isinstance(keyring, dns.tsig.Key): + self.keyring = keyring else: - if isinstance(keyname, string_types): + if isinstance(keyname, str): keyname = dns.name.from_text(keyname) - self.keyname = keyname - self.keyalgorithm = algorithm - self.fudge = fudge + if keyname is None: + keyname = next(iter(keyring)) + key = keyring[keyname] + if isinstance(key, bytes): + key = dns.tsig.Key(keyname, key, algorithm) + self.keyring = key if original_id is None: - self.original_id = self.id + original_id = self.id + self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge, + b'', original_id, tsig_error, other_data) + + @property + def keyname(self): + if self.tsig: + return self.tsig.name else: - self.original_id = original_id - self.tsig_error = tsig_error - self.other_data = other_data + return None + + @property + def keyalgorithm(self): + if self.tsig: + return self.tsig[0].algorithm + else: + return None + + @property + def mac(self): + if self.tsig: + return self.tsig[0].mac + else: + return None + + @property + def tsig_error(self): + if self.tsig: + return self.tsig[0].error + else: + return None + + @property + def had_tsig(self): + return bool(self.tsig) + + @staticmethod + def _make_opt(flags=0, payload=1280, options=None): + opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, + options or ()) + return dns.rrset.from_rdata(dns.name.root, int(flags), opt) def use_edns(self, edns=0, ednsflags=0, payload=1280, request_payload=None, options=None): """Configure EDNS behavior. - @param edns: The EDNS level to use. Specifying None, False, or -1 - means 'do not use EDNS', and in this case the other parameters are - ignored. Specifying True is equivalent to specifying 0, i.e. 'use - EDNS0'. - @type edns: int or bool or None - @param ednsflags: EDNS flag values. - @type ednsflags: int - @param payload: The EDNS sender's payload field, which is the maximum - size of UDP datagram the sender can handle. - @type payload: int - @param request_payload: The EDNS payload size to use when sending - this message. If not specified, defaults to the value of payload. - @type request_payload: int or None - @param options: The EDNS options - @type options: None or list of dns.edns.Option objects - @see: RFC 2671 + + *edns*, an ``int``, is the EDNS level to use. Specifying + ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case + the other parameters are ignored. Specifying ``True`` is + equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + + *request_payload*, an ``int``, is the EDNS payload size to use when + sending this message. If not specified, defaults to the value of + *payload*. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. """ + if edns is None or edns is False: edns = -1 if edns is True: @@ -514,326 +586,395 @@ class Message(object): options = [] else: # make sure the EDNS version in ednsflags agrees with edns - ednsflags &= long(0xFF00FFFF) + ednsflags &= 0xFF00FFFF ednsflags |= (edns << 16) if options is None: options = [] - self.edns = edns - self.ednsflags = ednsflags - self.payload = payload - self.options = options + if edns >= 0: + self.opt = self._make_opt(ednsflags, payload, options) + else: + self.opt = None self.request_payload = request_payload + @property + def edns(self): + if self.opt: + return (self.ednsflags & 0xff0000) >> 16 + else: + return -1 + + @property + def ednsflags(self): + if self.opt: + return self.opt.ttl + else: + return 0 + + @ednsflags.setter + def ednsflags(self, v): + if self.opt: + self.opt.ttl = v + elif v: + self.opt = self._make_opt(v) + + @property + def payload(self): + if self.opt: + return self.opt[0].payload + else: + return 0 + + @property + def options(self): + if self.opt: + return self.opt[0].options + else: + return () + def want_dnssec(self, wanted=True): """Enable or disable 'DNSSEC desired' flag in requests. - @param wanted: Is DNSSEC desired? If True, EDNS is enabled if - required, and then the DO bit is set. If False, the DO bit is - cleared if EDNS is enabled. - @type wanted: bool + + *wanted*, a ``bool``. If ``True``, then DNSSEC data is + desired in the response, EDNS is enabled if required, and then + the DO bit is set. If ``False``, the DO bit is cleared if + EDNS is enabled. """ + if wanted: - if self.edns < 0: - self.use_edns() self.ednsflags |= dns.flags.DO - elif self.edns >= 0: + elif self.opt: self.ednsflags &= ~dns.flags.DO def rcode(self): """Return the rcode. - @rtype: int + + Returns an ``int``. """ return dns.rcode.from_flags(self.flags, self.ednsflags) def set_rcode(self, rcode): """Set the rcode. - @param rcode: the rcode - @type rcode: int + + *rcode*, an ``int``, is the rcode to set. """ (value, evalue) = dns.rcode.to_flags(rcode) self.flags &= 0xFFF0 self.flags |= value - self.ednsflags &= long(0x00FFFFFF) + self.ednsflags &= 0x00FFFFFF self.ednsflags |= evalue - if self.ednsflags != 0 and self.edns < 0: - self.edns = 0 def opcode(self): """Return the opcode. - @rtype: int + + Returns an ``int``. """ return dns.opcode.from_flags(self.flags) def set_opcode(self, opcode): """Set the opcode. - @param opcode: the opcode - @type opcode: int + + *opcode*, an ``int``, is the opcode to set. """ self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) + def _get_one_rr_per_rrset(self, value): + # What the caller picked is fine. + return value -class _WireReader(object): + def _parse_rr_header(self, section, name, rdclass, rdtype): + return (rdclass, rdtype, None, False) + + def _parse_special_rr_header(self, section, count, position, + name, rdclass, rdtype): + if rdtype == dns.rdatatype.OPT: + if section != MessageSection.ADDITIONAL or self.opt or \ + name != dns.name.root: + raise BadEDNS + elif rdtype == dns.rdatatype.TSIG: + if section != MessageSection.ADDITIONAL or \ + rdclass != dns.rdatatype.ANY or \ + position != count - 1: + raise BadTSIG + return (rdclass, rdtype, None, False) + + +class QueryMessage(Message): + pass + + +def _maybe_import_update(): + # We avoid circular imports by doing this here. We do it in another + # function as doing it in _message_factory_from_opcode() makes "dns" + # a local symbol, and the first line fails :) + import dns.update # noqa: F401 + + +def _message_factory_from_opcode(opcode): + if opcode == dns.opcode.QUERY: + return QueryMessage + elif opcode == dns.opcode.UPDATE: + _maybe_import_update() + return dns.update.UpdateMessage + else: + return Message + + +class _WireReader: """Wire format reader. - @ivar wire: the wire-format message. - @type wire: string - @ivar message: The message object being built - @type message: dns.message.Message object - @ivar current: When building a message object from wire format, this - variable contains the offset from the beginning of wire of the next octet - to be read. - @type current: int - @ivar updating: Is the message a dynamic update? - @type updating: bool - @ivar one_rr_per_rrset: Put each RR into its own RRset? - @type one_rr_per_rrset: bool - @ivar ignore_trailing: Ignore trailing junk at end of request? - @type ignore_trailing: bool - @ivar zone_rdclass: The class of the zone in messages which are + parser: the binary parser + message: The message object being built + initialize_message: Callback to set message parsing options + question_only: Are we only reading the question? + one_rr_per_rrset: Put each RR into its own RRset? + keyring: TSIG keyring + ignore_trailing: Ignore trailing junk at end of request? + multi: Is this message part of a multi-message sequence? DNS dynamic updates. - @type zone_rdclass: int """ - def __init__(self, wire, message, question_only=False, - one_rr_per_rrset=False, ignore_trailing=False): - self.wire = dns.wiredata.maybe_wrap(wire) - self.message = message - self.current = 0 - self.updating = False - self.zone_rdclass = dns.rdataclass.IN + def __init__(self, wire, initialize_message, question_only=False, + one_rr_per_rrset=False, ignore_trailing=False, + keyring=None, multi=False): + self.parser = dns.wire.Parser(wire) + self.message = None + self.initialize_message = initialize_message self.question_only = question_only self.one_rr_per_rrset = one_rr_per_rrset self.ignore_trailing = ignore_trailing + self.keyring = keyring + self.multi = multi - def _get_question(self, qcount): - """Read the next I{qcount} records from the wire data and add them to + def _get_question(self, section_number, qcount): + """Read the next *qcount* records from the wire data and add them to the question section. - @param qcount: the number of questions in the message - @type qcount: int""" + """ - if self.updating and qcount > 1: - raise dns.exception.FormError + section = self.message.sections[section_number] + for i in range(qcount): + qname = self.parser.get_name(self.message.origin) + (rdtype, rdclass) = self.parser.get_struct('!HH') + (rdclass, rdtype, _, _) = \ + self.message._parse_rr_header(section_number, qname, rdclass, + rdtype) + self.message.find_rrset(section, qname, rdclass, rdtype, + create=True, force_unique=True) - for i in xrange(0, qcount): - (qname, used) = dns.name.from_wire(self.wire, self.current) - if self.message.origin is not None: - qname = qname.relativize(self.message.origin) - self.current = self.current + used - (rdtype, rdclass) = \ - struct.unpack('!HH', - self.wire[self.current:self.current + 4]) - self.current = self.current + 4 - self.message.find_rrset(self.message.question, qname, - rdclass, rdtype, create=True, - force_unique=True) - if self.updating: - self.zone_rdclass = rdclass - - def _get_section(self, section, count): + def _get_section(self, section_number, count): """Read the next I{count} records from the wire data and add them to the specified section. - @param section: the section of the message to which to add records - @type section: list of dns.rrset.RRset objects - @param count: the number of records to read - @type count: int""" - if self.updating or self.one_rr_per_rrset: - force_unique = True - else: - force_unique = False - seen_opt = False - for i in xrange(0, count): - rr_start = self.current - (name, used) = dns.name.from_wire(self.wire, self.current) - absolute_name = name + section: the section of the message to which to add records + count: the number of records to read + """ + + section = self.message.sections[section_number] + force_unique = self.one_rr_per_rrset + for i in range(count): + rr_start = self.parser.current + absolute_name = self.parser.get_name() if self.message.origin is not None: - name = name.relativize(self.message.origin) - self.current = self.current + used - (rdtype, rdclass, ttl, rdlen) = \ - struct.unpack('!HHIH', - self.wire[self.current:self.current + 10]) - self.current = self.current + 10 + name = absolute_name.relativize(self.message.origin) + else: + name = absolute_name + (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct('!HHIH') + if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG): + (rdclass, rdtype, deleting, empty) = \ + self.message._parse_special_rr_header(section_number, + count, i, name, + rdclass, rdtype) + else: + (rdclass, rdtype, deleting, empty) = \ + self.message._parse_rr_header(section_number, + name, rdclass, rdtype) + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE + else: + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser(rdclass, rdtype, + self.parser, + self.message.origin) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True if rdtype == dns.rdatatype.OPT: - if section is not self.message.additional or seen_opt: - raise BadEDNS - self.message.payload = rdclass - self.message.ednsflags = ttl - self.message.edns = (ttl & 0xff0000) >> 16 - self.message.options = [] - current = self.current - optslen = rdlen - while optslen > 0: - (otype, olen) = \ - struct.unpack('!HH', - self.wire[current:current + 4]) - current = current + 4 - opt = dns.edns.option_from_wire( - otype, self.wire, current, olen) - self.message.options.append(opt) - current = current + olen - optslen = optslen - 4 - olen - seen_opt = True + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) elif rdtype == dns.rdatatype.TSIG: - if not (section is self.message.additional and - i == (count - 1)): - raise BadTSIG - if self.message.keyring is None: + if self.keyring is None: raise UnknownTSIGKey('got signed message without keyring') - secret = self.message.keyring.get(absolute_name) - if secret is None: + if isinstance(self.keyring, dict): + key = self.keyring.get(absolute_name) + if isinstance(key, bytes): + key = dns.tsig.Key(absolute_name, key, rd.algorithm) + else: + key = self.keyring + if key is None: raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyname = absolute_name - (self.message.keyalgorithm, self.message.mac) = \ - dns.tsig.get_algorithm_and_mac(self.wire, self.current, - rdlen) + self.message.keyring = key self.message.tsig_ctx = \ - dns.tsig.validate(self.wire, + dns.tsig.validate(self.parser.wire, + key, absolute_name, - secret, + rd, int(time.time()), self.message.request_mac, rr_start, - self.current, - rdlen, self.message.tsig_ctx, - self.message.multi, - self.message.first) - self.message.had_tsig = True + self.multi) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) else: - if ttl < 0: - ttl = 0 - if self.updating and \ - (rdclass == dns.rdataclass.ANY or - rdclass == dns.rdataclass.NONE): - deleting = rdclass - rdclass = self.zone_rdclass - else: - deleting = None - if deleting == dns.rdataclass.ANY or \ - (deleting == dns.rdataclass.NONE and - section is self.message.answer): - covers = dns.rdatatype.NONE - rd = None - else: - rd = dns.rdata.from_wire(rdclass, rdtype, self.wire, - self.current, rdlen, - self.message.origin) - covers = rd.covers() - if self.message.xfr and rdtype == dns.rdatatype.SOA: - force_unique = True rrset = self.message.find_rrset(section, name, rdclass, rdtype, covers, - deleting, True, force_unique) + deleting, True, + force_unique) if rd is not None: + if ttl > 0x7fffffff: + ttl = 0 rrset.add(rd, ttl) - self.current = self.current + rdlen def read(self): """Read a wire format DNS message and build a dns.message.Message object.""" - l = len(self.wire) - if l < 12: + if self.parser.remaining() < 12: raise ShortHeader - (self.message.id, self.message.flags, qcount, ancount, - aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12]) - self.current = 12 - if dns.opcode.is_update(self.message.flags): - self.updating = True - self._get_question(qcount) + (id, flags, qcount, ancount, aucount, adcount) = \ + self.parser.get_struct('!HHHHHH') + factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) + self.message = factory(id=id) + self.message.flags = flags + self.initialize_message(self.message) + self.one_rr_per_rrset = \ + self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) + self._get_question(MessageSection.QUESTION, qcount) if self.question_only: return - self._get_section(self.message.answer, ancount) - self._get_section(self.message.authority, aucount) - self._get_section(self.message.additional, adcount) - if not self.ignore_trailing and self.current != l: + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) + if not self.ignore_trailing and self.parser.remaining() != 0: raise TrailingJunk - if self.message.multi and self.message.tsig_ctx and \ - not self.message.had_tsig: - self.message.tsig_ctx.update(self.wire) + if self.multi and self.message.tsig_ctx and not self.message.had_tsig: + self.message.tsig_ctx.update(self.parser.wire) + return self.message -def from_wire(wire, keyring=None, request_mac='', xfr=False, origin=None, - tsig_ctx=None, multi=False, first=True, +def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, + tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, - ignore_trailing=False): + ignore_trailing=False, raise_on_truncation=False): """Convert a DNS wire format message into a message object. - @param keyring: The keyring to use if the message is signed. - @type keyring: dict - @param request_mac: If the message is a response to a TSIG-signed request, - I{request_mac} should be set to the MAC of that request. - @type request_mac: string - @param xfr: Is this message part of a zone transfer? - @type xfr: bool - @param origin: If the message is part of a zone transfer, I{origin} - should be the origin name of the zone. - @type origin: dns.name.Name object - @param tsig_ctx: The ongoing TSIG context, used when validating zone - transfers. - @type tsig_ctx: hmac.HMAC object - @param multi: Is this message part of a multiple message sequence? - @type multi: bool - @param first: Is this message standalone, or the first of a multi - message sequence? - @type first: bool - @param question_only: Read only up to the end of the question section? - @type question_only: bool - @param one_rr_per_rrset: Put each RR into its own RRset - @type one_rr_per_rrset: bool - @param ignore_trailing: Ignore trailing junk at end of request? - @type ignore_trailing: bool - @raises ShortHeader: The message is less than 12 octets long. - @raises TrailingJunk: There were octets in the message past the end - of the proper DNS message. - @raises BadEDNS: An OPT record was in the wrong section, or occurred more - than once. - @raises BadTSIG: A TSIG record was not the last record of the additional - data section. - @rtype: dns.message.Message object""" + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use + if the message is signed. - m = Message(id=0) - m.keyring = keyring - m.request_mac = request_mac - m.xfr = xfr - m.origin = origin - m.tsig_ctx = tsig_ctx - m.multi = multi - m.first = first + *request_mac*, a ``bytes``. If the message is a response to a + TSIG-signed request, *request_mac* should be set to the MAC of + that request. - reader = _WireReader(wire, m, question_only, one_rr_per_rrset, - ignore_trailing) - reader.read() + *xfr*, a ``bool``, should be set to ``True`` if this message is part of + a zone transfer. + + *origin*, a ``dns.name.Name`` or ``None``. If the message is part + of a zone transfer, *origin* should be the origin name of the + zone. If not ``None``, names will be relativized to the origin. + + *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used + when validating zone transfers. + + *multi*, a ``bool``, should be set to ``True`` if this message is + part of a multiple message sequence. + + *question_only*, a ``bool``. If ``True``, read only up to + the end of the question section. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its + own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + Raises ``dns.message.ShortHeader`` if the message is less than 12 octets + long. + + Raises ``dns.message.TrailingJunk`` if there were octets in the message + past the end of the proper DNS message, and *ignore_trailing* is ``False``. + + Raises ``dns.message.BadEDNS`` if an OPT record was in the + wrong section, or occurred more than once. + + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last + record of the additional data section. + + Raises ``dns.message.Truncated`` if the TC flag is set and + *raise_on_truncation* is ``True``. + + Returns a ``dns.message.Message``. + """ + + def initialize_message(message): + message.request_mac = request_mac + message.xfr = xfr + message.origin = origin + message.tsig_ctx = tsig_ctx + + reader = _WireReader(wire, initialize_message, question_only, + one_rr_per_rrset, ignore_trailing, keyring, multi) + try: + m = reader.read() + except dns.exception.FormError: + if reader.message and (reader.message.flags & dns.flags.TC) and \ + raise_on_truncation: + raise Truncated(message=reader.message) + else: + raise + # Reading a truncated message might not have any errors, so we + # have to do this check here too. + if m.flags & dns.flags.TC and raise_on_truncation: + raise Truncated(message=m) return m -class _TextReader(object): +class _TextReader: """Text format reader. - @ivar tok: the tokenizer - @type tok: dns.tokenizer.Tokenizer object - @ivar message: The message object being built - @type message: dns.message.Message object - @ivar updating: Is the message a dynamic update? - @type updating: bool - @ivar zone_rdclass: The class of the zone in messages which are + tok: the tokenizer. + message: The message object being built. DNS dynamic updates. - @type zone_rdclass: int - @ivar last_name: The most recently read name when building a message object - from text format. - @type last_name: dns.name.Name object + last_name: The most recently read name when building a message object. + one_rr_per_rrset: Put each RR into its own RRset? + origin: The origin for relative names + relativize: relativize names? + relativize_to: the origin to relativize to. """ - def __init__(self, text, message): - self.message = message - self.tok = dns.tokenizer.Tokenizer(text) + def __init__(self, text, idna_codec, one_rr_per_rrset=False, + origin=None, relativize=True, relativize_to=None): + self.message = None + self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) self.last_name = None - self.zone_rdclass = dns.rdataclass.IN - self.updating = False + self.one_rr_per_rrset = one_rr_per_rrset + self.origin = origin + self.relativize = relativize + self.relativize_to = relativize_to + self.id = None + self.edns = -1 + self.ednsflags = 0 + self.payload = None + self.rcode = None + self.opcode = dns.opcode.QUERY + self.flags = 0 def _header_line(self, section): """Process one line from the text format header section.""" @@ -841,52 +982,51 @@ class _TextReader(object): token = self.tok.get() what = token.value if what == 'id': - self.message.id = self.tok.get_int() + self.id = self.tok.get_int() elif what == 'flags': while True: token = self.tok.get() if not token.is_identifier(): self.tok.unget(token) break - self.message.flags = self.message.flags | \ - dns.flags.from_text(token.value) - if dns.opcode.is_update(self.message.flags): - self.updating = True + self.flags = self.flags | dns.flags.from_text(token.value) elif what == 'edns': - self.message.edns = self.tok.get_int() - self.message.ednsflags = self.message.ednsflags | \ - (self.message.edns << 16) + self.edns = self.tok.get_int() + self.ednsflags = self.ednsflags | (self.edns << 16) elif what == 'eflags': - if self.message.edns < 0: - self.message.edns = 0 + if self.edns < 0: + self.edns = 0 while True: token = self.tok.get() if not token.is_identifier(): self.tok.unget(token) break - self.message.ednsflags = self.message.ednsflags | \ + self.ednsflags = self.ednsflags | \ dns.flags.edns_from_text(token.value) elif what == 'payload': - self.message.payload = self.tok.get_int() - if self.message.edns < 0: - self.message.edns = 0 + self.payload = self.tok.get_int() + if self.edns < 0: + self.edns = 0 elif what == 'opcode': text = self.tok.get_string() - self.message.flags = self.message.flags | \ - dns.opcode.to_flags(dns.opcode.from_text(text)) + self.opcode = dns.opcode.from_text(text) + self.flags = self.flags | dns.opcode.to_flags(self.opcode) elif what == 'rcode': text = self.tok.get_string() - self.message.set_rcode(dns.rcode.from_text(text)) + self.rcode = dns.rcode.from_text(text) else: raise UnknownHeaderField self.tok.get_eol() - def _question_line(self, section): + def _question_line(self, section_number): """Process one line from the text format question section.""" + section = self.message.sections[section_number] token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = dns.name.from_text(token.value, None) + self.last_name = self.tok.as_name(token, self.message.origin, + self.relativize, + self.relativize_to) name = self.last_name token = self.tok.get() if not token.is_identifier(): @@ -899,27 +1039,28 @@ class _TextReader(object): raise dns.exception.SyntaxError except dns.exception.SyntaxError: raise dns.exception.SyntaxError - except: + except Exception: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) - self.message.find_rrset(self.message.question, name, - rdclass, rdtype, create=True, + (rdclass, rdtype, _, _) = \ + self.message._parse_rr_header(section_number, name, rdclass, rdtype) + self.message.find_rrset(section, name, rdclass, rdtype, create=True, force_unique=True) - if self.updating: - self.zone_rdclass = rdclass self.tok.get_eol() - def _rr_line(self, section): + def _rr_line(self, section_number): """Process one line from the text format answer, authority, or additional data sections. """ - deleting = None + section = self.message.sections[section_number] # Name token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = dns.name.from_text(token.value, None) + self.last_name = self.tok.as_name(token, self.message.origin, + self.relativize, + self.relativize_to) name = self.last_name token = self.tok.get() if not token.is_identifier(): @@ -932,7 +1073,7 @@ class _TextReader(object): raise dns.exception.SyntaxError except dns.exception.SyntaxError: raise dns.exception.SyntaxError - except: + except Exception: ttl = 0 # Class try: @@ -940,35 +1081,50 @@ class _TextReader(object): token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError - if rdclass == dns.rdataclass.ANY or rdclass == dns.rdataclass.NONE: - deleting = rdclass - rdclass = self.zone_rdclass except dns.exception.SyntaxError: raise dns.exception.SyntaxError - except: + except Exception: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) + (rdclass, rdtype, deleting, empty) = \ + self.message._parse_rr_header(section_number, name, rdclass, rdtype) token = self.tok.get() + if empty and not token.is_eol_or_eof(): + raise dns.exception.SyntaxError if not token.is_eol_or_eof(): self.tok.unget(token) - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, None) + rd = dns.rdata.from_text(rdclass, rdtype, self.tok, + self.message.origin, self.relativize, + self.relativize_to) covers = rd.covers() else: rd = None covers = dns.rdatatype.NONE rrset = self.message.find_rrset(section, name, rdclass, rdtype, covers, - deleting, True, self.updating) + deleting, True, self.one_rr_per_rrset) if rd is not None: rrset.add(rd, ttl) + def _make_message(self): + factory = _message_factory_from_opcode(self.opcode) + message = factory(id=self.id) + message.flags = self.flags + if self.edns >= 0: + message.use_edns(self.edns, self.ednsflags, self.payload) + if self.rcode: + message.set_rcode(self.rcode) + if self.origin: + message.origin = self.origin + return message + def read(self): """Read a text format DNS message and build a dns.message.Message object.""" line_method = self._header_line - section = None + section_number = None while 1: token = self.tok.get(True, True) if token.is_eol_or_eof(): @@ -977,74 +1133,109 @@ class _TextReader(object): u = token.value.upper() if u == 'HEADER': line_method = self._header_line - elif u == 'QUESTION' or u == 'ZONE': - line_method = self._question_line - section = self.message.question - elif u == 'ANSWER' or u == 'PREREQ': - line_method = self._rr_line - section = self.message.answer - elif u == 'AUTHORITY' or u == 'UPDATE': - line_method = self._rr_line - section = self.message.authority - elif u == 'ADDITIONAL': - line_method = self._rr_line - section = self.message.additional + + if self.message: + message = self.message + else: + # If we don't have a message, create one with the current + # opcode, so that we know which section names to parse. + message = self._make_message() + try: + section_number = message._section_enum.from_text(u) + # We found a section name. If we don't have a message, + # use the one we just created. + if not self.message: + self.message = message + self.one_rr_per_rrset = \ + message._get_one_rr_per_rrset(self.one_rr_per_rrset) + if section_number == MessageSection.QUESTION: + line_method = self._question_line + else: + line_method = self._rr_line + except Exception: + # It's just a comment. + pass self.tok.get_eol() continue self.tok.unget(token) - line_method(section) + line_method(section_number) + if not self.message: + self.message = self._make_message() + return self.message -def from_text(text): +def from_text(text, idna_codec=None, one_rr_per_rrset=False, + origin=None, relativize=True, relativize_to=None): """Convert the text format message into a message object. - @param text: The text format message. - @type text: string - @raises UnknownHeaderField: - @raises dns.exception.SyntaxError: - @rtype: dns.message.Message object""" + The reader stops after reading the first blank line in the input to + facilitate reading multiple messages from a single file with + ``dns.message.from_file()``. + + *text*, a ``str``, the text format message. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + *one_rr_per_rrset*, a ``bool``. If ``True``, then each RR is put + into its own rrset. The default is ``False``. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + + Raises ``dns.message.UnknownHeaderField`` if a header is unknown. + + Raises ``dns.exception.SyntaxError`` if the text is badly formed. + + Returns a ``dns.message.Message object`` + """ # 'text' can also be a file, but we don't publish that fact # since it's an implementation detail. The official file # interface is from_file(). - m = Message() - - reader = _TextReader(text, m) - reader.read() - - return m + reader = _TextReader(text, idna_codec, one_rr_per_rrset, origin, + relativize, relativize_to) + return reader.read() -def from_file(f): +def from_file(f, idna_codec=None, one_rr_per_rrset=False): """Read the next text format message from the specified file. - @param f: file or string. If I{f} is a string, it is treated - as the name of a file to open. - @raises UnknownHeaderField: - @raises dns.exception.SyntaxError: - @rtype: dns.message.Message object""" + Message blocks are separated by a single blank line. - str_type = string_types - opts = 'rU' + *f*, a ``file`` or ``str``. If *f* is text, it is treated as the + pathname of a file to open. - if isinstance(f, str_type): - f = open(f, opts) - want_close = True - else: - want_close = False + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. - try: - m = from_text(f) - finally: - if want_close: - f.close() - return m + *one_rr_per_rrset*, a ``bool``. If ``True``, then each RR is put + into its own rrset. The default is ``False``. + + Raises ``dns.message.UnknownHeaderField`` if a header is unknown. + + Raises ``dns.exception.SyntaxError`` if the text is badly formed. + + Returns a ``dns.message.Message object`` + """ + + with contextlib.ExitStack() as stack: + if isinstance(f, str): + f = stack.enter_context(open(f)) + return from_text(f, idna_codec, one_rr_per_rrset) def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, want_dnssec=False, ednsflags=None, payload=None, - request_payload=None, options=None): + request_payload=None, options=None, idna_codec=None): """Make a query message. The query name, type, and class may all be specified either @@ -1053,38 +1244,45 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, The query will have a randomly chosen query id, and its DNS flags will be set to dns.flags.RD. - @param qname: The query name. - @type qname: dns.name.Name object or string - @param rdtype: The desired rdata type. - @type rdtype: int - @param rdclass: The desired rdata class; the default is class IN. - @type rdclass: int - @param use_edns: The EDNS level to use; the default is None (no EDNS). + qname, a ``dns.name.Name`` or ``str``, the query name. + + *rdtype*, an ``int`` or ``str``, the desired rdata type. + + *rdclass*, an ``int`` or ``str``, the desired rdata class; the default + is class IN. + + *use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the + default is None (no EDNS). See the description of dns.message.Message.use_edns() for the possible values for use_edns and their meanings. - @type use_edns: int or bool or None - @param want_dnssec: Should the query indicate that DNSSEC is desired? - @type want_dnssec: bool - @param ednsflags: EDNS flag values. - @type ednsflags: int - @param payload: The EDNS sender's payload field, which is the maximum - size of UDP datagram the sender can handle. - @type payload: int - @param request_payload: The EDNS payload size to use when sending - this message. If not specified, defaults to the value of payload. - @type request_payload: int or None - @param options: The EDNS options - @type options: None or list of dns.edns.Option objects - @see: RFC 2671 - @rtype: dns.message.Message object""" - if isinstance(qname, string_types): - qname = dns.name.from_text(qname) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if isinstance(rdclass, string_types): - rdclass = dns.rdataclass.from_text(rdclass) - m = Message() + *want_dnssec*, a ``bool``. If ``True``, DNSSEC data is desired. + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + + *request_payload*, an ``int``, is the EDNS payload size to use when + sending this message. If not specified, defaults to the value of + *payload*. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.message.QueryMessage`` + """ + + if isinstance(qname, str): + qname = dns.name.from_text(qname, idna_codec=idna_codec) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + m = QueryMessage() m.flags |= dns.flags.RD m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) @@ -1115,7 +1313,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, def make_response(query, recursion_available=False, our_payload=8192, - fudge=300): + fudge=300, tsig_error=0): """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the @@ -1125,20 +1323,26 @@ def make_response(query, recursion_available=False, our_payload=8192, question section, so the query's question RRsets should not be changed. - @param query: the query to respond to - @type query: dns.message.Message object - @param recursion_available: should RA be set in the response? - @type recursion_available: bool - @param our_payload: payload size to advertise in EDNS responses; default - is 8192. - @type our_payload: int - @param fudge: TSIG time fudge; default is 300 seconds. - @type fudge: int - @rtype: dns.message.Message object""" + *query*, a ``dns.message.Message``, the query to respond to. + + *recursion_available*, a ``bool``, should RA be set in the response? + + *our_payload*, an ``int``, the payload size to advertise in EDNS + responses. + + *fudge*, an ``int``, the TSIG time fudge. + + *tsig_error*, an ``int``, the TSIG error. + + Returns a ``dns.message.Message`` object whose specific class is + appropriate for the query. For example, if query is a + ``dns.update.UpdateMessage``, response will be too. + """ if query.flags & dns.flags.QR: raise dns.exception.FormError('specified query message is not a query') - response = dns.message.Message(query.id) + factory = _message_factory_from_opcode(query.opcode()) + response = factory(id=query.id) response.flags = dns.flags.QR | (query.flags & dns.flags.RD) if recursion_available: response.flags |= dns.flags.RA @@ -1147,7 +1351,7 @@ def make_response(query, recursion_available=False, our_payload=8192, if query.edns >= 0: response.use_edns(0, 0, our_payload, query.payload) if query.had_tsig: - response.use_tsig(query.keyring, query.keyname, fudge, None, 0, '', - query.keyalgorithm) + response.use_tsig(query.keyring, query.keyname, fudge, None, + tsig_error, b'', query.keyalgorithm) response.request_mac = query.mac return response diff --git a/lib/dns/name.py b/lib/dns/name.py index 2a74694c..529ae7f9 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -14,126 +16,262 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. """DNS Names. - -@var root: The DNS root name. -@type root: dns.name.Name object -@var empty: The empty DNS name. -@type empty: dns.name.Name object """ -from io import BytesIO -import struct -import sys import copy -import encodings.idna - -import dns.exception -import dns.wiredata - -from ._compat import long, binary_type, text_type, unichr +import struct +import encodings.idna # type: ignore try: - maxint = sys.maxint -except: - maxint = (1 << (8 * struct.calcsize("P"))) / 2 - 1 + import idna # type: ignore + have_idna_2008 = True +except ImportError: # pragma: no cover + have_idna_2008 = False +import dns.wire +import dns.exception + +# fullcompare() result values + +#: The compared names have no relationship to each other. NAMERELN_NONE = 0 +#: the first name is a superdomain of the second. NAMERELN_SUPERDOMAIN = 1 +#: The first name is a subdomain of the second. NAMERELN_SUBDOMAIN = 2 +#: The compared names are equal. NAMERELN_EQUAL = 3 +#: The compared names have a common ancestor. NAMERELN_COMMONANCESTOR = 4 class EmptyLabel(dns.exception.SyntaxError): - """A DNS label is empty.""" class BadEscape(dns.exception.SyntaxError): - """An escaped code in a text format of DNS name is invalid.""" class BadPointer(dns.exception.FormError): - """A DNS compression pointer points forward instead of backward.""" class BadLabelType(dns.exception.FormError): - """The label type in DNS name wire format is unknown.""" class NeedAbsoluteNameOrOrigin(dns.exception.DNSException): - """An attempt was made to convert a non-absolute name to wire when there was also a non-absolute (or missing) origin.""" class NameTooLong(dns.exception.FormError): - """A DNS name is > 255 octets long.""" class LabelTooLong(dns.exception.SyntaxError): - """A DNS label is > 63 octets long.""" class AbsoluteConcatenation(dns.exception.DNSException): - """An attempt was made to append anything other than the empty name to an absolute DNS name.""" class NoParent(dns.exception.DNSException): - """An attempt was made to get the parent of the root name or the empty name.""" -_escaped = bytearray(b'"().;\\@$') +class NoIDNA2008(dns.exception.DNSException): + """IDNA 2008 processing was requested but the idna module is not + available.""" -def _escapify(label, unicode_mode=False): +class IDNAException(dns.exception.DNSException): + """IDNA processing raised an exception.""" + + supp_kwargs = {'idna_exception'} + fmt = "IDNA processing exception: {idna_exception}" + + +class IDNACodec: + """Abstract base class for IDNA encoder/decoders.""" + + def __init__(self): + pass + + def is_idna(self, label): + return label.lower().startswith(b'xn--') + + def encode(self, label): + raise NotImplementedError # pragma: no cover + + def decode(self, label): + # We do not apply any IDNA policy on decode. + if self.is_idna(label): + try: + label = label[4:].decode('punycode') + except Exception as e: + raise IDNAException(idna_exception=e) + return _escapify(label) + + +class IDNA2003Codec(IDNACodec): + """IDNA 2003 encoder/decoder.""" + + def __init__(self, strict_decode=False): + """Initialize the IDNA 2003 encoder/decoder. + + *strict_decode* is a ``bool``. If `True`, then IDNA2003 checking + is done when decoding. This can cause failures if the name + was encoded with IDNA2008. The default is `False`. + """ + + super().__init__() + self.strict_decode = strict_decode + + def encode(self, label): + """Encode *label*.""" + + if label == '': + return b'' + try: + return encodings.idna.ToASCII(label) + except UnicodeError: + raise LabelTooLong + + def decode(self, label): + """Decode *label*.""" + if not self.strict_decode: + return super().decode(label) + if label == b'': + return '' + try: + return _escapify(encodings.idna.ToUnicode(label)) + except Exception as e: + raise IDNAException(idna_exception=e) + + +class IDNA2008Codec(IDNACodec): + """IDNA 2008 encoder/decoder. + """ + + def __init__(self, uts_46=False, transitional=False, + allow_pure_ascii=False, strict_decode=False): + """Initialize the IDNA 2008 encoder/decoder. + + *uts_46* is a ``bool``. If True, apply Unicode IDNA + compatibility processing as described in Unicode Technical + Standard #46 (http://unicode.org/reports/tr46/). + If False, do not apply the mapping. The default is False. + + *transitional* is a ``bool``: If True, use the + "transitional" mode described in Unicode Technical Standard + #46. The default is False. + + *allow_pure_ascii* is a ``bool``. If True, then a label which + consists of only ASCII characters is allowed. This is less + strict than regular IDNA 2008, but is also necessary for mixed + names, e.g. a name with starting with "_sip._tcp." and ending + in an IDN suffix which would otherwise be disallowed. The + default is False. + + *strict_decode* is a ``bool``: If True, then IDNA2008 checking + is done when decoding. This can cause failures if the name + was encoded with IDNA2003. The default is False. + """ + super().__init__() + self.uts_46 = uts_46 + self.transitional = transitional + self.allow_pure_ascii = allow_pure_ascii + self.strict_decode = strict_decode + + def encode(self, label): + if label == '': + return b'' + if self.allow_pure_ascii and is_all_ascii(label): + encoded = label.encode('ascii') + if len(encoded) > 63: + raise LabelTooLong + return encoded + if not have_idna_2008: + raise NoIDNA2008 + try: + if self.uts_46: + label = idna.uts46_remap(label, False, self.transitional) + return idna.alabel(label) + except idna.IDNAError as e: + if e.args[0] == 'Label too long': + raise LabelTooLong + else: + raise IDNAException(idna_exception=e) + + def decode(self, label): + if not self.strict_decode: + return super().decode(label) + if label == b'': + return '' + if not have_idna_2008: + raise NoIDNA2008 + try: + if self.uts_46: + label = idna.uts46_remap(label, False, False) + return _escapify(idna.ulabel(label)) + except (idna.IDNAError, UnicodeError) as e: + raise IDNAException(idna_exception=e) + +_escaped = b'"().;\\@$' +_escaped_text = '"().;\\@$' + +IDNA_2003_Practical = IDNA2003Codec(False) +IDNA_2003_Strict = IDNA2003Codec(True) +IDNA_2003 = IDNA_2003_Practical +IDNA_2008_Practical = IDNA2008Codec(True, False, True, False) +IDNA_2008_UTS_46 = IDNA2008Codec(True, False, False, False) +IDNA_2008_Strict = IDNA2008Codec(False, False, False, True) +IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) +IDNA_2008 = IDNA_2008_Practical + +def _escapify(label): """Escape the characters in label which need it. - @param unicode_mode: escapify only special and whitespace (<= 0x20) - characters @returns: the escaped string @rtype: string""" - if not unicode_mode: + if isinstance(label, bytes): + # Ordinary DNS label mode. Escape special characters and values + # < 0x20 or > 0x7f. text = '' - if isinstance(label, text_type): - label = label.encode() - for c in bytearray(label): - packed = struct.pack('!B', c).decode() + for c in label: if c in _escaped: - text += '\\' + packed + text += '\\' + chr(c) elif c > 0x20 and c < 0x7F: - text += packed + text += chr(c) else: text += '\\%03d' % c - return text.encode() + return text - text = u'' - if isinstance(label, binary_type): - label = label.decode() + # Unicode label mode. Escape only special characters and values < 0x20 + text = '' for c in label: - if c > u'\x20' and c < u'\x7f': - text += c + if c in _escaped_text: + text += '\\' + c + elif c <= '\x20': + text += '\\%03d' % ord(c) else: - if c >= u'\x7f': - text += c - else: - text += u'\\%03d' % c + text += c return text - def _validate_labels(labels): """Check for empty labels in the middle of a label sequence, labels that are too long, and for too many labels. - @raises NameTooLong: the name as a whole is too long - @raises EmptyLabel: a label is empty (i.e. the root label) and appears - in a position other than the end of the label sequence""" + + Raises ``dns.name.NameTooLong`` if the name as a whole is too long. + + Raises ``dns.name.EmptyLabel`` if a label is empty (i.e. the root + label) and appears in a position other than the end of the label + sequence + + """ l = len(labels) total = 0 @@ -153,38 +291,46 @@ def _validate_labels(labels): raise EmptyLabel -def _ensure_bytes(label): - if isinstance(label, binary_type): +def _maybe_convert_to_binary(label): + """If label is ``str``, convert it to ``bytes``. If it is already + ``bytes`` just return it. + + """ + + if isinstance(label, bytes): return label - if isinstance(label, text_type): + if isinstance(label, str): return label.encode() - raise ValueError + raise ValueError # pragma: no cover -class Name(object): +class Name: """A DNS name. - The dns.name.Name class represents a DNS name as a tuple of labels. - Instances of the class are immutable. - - @ivar labels: The tuple of labels in the name. Each label is a string of - up to 63 octets.""" + The dns.name.Name class represents a DNS name as a tuple of + labels. Each label is a ``bytes`` in DNS wire format. Instances + of the class are immutable. + """ __slots__ = ['labels'] def __init__(self, labels): - """Initialize a domain name from a list of labels. - @param labels: the labels - @type labels: any iterable whose values are strings + """*labels* is any iterable whose values are ``str`` or ``bytes``. """ - labels = [_ensure_bytes(x) for x in labels] - super(Name, self).__setattr__('labels', tuple(labels)) + + labels = [_maybe_convert_to_binary(x) for x in labels] + super().__setattr__('labels', tuple(labels)) _validate_labels(self.labels) def __setattr__(self, name, value): + # Names are immutable raise TypeError("object doesn't support attribute assignment") + def __delattr__(self, name): + # Names are immutable + raise TypeError("object doesn't support attribute deletion") + def __copy__(self): return Name(self.labels) @@ -192,52 +338,71 @@ class Name(object): return Name(copy.deepcopy(self.labels, memo)) def __getstate__(self): + # Names can be pickled return {'labels': self.labels} def __setstate__(self, state): - super(Name, self).__setattr__('labels', state['labels']) + super().__setattr__('labels', state['labels']) _validate_labels(self.labels) def is_absolute(self): """Is the most significant label of this name the root label? - @rtype: bool + + Returns a ``bool``. """ return len(self.labels) > 0 and self.labels[-1] == b'' def is_wild(self): """Is this name wild? (I.e. Is the least significant label '*'?) - @rtype: bool + + Returns a ``bool``. """ return len(self.labels) > 0 and self.labels[0] == b'*' def __hash__(self): """Return a case-insensitive hash of the name. - @rtype: int + + Returns an ``int``. """ - h = long(0) + h = 0 for label in self.labels: - for c in bytearray(label.lower()): + for c in label.lower(): h += (h << 3) + c - return int(h % maxint) + return h def fullcompare(self, other): - """Compare two names, returning a 3-tuple (relation, order, nlabels). + """Compare two names, returning a 3-tuple + ``(relation, order, nlabels)``. - I{relation} describes the relation ship between the names, - and is one of: dns.name.NAMERELN_NONE, - dns.name.NAMERELN_SUPERDOMAIN, dns.name.NAMERELN_SUBDOMAIN, - dns.name.NAMERELN_EQUAL, or dns.name.NAMERELN_COMMONANCESTOR + *relation* describes the relation ship between the names, + and is one of: ``dns.name.NAMERELN_NONE``, + ``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``, + ``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``. - I{order} is < 0 if self < other, > 0 if self > other, and == - 0 if self == other. A relative name is always less than an + *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and == + 0 if *self* == *other*. A relative name is always less than an absolute name. If both names have the same relativity, then the DNSSEC order relation is used to order them. - I{nlabels} is the number of significant labels that the two names + *nlabels* is the number of significant labels that the two names have in common. + + Here are some examples. Names ending in "." are absolute names, + those not ending in "." are relative names. + + ============= ============= =========== ===== ======= + self other relation order nlabels + ============= ============= =========== ===== ======= + www.example. www.example. equal 0 3 + www.example. example. subdomain > 0 2 + example. www.example. superdomain < 0 2 + example1.com. example2.com. common anc. < 0 2 + example1 example2. none < 0 0 + example1. example2 none > 0 0 + ============= ============= =========== ===== ======= """ sabs = self.is_absolute() @@ -287,8 +452,10 @@ class Name(object): def is_subdomain(self, other): """Is self a subdomain of other? - The notion of subdomain includes equality. - @rtype: bool + Note that the notion of subdomain includes equality, e.g. + "dnpython.org" is a subdomain of itself. + + Returns a ``bool``. """ (nr, o, nl) = self.fullcompare(other) @@ -299,8 +466,10 @@ class Name(object): def is_superdomain(self, other): """Is self a superdomain of other? - The notion of subdomain includes equality. - @rtype: bool + Note that the notion of superdomain includes equality, e.g. + "dnpython.org" is a superdomain of itself. + + Returns a ``bool``. """ (nr, o, nl) = self.fullcompare(other) @@ -311,7 +480,6 @@ class Name(object): def canonicalize(self): """Return a name which is equal to the current name, but is in DNSSEC canonical form. - @rtype: dns.name.Name object """ return Name([x.lower() for x in self.labels]) @@ -356,96 +524,124 @@ class Name(object): return '' def __str__(self): - return self.to_text(False).decode() + return self.to_text(False) def to_text(self, omit_final_dot=False): - """Convert name to text format. - @param omit_final_dot: If True, don't emit the final dot (denoting the - root label) for absolute names. The default is False. - @rtype: string + """Convert name to DNS text format. + + *omit_final_dot* is a ``bool``. If True, don't emit the final + dot (denoting the root label) for absolute names. The default + is False. + + Returns a ``str``. """ if len(self.labels) == 0: - return b'@' + return '@' if len(self.labels) == 1 and self.labels[0] == b'': - return b'.' + return '.' if omit_final_dot and self.is_absolute(): l = self.labels[:-1] else: l = self.labels - s = b'.'.join(map(_escapify, l)) + s = '.'.join(map(_escapify, l)) return s - def to_unicode(self, omit_final_dot=False): + def to_unicode(self, omit_final_dot=False, idna_codec=None): """Convert name to Unicode text format. IDN ACE labels are converted to Unicode. - @param omit_final_dot: If True, don't emit the final dot (denoting the - root label) for absolute names. The default is False. - @rtype: string + *omit_final_dot* is a ``bool``. If True, don't emit the final + dot (denoting the root label) for absolute names. The default + is False. + *idna_codec* specifies the IDNA encoder/decoder. If None, the + dns.name.IDNA_2003_Practical encoder/decoder is used. + The IDNA_2003_Practical decoder does + not impose any policy, it just decodes punycode, so if you + don't want checking for compliance, you can use this decoder + for IDNA2008 as well. + + Returns a ``str``. """ if len(self.labels) == 0: - return u'@' - if len(self.labels) == 1 and self.labels[0] == '': - return u'.' + return '@' + if len(self.labels) == 1 and self.labels[0] == b'': + return '.' if omit_final_dot and self.is_absolute(): l = self.labels[:-1] else: l = self.labels - s = u'.'.join([_escapify(encodings.idna.ToUnicode(x), True) - for x in l]) - return s + if idna_codec is None: + idna_codec = IDNA_2003_Practical + return '.'.join([idna_codec.decode(x) for x in l]) def to_digestable(self, origin=None): """Convert name to a format suitable for digesting in hashes. - The name is canonicalized and converted to uncompressed wire format. + The name is canonicalized and converted to uncompressed wire + format. All names in wire format are absolute. If the name + is a relative name, then an origin must be supplied. - @param origin: If the name is relative and origin is not None, then - origin will be appended to it. - @type origin: dns.name.Name object - @raises NeedAbsoluteNameOrOrigin: All names in wire format are - absolute. If self is a relative name, then an origin must be supplied; - if it is missing, then this exception is raised - @rtype: string + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then origin will be appended + to the name. + + Raises ``dns.name.NeedAbsoluteNameOrOrigin`` if the name is + relative and no origin was provided. + + Returns a ``bytes``. """ - if not self.is_absolute(): - if origin is None or not origin.is_absolute(): - raise NeedAbsoluteNameOrOrigin - labels = list(self.labels) - labels.extend(list(origin.labels)) - else: - labels = self.labels - dlabels = [struct.pack('!B%ds' % len(x), len(x), x.lower()) - for x in labels] - return b''.join(dlabels) + return self.to_wire(origin=origin, canonicalize=True) - def to_wire(self, file=None, compress=None, origin=None): + def to_wire(self, file=None, compress=None, origin=None, + canonicalize=False): """Convert name to wire format, possibly compressing it. - @param file: the file where the name is emitted (typically - a BytesIO file). If None, a string containing the wire name - will be returned. - @type file: file or None - @param compress: The compression table. If None (the default) names - will not be compressed. - @type compress: dict - @param origin: If the name is relative and origin is not None, then - origin will be appended to it. - @type origin: dns.name.Name object - @raises NeedAbsoluteNameOrOrigin: All names in wire format are - absolute. If self is a relative name, then an origin must be supplied; - if it is missing, then this exception is raised + *file* is the file where the name is emitted (typically an + io.BytesIO file). If ``None`` (the default), a ``bytes`` + containing the wire name will be returned. + + *compress*, a ``dict``, is the compression table to use. If + ``None`` (the default), names will not be compressed. Note that + the compression code assumes that compression offset 0 is the + start of *file*, and thus compression will not be correct + if this is not the case. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then *origin* will be appended + to it. + + *canonicalize*, a ``bool``, indicates whether the name should + be canonicalized; that is, converted to a format suitable for + digesting in hashes. + + Raises ``dns.name.NeedAbsoluteNameOrOrigin`` if the name is + relative and no origin was provided. + + Returns a ``bytes`` or ``None``. """ if file is None: - file = BytesIO() - want_return = True - else: - want_return = False + out = bytearray() + for label in self.labels: + out.append(len(label)) + if canonicalize: + out += label.lower() + else: + out += label + if not self.is_absolute(): + if origin is None or not origin.is_absolute(): + raise NeedAbsoluteNameOrOrigin + for label in origin.labels: + out.append(len(label)) + if canonicalize: + out += label.lower() + else: + out += label + return bytes(out) if not self.is_absolute(): if origin is None or not origin.is_absolute(): @@ -475,13 +671,15 @@ class Name(object): l = len(label) file.write(struct.pack('!B', l)) if l > 0: - file.write(label) - if want_return: - return file.getvalue() + if canonicalize: + file.write(label.lower()) + else: + file.write(label) def __len__(self): """The length of the name (in labels). - @rtype: int + + Returns an ``int``. """ return len(self.labels) @@ -489,9 +687,6 @@ class Name(object): def __getitem__(self, index): return self.labels[index] - def __getslice__(self, start, stop): - return self.labels[start:stop] - def __add__(self, other): return self.concatenate(other) @@ -499,14 +694,14 @@ class Name(object): return self.relativize(other) def split(self, depth): - """Split a name into a prefix and suffix at depth. + """Split a name into a prefix and suffix names at the specified depth. - @param depth: the number of labels in the suffix - @type depth: int - @raises ValueError: the depth was not >= 0 and <= the length of the + *depth* is an ``int`` specifying the number of labels in the suffix + + Raises ``ValueError`` if *depth* was not >= 0 and <= the length of the name. - @returns: the tuple (prefix, suffix) - @rtype: tuple + + Returns the tuple ``(prefix, suffix)``. """ l = len(self.labels) @@ -521,9 +716,11 @@ class Name(object): def concatenate(self, other): """Return a new name which is the concatenation of self and other. - @rtype: dns.name.Name object - @raises AbsoluteConcatenation: self is absolute and other is - not the empty name + + Raises ``dns.name.AbsoluteConcatenation`` if the name is + absolute and *other* is not the empty name. + + Returns a ``dns.name.Name``. """ if self.is_absolute() and len(other) > 0: @@ -533,9 +730,14 @@ class Name(object): return Name(labels) def relativize(self, origin): - """If self is a subdomain of origin, return a new name which is self - relative to origin. Otherwise return self. - @rtype: dns.name.Name object + """If the name is a subdomain of *origin*, return a new name which is + the name relative to origin. Otherwise return the name. + + For example, relativizing ``www.dnspython.org.`` to origin + ``dnspython.org.`` returns the name ``www``. Relativizing ``example.`` + to origin ``dnspython.org.`` returns ``example.``. + + Returns a ``dns.name.Name``. """ if origin is not None and self.is_subdomain(origin): @@ -544,9 +746,14 @@ class Name(object): return self def derelativize(self, origin): - """If self is a relative name, return a new name which is the - concatenation of self and origin. Otherwise return self. - @rtype: dns.name.Name object + """If the name is a relative name, return a new name which is the + concatenation of the name and origin. Otherwise return the name. + + For example, derelativizing ``www`` to origin ``dnspython.org.`` + returns the name ``www.dnspython.org.``. Derelativizing ``example.`` + to origin ``dnspython.org.`` returns ``example.``. + + Returns a ``dns.name.Name``. """ if not self.is_absolute(): @@ -555,11 +762,14 @@ class Name(object): return self def choose_relativity(self, origin=None, relativize=True): - """Return a name with the relativity desired by the caller. If - origin is None, then self is returned. Otherwise, if - relativize is true the name is relativized, and if relativize is - false the name is derelativized. - @rtype: dns.name.Name object + """Return a name with the relativity desired by the caller. + + If *origin* is ``None``, then the name is returned. + Otherwise, if *relativize* is ``True`` the name is + relativized, and if *relativize* is ``False`` the name is + derelativized. + + Returns a ``dns.name.Name``. """ if origin: @@ -572,39 +782,58 @@ class Name(object): def parent(self): """Return the parent of the name. - @rtype: dns.name.Name object - @raises NoParent: the name is either the root name or the empty name, - and thus has no parent. + + For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. + + Raises ``dns.name.NoParent`` if the name is either the root name or the + empty name, and thus has no parent. + + Returns a ``dns.name.Name``. """ + if self == root or self == empty: raise NoParent return Name(self.labels[1:]) +#: The root name, '.' root = Name([b'']) + +#: The empty name. empty = Name([]) - -def from_unicode(text, origin=root): +def from_unicode(text, origin=root, idna_codec=None): """Convert unicode text into a Name object. - Labels are encoded in IDN ACE form. + Labels are encoded in IDN ACE form according to rules specified by + the IDNA codec. - @rtype: dns.name.Name object + *text*, a ``str``, is the text to convert into a name. + + *origin*, a ``dns.name.Name``, specifies the origin to + append to non-absolute names. The default is the root name. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.name.Name``. """ - if not isinstance(text, text_type): + if not isinstance(text, str): raise ValueError("input to from_unicode() must be a unicode string") if not (origin is None or isinstance(origin, Name)): raise ValueError("origin must be a Name or None") labels = [] - label = u'' + label = '' escaping = False edigits = 0 total = 0 - if text == u'@': - text = u'' + if idna_codec is None: + idna_codec = IDNA_2003 + if text == '@': + text = '' if text: - if text == u'.': + if text in ['.', '\u3002', '\uff0e', '\uff61']: return Name([b'']) # no Unicode "u" on this constant! for c in text: if escaping: @@ -623,16 +852,13 @@ def from_unicode(text, origin=root): edigits += 1 if edigits == 3: escaping = False - label += unichr(total) - elif c in [u'.', u'\u3002', u'\uff0e', u'\uff61']: + label += chr(total) + elif c in ['.', '\u3002', '\uff0e', '\uff61']: if len(label) == 0: raise EmptyLabel - try: - labels.append(encodings.idna.ToASCII(label)) - except UnicodeError: - raise LabelTooLong - label = u'' - elif c == u'\\': + labels.append(idna_codec.encode(label)) + label = '' + elif c == '\\': escaping = True edigits = 0 total = 0 @@ -641,10 +867,7 @@ def from_unicode(text, origin=root): if escaping: raise BadEscape if len(label) > 0: - try: - labels.append(encodings.idna.ToASCII(label)) - except UnicodeError: - raise LabelTooLong + labels.append(idna_codec.encode(label)) else: labels.append(b'') @@ -652,15 +875,41 @@ def from_unicode(text, origin=root): labels.extend(list(origin.labels)) return Name(labels) +def is_all_ascii(text): + for c in text: + if ord(c) > 0x7f: + return False + return True -def from_text(text, origin=root): +def from_text(text, origin=root, idna_codec=None): """Convert text into a Name object. - @rtype: dns.name.Name object + + *text*, a ``str``, is the text to convert into a name. + + *origin*, a ``dns.name.Name``, specifies the origin to + append to non-absolute names. The default is the root name. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.name.Name``. """ - if isinstance(text, text_type): - return from_unicode(text, origin) - if not isinstance(text, binary_type): + if isinstance(text, str): + if not is_all_ascii(text): + # Some codepoint in the input text is > 127, so IDNA applies. + return from_unicode(text, origin, idna_codec) + # The input is all ASCII, so treat this like an ordinary non-IDNA + # domain name. Note that "all ASCII" is about the input text, + # not the codepoints in the domain name. E.g. if text has value + # + # r'\150\151\152\153\154\155\156\157\158\159' + # + # then it's still "all ASCII" even though the domain name has + # codepoints > 127. + text = text.encode('ascii') + if not isinstance(text, bytes): raise ValueError("input to from_text() must be a string") if not (origin is None or isinstance(origin, Name)): raise ValueError("origin must be a Name or None") @@ -674,7 +923,7 @@ def from_text(text, origin=root): if text: if text == b'.': return Name([b'']) - for c in bytearray(text): + for c in text: byte_ = struct.pack('!B', c) if escaping: if edigits == 0: @@ -715,49 +964,60 @@ def from_text(text, origin=root): return Name(labels) -def from_wire(message, current): +def from_wire_parser(parser): """Convert possibly compressed wire format into a Name. - @param message: the entire DNS message - @type message: string - @param current: the offset of the beginning of the name from the start - of the message - @type current: int - @raises dns.name.BadPointer: a compression pointer did not point backwards - in the message - @raises dns.name.BadLabelType: an invalid label type was encountered. - @returns: a tuple consisting of the name that was read and the number - of bytes of the wire format message which were consumed reading it - @rtype: (dns.name.Name object, int) tuple + + *parser* is a dns.wire.Parser. + + Raises ``dns.name.BadPointer`` if a compression pointer did not + point backwards in the message. + + Raises ``dns.name.BadLabelType`` if an invalid label type was encountered. + + Returns a ``dns.name.Name`` """ - if not isinstance(message, binary_type): - raise ValueError("input to from_wire() must be a byte string") - message = dns.wiredata.maybe_wrap(message) labels = [] - biggest_pointer = current - hops = 0 - count = message[current] - current += 1 - cused = 1 - while count != 0: - if count < 64: - labels.append(message[current: current + count].unwrap()) - current += count - if hops == 0: - cused += count - elif count >= 192: - current = (count & 0x3f) * 256 + message[current] - if hops == 0: - cused += 1 - if current >= biggest_pointer: - raise BadPointer - biggest_pointer = current - hops += 1 - else: - raise BadLabelType - count = message[current] - current += 1 - if hops == 0: - cused += 1 - labels.append('') - return (Name(labels), cused) + biggest_pointer = parser.current + with parser.restore_furthest(): + count = parser.get_uint8() + while count != 0: + if count < 64: + labels.append(parser.get_bytes(count)) + elif count >= 192: + current = (count & 0x3f) * 256 + parser.get_uint8() + if current >= biggest_pointer: + raise BadPointer + biggest_pointer = current + parser.seek(current) + else: + raise BadLabelType + count = parser.get_uint8() + labels.append(b'') + return Name(labels) + + +def from_wire(message, current): + """Convert possibly compressed wire format into a Name. + + *message* is a ``bytes`` containing an entire DNS message in DNS + wire form. + + *current*, an ``int``, is the offset of the beginning of the name + from the start of the message + + Raises ``dns.name.BadPointer`` if a compression pointer did not + point backwards in the message. + + Raises ``dns.name.BadLabelType`` if an invalid label type was encountered. + + Returns a ``(dns.name.Name, int)`` tuple consisting of the name + that was read and the number of bytes of the wire format message + which were consumed reading it. + """ + + if not isinstance(message, bytes): + raise ValueError("input to from_wire() must be a byte string") + parser = dns.wire.Parser(message, current) + name = from_wire_parser(parser) + return (name, parser.current - current) diff --git a/lib/dns/namedict.py b/lib/dns/namedict.py index 58e40344..4c8f9abd 100644 --- a/lib/dns/namedict.py +++ b/lib/dns/namedict.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # Copyright (C) 2016 Coresec Systems AB # # Permission to use, copy, modify, and distribute this software and its @@ -25,26 +27,26 @@ """DNS name dictionary""" -import collections +from collections.abc import MutableMapping + import dns.name -from ._compat import xrange -class NameDict(collections.MutableMapping): - +class NameDict(MutableMapping): """A dictionary whose keys are dns.name.Name objects. - @ivar max_depth: the maximum depth of the keys that have ever been - added to the dictionary. - @type max_depth: int - @ivar max_depth_items: the number of items of maximum depth - @type max_depth_items: int + + In addition to being like a regular Python dictionary, this + dictionary can also get the deepest match for a given key. """ __slots__ = ["max_depth", "max_depth_items", "__store"] def __init__(self, *args, **kwargs): + super().__init__() self.__store = dict() + #: the maximum depth of the keys that have ever been added self.max_depth = 0 + #: the number of items of maximum depth self.max_depth_items = 0 self.update(dict(*args, **kwargs)) @@ -65,8 +67,8 @@ class NameDict(collections.MutableMapping): self.__update_max_depth(key) def __delitem__(self, key): - value = self.__store.pop(key) - if len(value) == self.max_depth: + self.__store.pop(key) + if len(key) == self.max_depth: self.max_depth_items = self.max_depth_items - 1 if self.max_depth_items == 0: self.max_depth = 0 @@ -83,20 +85,22 @@ class NameDict(collections.MutableMapping): return key in self.__store def get_deepest_match(self, name): - """Find the deepest match to I{name} in the dictionary. + """Find the deepest match to *fname* in the dictionary. The deepest match is the longest name in the dictionary which is - a superdomain of I{name}. + a superdomain of *name*. Note that *superdomain* includes matching + *name* itself. - @param name: the name - @type name: dns.name.Name object - @rtype: (key, value) tuple + *name*, a ``dns.name.Name``, the name to find. + + Returns a ``(key, value)`` where *key* is the deepest + ``dns.name.Name``, and *value* is the value associated with *key*. """ depth = len(name) if depth > self.max_depth: depth = self.max_depth - for i in xrange(-depth, 0): + for i in range(-depth, 0): n = dns.name.Name(name[i:]) if n in self: return (n, self[n]) diff --git a/lib/dns/node.py b/lib/dns/node.py index 7c25060e..b7e21b54 100644 --- a/lib/dns/node.py +++ b/lib/dns/node.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,28 +17,21 @@ """DNS nodes. A node is a set of rdatasets.""" -from io import StringIO +import io import dns.rdataset import dns.rdatatype import dns.renderer -class Node(object): +class Node: - """A DNS node. - - A node is a set of rdatasets - - @ivar rdatasets: the node's rdatasets - @type rdatasets: list of dns.rdataset.Rdataset objects""" + """A Node is a set of rdatasets.""" __slots__ = ['rdatasets'] def __init__(self): - """Initialize a DNS node. - """ - + # the set of rdatasets, represented as a list. self.rdatasets = [] def to_text(self, name, **kw): @@ -44,26 +39,25 @@ class Node(object): Each rdataset at the node is printed. Any keyword arguments to this method are passed on to the rdataset's to_text() method. - @param name: the owner name of the rdatasets - @type name: dns.name.Name object - @rtype: string + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the + rdatasets. + + Returns a ``str``. + """ - s = StringIO() + s = io.StringIO() for rds in self.rdatasets: if len(rds) > 0: s.write(rds.to_text(name, **kw)) - s.write(u'\n') + s.write('\n') return s.getvalue()[:-1] def __repr__(self): return '' def __eq__(self, other): - """Two nodes are equal if they have the same rdatasets. - - @rtype: bool - """ # # This is inefficient. Good thing we don't need to do it much. # @@ -89,24 +83,26 @@ class Node(object): """Find an rdataset matching the specified properties in the current node. - @param rdclass: The class of the rdataset - @type rdclass: int - @param rdtype: The type of the rdataset - @type rdtype: int - @param covers: The covered type. Usually this value is - dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or - dns.rdatatype.RRSIG, then the covers value will be the rdata - type the SIG/RRSIG covers. The library treats the SIG and RRSIG - types as if they were a family of - types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much - easier to work with than if RRSIGs covering different rdata - types were aggregated into a single RRSIG rdataset. - @type covers: int - @param create: If True, create the rdataset if it is not found. - @type create: bool - @raises KeyError: An rdataset of the desired type and class does - not exist and I{create} is not True. - @rtype: dns.rdataset.Rdataset object + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If True, create the rdataset if it is not found. + + Raises ``KeyError`` if an rdataset of the desired type and class does + not exist and *create* is not ``True``. + + Returns a ``dns.rdataset.Rdataset``. """ for rds in self.rdatasets: @@ -124,17 +120,24 @@ class Node(object): current node. None is returned if an rdataset of the specified type and - class does not exist and I{create} is not True. + class does not exist and *create* is not ``True``. - @param rdclass: The class of the rdataset - @type rdclass: int - @param rdtype: The type of the rdataset - @type rdtype: int - @param covers: The covered type. - @type covers: int - @param create: If True, create the rdataset if it is not found. - @type create: bool - @rtype: dns.rdataset.Rdataset object or None + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. Usually this value is + dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or + dns.rdatatype.RRSIG, then the covers value will be the rdata + type the SIG/RRSIG covers. The library treats the SIG and RRSIG + types as if they were a family of + types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much + easier to work with than if RRSIGs covering different rdata + types were aggregated into a single RRSIG rdataset. + + *create*, a ``bool``. If True, create the rdataset if it is not found. + + Returns a ``dns.rdataset.Rdataset`` or ``None``. """ try: @@ -149,12 +152,11 @@ class Node(object): If a matching rdataset does not exist, it is not an error. - @param rdclass: The class of the rdataset - @type rdclass: int - @param rdtype: The type of the rdataset - @type rdtype: int - @param covers: The covered type. - @type covers: int + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. """ rds = self.get_rdataset(rdclass, rdtype, covers) @@ -164,11 +166,16 @@ class Node(object): def replace_rdataset(self, replacement): """Replace an rdataset. - It is not an error if there is no rdataset matching I{replacement}. + It is not an error if there is no rdataset matching *replacement*. - Ownership of the I{replacement} object is transferred to the node; - in other words, this method does not store a copy of I{replacement} - at the node, it stores I{replacement} itself. + Ownership of the *replacement* object is transferred to the node; + in other words, this method does not store a copy of *replacement* + at the node, it stores *replacement* itself. + + *replacement*, a ``dns.rdataset.Rdataset``. + + Raises ``ValueError`` if *replacement* is not a + ``dns.rdataset.Rdataset``. """ if not isinstance(replacement, dns.rdataset.Rdataset): diff --git a/lib/dns/opcode.py b/lib/dns/opcode.py index 2b2918e9..5a76326a 100644 --- a/lib/dns/opcode.py +++ b/lib/dns/opcode.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,58 +17,55 @@ """DNS Opcodes.""" +import dns.enum import dns.exception -QUERY = 0 -IQUERY = 1 -STATUS = 2 -NOTIFY = 4 -UPDATE = 5 +class Opcode(dns.enum.IntEnum): + #: Query + QUERY = 0 + #: Inverse Query (historical) + IQUERY = 1 + #: Server Status (unspecified and unimplemented anywhere) + STATUS = 2 + #: Notify + NOTIFY = 4 + #: Dynamic Update + UPDATE = 5 -_by_text = { - 'QUERY': QUERY, - 'IQUERY': IQUERY, - 'STATUS': STATUS, - 'NOTIFY': NOTIFY, - 'UPDATE': UPDATE -} + @classmethod + def _maximum(cls): + return 15 -# We construct the inverse mapping programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mapping not to be true inverse. + @classmethod + def _unknown_exception_class(cls): + return UnknownOpcode -_by_value = dict((y, x) for x, y in _by_text.items()) +globals().update(Opcode.__members__) class UnknownOpcode(dns.exception.DNSException): - """An DNS opcode is unknown.""" def from_text(text): """Convert text into an opcode. - @param text: the textual opcode - @type text: string - @raises UnknownOpcode: the opcode is unknown - @rtype: int + *text*, a ``str``, the textual opcode + + Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown. + + Returns an ``int``. """ - if text.isdigit(): - value = int(text) - if value >= 0 and value <= 15: - return value - value = _by_text.get(text.upper()) - if value is None: - raise UnknownOpcode - return value + return Opcode.from_text(text) def from_flags(flags): """Extract an opcode from DNS message flags. - @param flags: int - @rtype: int + *flags*, an ``int``, the DNS flags. + + Returns an ``int``. """ return (flags & 0x7800) >> 11 @@ -75,7 +74,10 @@ def from_flags(flags): def to_flags(value): """Convert an opcode to a value suitable for ORing into DNS message flags. - @rtype: int + + *value*, an ``int``, the DNS opcode value. + + Returns an ``int``. """ return (value << 11) & 0x7800 @@ -84,26 +86,22 @@ def to_flags(value): def to_text(value): """Convert an opcode to text. - @param value: the opcdoe - @type value: int - @raises UnknownOpcode: the opcode is unknown - @rtype: string + *value*, an ``int`` the opcode value, + + Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown. + + Returns a ``str``. """ - text = _by_value.get(value) - if text is None: - text = str(value) - return text + return Opcode.to_text(value) def is_update(flags): - """True if the opcode in flags is UPDATE. + """Is the opcode in flags UPDATE? - @param flags: DNS flags - @type flags: int - @rtype: bool + *flags*, an ``int``, the DNS message flags. + + Returns a ``bool``. """ - if (from_flags(flags) == UPDATE): - return True - return False + return from_flags(flags) == Opcode.UPDATE diff --git a/lib/dns/py.typed b/lib/dns/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/dns/query.py b/lib/dns/query.py index 35670983..7df565d8 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,58 +17,91 @@ """Talk to a DNS server.""" -from __future__ import generators - +import contextlib import errno +import os import select import socket import struct -import sys import time +import base64 +import urllib.parse import dns.exception import dns.inet import dns.name import dns.message +import dns.rcode import dns.rdataclass import dns.rdatatype -from ._compat import long, string_types +import dns.serial -if sys.version_info > (3,): - select_error = OSError -else: - select_error = select.error +try: + import requests + from requests_toolbelt.adapters.source import SourceAddressAdapter + from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter + have_doh = True +except ImportError: # pragma: no cover + have_doh = False +try: + import ssl +except ImportError: # pragma: no cover + class ssl: # type: ignore + + class WantReadException(Exception): + pass + + class WantWriteException(Exception): + pass + + class SSLSocket: + pass + + def create_default_context(self, *args, **kwargs): + raise Exception('no ssl support') + +# Function used to create a socket. Can be overridden if needed in special +# situations. +socket_factory = socket.socket class UnexpectedSource(dns.exception.DNSException): - """A DNS query response came from an unexpected address or port.""" class BadResponse(dns.exception.FormError): - """A DNS query response does not respond to the question asked.""" -def _compute_expiration(timeout): - if timeout is None: - return None - else: - return time.time() + timeout +class TransferError(dns.exception.DNSException): + """A zone transfer response got a non-zero rcode.""" + def __init__(self, rcode): + message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) + super().__init__(message) + self.rcode = rcode + + +class NoDOH(dns.exception.DNSException): + """DNS over HTTPS (DOH) was requested but the requests module is not + available.""" + + +def _compute_times(timeout): + now = time.time() + if timeout is None: + return (now, None) + else: + return (now, now + timeout) + +# This module can use either poll() or select() as the "polling backend". +# +# A backend function takes an fd, bools for readability, writablity, and +# error detection, and a timeout. def _poll_for(fd, readable, writable, error, timeout): - """Poll polling backend. - @param fd: File descriptor - @type fd: int - @param readable: Whether to wait for readability - @type readable: bool - @param writable: Whether to wait for writability - @type writable: bool - @param timeout: Deadline timeout (expiration time, in seconds) - @type timeout: float - @return True on success, False on timeout - """ + """Poll polling backend.""" + event_mask = 0 if readable: event_mask |= select.POLLIN @@ -79,7 +114,7 @@ def _poll_for(fd, readable, writable, error, timeout): pollable.register(fd, event_mask) if timeout: - event_list = pollable.poll(long(timeout * 1000)) + event_list = pollable.poll(timeout * 1000) else: event_list = pollable.poll() @@ -87,17 +122,8 @@ def _poll_for(fd, readable, writable, error, timeout): def _select_for(fd, readable, writable, error, timeout): - """Select polling backend. - @param fd: File descriptor - @type fd: int - @param readable: Whether to wait for readability - @type readable: bool - @param writable: Whether to wait for writability - @type writable: bool - @param timeout: Deadline timeout (expiration time, in seconds) - @type timeout: float - @return True on success, False on timeout - """ + """Select polling backend.""" + rset, wset, xset = [], [], [] if readable: @@ -116,6 +142,10 @@ def _select_for(fd, readable, writable, error, timeout): def _wait_for(fd, readable, writable, error, expiration): + # Use the selected polling backend to wait for any of the specified + # events. An "expiration" absolute time is converted into a relative + # timeout. + done = False while not done: if expiration is None: @@ -125,18 +155,19 @@ def _wait_for(fd, readable, writable, error, expiration): if timeout <= 0.0: raise dns.exception.Timeout try: + if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: + return True if not _polling_backend(fd, readable, writable, error, timeout): raise dns.exception.Timeout - except select_error as e: + except OSError as e: # pragma: no cover if e.args[0] != errno.EINTR: raise e done = True def _set_polling_backend(fn): - """ - Internal API. Do not use. - """ + # Internal API. Do not use. + global _polling_backend _polling_backend = fn @@ -147,7 +178,7 @@ if hasattr(select, 'poll'): # be more efficient for high values). _polling_backend = _poll_for else: - _polling_backend = _select_for + _polling_backend = _select_for # pragma: no cover def _wait_for_readable(s, expiration): @@ -162,101 +193,402 @@ def _addresses_equal(af, a1, a2): # Convert the first value of the tuple, which is a textual format # address into binary form, so that we are not confused by different # textual representations of the same address - n1 = dns.inet.inet_pton(af, a1[0]) - n2 = dns.inet.inet_pton(af, a2[0]) + try: + n1 = dns.inet.inet_pton(af, a1[0]) + n2 = dns.inet.inet_pton(af, a2[0]) + except dns.exception.SyntaxError: + return False return n1 == n2 and a1[1:] == a2[1:] -def _destination_and_source(af, where, port, source, source_port): +def _matches_destination(af, from_address, destination, ignore_unexpected): + # Check that from_address is appropriate for a response to a query + # sent to destination. + if not destination: + return True + if _addresses_equal(af, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + return True + elif ignore_unexpected: + return False + raise UnexpectedSource(f'got a response from {from_address} instead of ' + f'{destination}') + + +def _destination_and_source(where, port, source, source_port, + where_must_be_address=True): # Apply defaults and compute destination and source tuples # suitable for use in connect(), sendto(), or bind(). - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None or source_port != 0: - if source is None: - source = '0.0.0.0' - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None or source_port != 0: - if source is None: - source = '::' - source = (source, source_port, 0, 0) + af = None + destination = None + try: + af = dns.inet.af_for_address(where) + destination = where + except Exception: + if where_must_be_address: + raise + # URLs are ok so eat the exception + if source: + saf = dns.inet.af_for_address(source) + if af: + # We know the destination af, so source had better agree! + if saf != af: + raise ValueError('different address families for source ' + + 'and destination') + else: + # We didn't know the destination af, but we know the source, + # so that's our af. + af = saf + if source_port and not source: + # Caller has specified a source_port but not an address, so we + # need to return a source, and we need to use the appropriate + # wildcard address as the address. + if af == socket.AF_INET: + source = '0.0.0.0' + elif af == socket.AF_INET6: + source = '::' + else: + raise ValueError('source_port specified but address family is ' + 'unknown') + # Convert high-level (address, port) tuples into low-level address + # tuples. + if destination: + destination = dns.inet.low_level_address_tuple((destination, port), af) + if source: + source = dns.inet.low_level_address_tuple((source, source_port), af) return (af, destination, source) - -def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False): - """Return the response obtained after sending a query via UDP. - - @param q: the query - @type q: dns.message.Message - @param where: where to send the message - @type where: string containing an IPv4 or IPv6 address - @param timeout: The number of seconds to wait before the query times out. - If None, the default, wait forever. - @type timeout: float - @param port: The port to which to send the message. The default is 53. - @type port: int - @param af: the address family to use. The default is None, which - causes the address family to use to be inferred from the form of where. - If the inference attempt fails, AF_INET is used. - @type af: int - @rtype: dns.message.Message object - @param source: source address. The default is the wildcard address. - @type source: string - @param source_port: The port from which to send the message. - The default is 0. - @type source_port: int - @param ignore_unexpected: If True, ignore responses from unexpected - sources. The default is False. - @type ignore_unexpected: bool - @param one_rr_per_rrset: Put each RR into its own RRset - @type one_rr_per_rrset: bool - """ - - wire = q.to_wire() - (af, destination, source) = _destination_and_source(af, where, port, - source, source_port) - s = socket.socket(af, socket.SOCK_DGRAM, 0) - begin_time = None +def _make_socket(af, type, source, ssl_context=None, server_hostname=None): + s = socket_factory(af, type) try: - expiration = _compute_expiration(timeout) - s.setblocking(0) + s.setblocking(False) if source is not None: s.bind(source) - _wait_for_writable(s, expiration) - begin_time = time.time() - s.sendto(wire, destination) - while 1: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) - if _addresses_equal(af, from_address, destination) or \ - (dns.inet.is_multicast(where) and - from_address[1:] == destination[1:]): - break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) - finally: - if begin_time is None: - response_time = 0 + if ssl_context: + return ssl_context.wrap_socket(s, do_handshake_on_connect=False, + server_hostname=server_hostname) else: - response_time = time.time() - begin_time + return s + except Exception: s.close() - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - one_rr_per_rrset=one_rr_per_rrset) - r.time = response_time + raise + +def https(q, where, timeout=None, port=443, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + session=None, path='/dns-query', post=True, + bootstrap_address=None, verify=True): + """Return the response obtained after sending a query via DNS-over-HTTPS. + + *q*, a ``dns.message.Message``, the query to send. + + *where*, a ``str``, the nameserver IP address or the full URL. If an IP + address is given, the URL will be constructed using the following schema: + https://:/. + + *timeout*, a ``float`` or ``None``, the number of seconds to + wait before the query times out. If ``None``, the default, wait forever. + + *port*, a ``int``, the port to send the query to. The default is 443. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *session*, a ``requests.session.Session``. If provided, the session to use + to send the queries. + + *path*, a ``str``. If *where* is an IP address, then *path* will be used to + construct the URL to send the DNS query to. + + *post*, a ``bool``. If ``True``, the default, POST method will be used. + + *bootstrap_address*, a ``str``, the IP address to use to bypass the + system's DNS resolver. + + *verify*, a ``str``, containing a path to a certificate file or directory. + + Returns a ``dns.message.Message``. + """ + + if not have_doh: + raise NoDOH # pragma: no cover + + wire = q.to_wire() + (af, destination, source) = _destination_and_source(where, port, + source, source_port, + False) + transport_adapter = None + headers = { + "accept": "application/dns-message" + } + try: + where_af = dns.inet.af_for_address(where) + if where_af == socket.AF_INET: + url = 'https://{}:{}{}'.format(where, port, path) + elif where_af == socket.AF_INET6: + url = 'https://[{}]:{}{}'.format(where, port, path) + except ValueError: + if bootstrap_address is not None: + split_url = urllib.parse.urlsplit(where) + headers['Host'] = split_url.hostname + url = where.replace(split_url.hostname, bootstrap_address) + transport_adapter = HostHeaderSSLAdapter() + else: + url = where + if source is not None: + # set source port and source address + transport_adapter = SourceAddressAdapter(source) + + with contextlib.ExitStack() as stack: + if not session: + session = stack.enter_context(requests.sessions.Session()) + + if transport_adapter: + session.mount(url, transport_adapter) + + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH + # GET and POST examples + if post: + headers.update({ + "content-type": "application/dns-message", + "content-length": str(len(wire)) + }) + response = session.post(url, headers=headers, data=wire, + timeout=timeout, verify=verify) + else: + wire = base64.urlsafe_b64encode(wire).rstrip(b"=") + response = session.get(url, headers=headers, + timeout=timeout, verify=verify, + params={"dns": wire}) + + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH + # status codes + if response.status_code < 200 or response.status_code > 299: + raise ValueError('{} responded with status code {}' + '\nResponse body: {}'.format(where, + response.status_code, + response.content)) + r = dns.message.from_wire(response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + r.time = response.elapsed if not q.is_response(r): raise BadResponse return r +def send_udp(sock, what, destination, expiration=None): + """Send a DNS message to the specified UDP socket. + + *sock*, a ``socket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where to send the query. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + _wait_for_writable(sock, expiration) + sent_time = time.time() + n = sock.sendto(what, destination) + return (n, sent_time) + + +def receive_udp(sock, destination=None, expiration=None, + ignore_unexpected=False, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False, + raise_on_truncation=False): + """Read a DNS message from a UDP socket. + + *sock*, a ``socket``. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` + tuple of the received message and the received time. + + If *destination* is ``None``, returns a + ``(dns.message.Message, float, tuple)`` + tuple of the received message, the received time, and the address where + the message arrived from. + """ + + wire = b'' + while 1: + _wait_for_readable(sock, expiration) + (wire, from_address) = sock.recvfrom(65535) + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): + break + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation) + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) + +def udp(q, where, timeout=None, port=53, source=None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, + raise_on_truncation=False, sock=None): + """Return the response obtained after sending a query via UDP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + *sock*, a ``socket.socket``, or ``None``, the socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking datagram socket, + and the *source* and *source_port* are ignored. + + Returns a ``dns.message.Message``. + """ + + wire = q.to_wire() + (af, destination, source) = _destination_and_source(where, port, + source, source_port) + (begin_time, expiration) = _compute_times(timeout) + with contextlib.ExitStack() as stack: + if sock: + s = sock + else: + s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source)) + send_udp(s, wire, destination, expiration) + (r, received_time) = receive_udp(s, destination, expiration, + ignore_unexpected, one_rr_per_rrset, + q.keyring, q.mac, ignore_trailing, + raise_on_truncation) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + +def udp_with_fallback(q, where, timeout=None, port=53, source=None, + source_port=0, ignore_unexpected=False, + one_rr_per_rrset=False, ignore_trailing=False, + udp_sock=None, tcp_sock=None): + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the + UDP query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking datagram socket, + and the *source* and *source_port* are ignored for the UDP query. + + *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the + TCP query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking connected stream + socket, and *where*, *source* and *source_port* are ignored for the TCP + query. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` + if and only if TCP was used. + """ + try: + response = udp(q, where, timeout, port, source, source_port, + ignore_unexpected, one_rr_per_rrset, + ignore_trailing, True, udp_sock) + return (response, False) + except dns.message.Truncated: + response = tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, tcp_sock) + return (response, True) def _net_read(sock, count, expiration): """Read the specified number of bytes from sock. Keep trying until we @@ -267,7 +599,13 @@ def _net_read(sock, count, expiration): s = b'' while count > 0: _wait_for_readable(sock, expiration) - n = sock.recv(count) + try: + n = sock.recv(count) + except ssl.SSLWantReadError: # pragma: no cover + continue + except ssl.SSLWantWriteError: # pragma: no cover + _wait_for_writable(sock, expiration) + continue if n == b'': raise EOFError count = count - len(n) @@ -284,144 +622,292 @@ def _net_write(sock, data, expiration): l = len(data) while current < l: _wait_for_writable(sock, expiration) - current += sock.send(data[current:]) + try: + current += sock.send(data[current:]) + except ssl.SSLWantReadError: # pragma: no cover + _wait_for_readable(sock, expiration) + continue + except ssl.SSLWantWriteError: # pragma: no cover + continue -def _connect(s, address): - try: - s.connect(address) - except socket.error: - (ty, v) = sys.exc_info()[:2] +def send_tcp(sock, what, expiration=None): + """Send a DNS message to the specified TCP socket. - if hasattr(v, 'errno'): - v_err = v.errno - else: - v_err = v[0] - if v_err not in [errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY]: - raise v + *sock*, a ``socket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + l = len(what) + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = struct.pack("!H", l) + what + _wait_for_writable(sock, expiration) + sent_time = time.time() + _net_write(sock, tcpmsg, expiration) + return (len(tcpmsg), sent_time) + +def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False): + """Read a DNS message from a TCP socket. + + *sock*, a ``socket``. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``(dns.message.Message, float)`` tuple of the received message + and the received time. + """ + + ldata = _net_read(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = _net_read(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + return (r, received_time) + +def _connect(s, address, expiration): + err = s.connect_ex(address) + if err == 0: + return + if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY): + _wait_for_writable(s, expiration) + err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, os.strerror(err)) -def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, - one_rr_per_rrset=False): +def tcp(q, where, timeout=None, port=53, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock=None): """Return the response obtained after sending a query via TCP. - @param q: the query - @type q: dns.message.Message object - @param where: where to send the message - @type where: string containing an IPv4 or IPv6 address - @param timeout: The number of seconds to wait before the query times out. - If None, the default, wait forever. - @type timeout: float - @param port: The port to which to send the message. The default is 53. - @type port: int - @param af: the address family to use. The default is None, which - causes the address family to use to be inferred from the form of where. - If the inference attempt fails, AF_INET is used. - @type af: int - @rtype: dns.message.Message object - @param source: source address. The default is the wildcard address. - @type source: string - @param source_port: The port from which to send the message. + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is 0. - @type source_port: int - @param one_rr_per_rrset: Put each RR into its own RRset - @type one_rr_per_rrset: bool + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, a ``socket.socket``, or ``None``, the socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking connected stream + socket, and *where*, *port*, *source* and *source_port* are ignored. + + Returns a ``dns.message.Message``. """ wire = q.to_wire() - (af, destination, source) = _destination_and_source(af, where, port, - source, source_port) - s = socket.socket(af, socket.SOCK_STREAM, 0) - begin_time = None - try: - expiration = _compute_expiration(timeout) - s.setblocking(0) - begin_time = time.time() - if source is not None: - s.bind(source) - _connect(s, destination) - - l = len(wire) - - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire - _net_write(s, tcpmsg, expiration) - ldata = _net_read(s, 2, expiration) - (l,) = struct.unpack("!H", ldata) - wire = _net_read(s, l, expiration) - finally: - if begin_time is None: - response_time = 0 + (begin_time, expiration) = _compute_times(timeout) + with contextlib.ExitStack() as stack: + if sock: + # + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + sock.getpeername() + s = sock else: - response_time = time.time() - begin_time - s.close() - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - one_rr_per_rrset=one_rr_per_rrset) - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + (af, destination, source) = _destination_and_source(where, port, + source, + source_port) + s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, + source)) + _connect(s, destination, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.mac, ignore_trailing) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + + +def _tls_handshake(s, expiration): + while True: + try: + s.do_handshake() + return + except ssl.SSLWantReadError: + _wait_for_readable(s, expiration) + except ssl.SSLWantWriteError: # pragma: no cover + _wait_for_writable(s, expiration) + + +def tls(q, where, timeout=None, port=853, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock=None, + ssl_context=None, server_hostname=None): + """Return the response obtained after sending a query via TLS. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 853. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for + the query. If ``None``, the default, a socket is created. Note + that if a socket is provided, it must be a nonblocking connected + SSL stream socket, and *where*, *port*, *source*, *source_port*, + and *ssl_context* are ignored. + + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing + a TLS connection. If ``None``, the default, creates one with the default + configuration. + + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + + Returns a ``dns.message.Message``. + + """ + + if sock: + # + # If a socket was provided, there's no special TLS handling needed. + # + return tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, sock) + + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + (af, destination, source) = _destination_and_source(where, port, + source, source_port) + if ssl_context is None and not sock: + ssl_context = ssl.create_default_context() + if server_hostname is None: + ssl_context.check_hostname = False + + with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context, + server_hostname=server_hostname) as s: + _connect(s, destination, expiration) + _tls_handshake(s, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.mac, ignore_trailing) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, timeout=None, port=53, keyring=None, keyname=None, relativize=True, - af=None, lifetime=None, source=None, source_port=0, serial=0, + lifetime=None, source=None, source_port=0, serial=0, use_udp=False, keyalgorithm=dns.tsig.default_algorithm): """Return a generator for the responses to a zone transfer. - @param where: where to send the message - @type where: string containing an IPv4 or IPv6 address - @param zone: The name of the zone to transfer - @type zone: dns.name.Name object or string - @param rdtype: The type of zone transfer. The default is - dns.rdatatype.AXFR. - @type rdtype: int or string - @param rdclass: The class of the zone transfer. The default is - dns.rdataclass.IN. - @type rdclass: int or string - @param timeout: The number of seconds to wait for each response message. - If None, the default, wait forever. - @type timeout: float - @param port: The port to which to send the message. The default is 53. - @type port: int - @param keyring: The TSIG keyring to use - @type keyring: dict - @param keyname: The name of the TSIG key to use - @type keyname: dns.name.Name object or string - @param relativize: If True, all names in the zone will be relativized to - the zone origin. It is essential that the relativize setting matches - the one specified to dns.zone.from_xfr(). - @type relativize: bool - @param af: the address family to use. The default is None, which - causes the address family to use to be inferred from the form of where. - If the inference attempt fails, AF_INET is used. - @type af: int - @param lifetime: The total number of seconds to spend doing the transfer. - If None, the default, then there is no limit on the time the transfer may - take. - @type lifetime: float - @rtype: generator of dns.message.Message objects. - @param source: source address. The default is the wildcard address. - @type source: string - @param source_port: The port from which to send the message. + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer. + + *rdtype*, an ``int`` or ``str``, the type of zone transfer. The + default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be + used to do an incremental transfer instead. + + *rdclass*, an ``int`` or ``str``, the class of the zone transfer. + The default is ``dns.rdataclass.IN``. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG + key to use. + + *relativize*, a ``bool``. If ``True``, all names in the zone will be + relativized to the zone origin. It is essential that the + relativize setting matches the one specified to + ``dns.zone.from_xfr()`` if using this generator to make a zone. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is 0. - @type source_port: int - @param serial: The SOA serial number to use as the base for an IXFR diff - sequence (only meaningful if rdtype == dns.rdatatype.IXFR). - @type serial: int - @param use_udp: Use UDP (only meaningful for IXFR) - @type use_udp: bool - @param keyalgorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm - @type keyalgorithm: string + + *serial*, an ``int``, the SOA serial number to use as the base for + an IXFR diff sequence (only meaningful if *rdtype* is + ``dns.rdatatype.IXFR``). + + *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR). + + *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. + + Raises on errors, and so does the generator. + + Returns a generator of ``dns.message.Message`` objects. """ - if isinstance(zone, string_types): + if isinstance(zone, str): zone = dns.name.from_text(zone) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) q = dns.message.make_query(zone, rdtype, rdclass) if rdtype == dns.rdatatype.IXFR: rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA', @@ -430,107 +916,103 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) wire = q.to_wire() - (af, destination, source) = _destination_and_source(af, where, port, + (af, destination, source) = _destination_and_source(where, port, source, source_port) - if use_udp: - if rdtype != dns.rdatatype.IXFR: - raise ValueError('cannot do a UDP AXFR') - s = socket.socket(af, socket.SOCK_DGRAM, 0) - else: - s = socket.socket(af, socket.SOCK_STREAM, 0) - s.setblocking(0) - if source is not None: - s.bind(source) - expiration = _compute_expiration(lifetime) - _connect(s, destination) - l = len(wire) - if use_udp: - _wait_for_writable(s, expiration) - s.send(wire) - else: - tcpmsg = struct.pack("!H", l) + wire - _net_write(s, tcpmsg, expiration) - done = False - delete_mode = True - expecting_SOA = False - soa_rrset = None - if relativize: - origin = zone - oname = dns.name.empty - else: - origin = None - oname = zone - tsig_ctx = None - first = True - while not done: - mexpiration = _compute_expiration(timeout) - if mexpiration is None or mexpiration > expiration: - mexpiration = expiration + if use_udp and rdtype != dns.rdatatype.IXFR: + raise ValueError('cannot do a UDP AXFR') + sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM + with _make_socket(af, sock_type, source) as s: + (_, expiration) = _compute_times(lifetime) + _connect(s, destination, expiration) + l = len(wire) if use_udp: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) + _wait_for_writable(s, expiration) + s.send(wire) else: - ldata = _net_read(s, 2, mexpiration) - (l,) = struct.unpack("!H", ldata) - wire = _net_read(s, l, mexpiration) - is_ixfr = (rdtype == dns.rdatatype.IXFR) - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - xfr=True, origin=origin, tsig_ctx=tsig_ctx, - multi=True, first=first, - one_rr_per_rrset=is_ixfr) - tsig_ctx = r.tsig_ctx - first = False - answer_index = 0 - if soa_rrset is None: - if not r.answer or r.answer[0].name != oname: - raise dns.exception.FormError( - "No answer or RRset not for qname") - rrset = r.answer[0] - if rrset.rdtype != dns.rdatatype.SOA: - raise dns.exception.FormError("first RRset is not an SOA") - answer_index = 1 - soa_rrset = rrset.copy() - if rdtype == dns.rdatatype.IXFR: - if soa_rrset[0].serial <= serial: + tcpmsg = struct.pack("!H", l) + wire + _net_write(s, tcpmsg, expiration) + done = False + delete_mode = True + expecting_SOA = False + soa_rrset = None + if relativize: + origin = zone + oname = dns.name.empty + else: + origin = None + oname = zone + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or \ + (expiration is not None and mexpiration > expiration): + mexpiration = expiration + if use_udp: + _wait_for_readable(s, expiration) + (wire, from_address) = s.recvfrom(65535) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + wire = _net_read(s, l, mexpiration) + is_ixfr = (rdtype == dns.rdatatype.IXFR) + r = dns.message.from_wire(wire, keyring=q.keyring, + request_mac=q.mac, xfr=True, + origin=origin, tsig_ctx=tsig_ctx, + multi=True, one_rr_per_rrset=is_ixfr) + rcode = r.rcode() + if rcode != dns.rcode.NOERROR: + raise TransferError(rcode) + tsig_ctx = r.tsig_ctx + answer_index = 0 + if soa_rrset is None: + if not r.answer or r.answer[0].name != oname: + raise dns.exception.FormError( + "No answer or RRset not for qname") + rrset = r.answer[0] + if rrset.rdtype != dns.rdatatype.SOA: + raise dns.exception.FormError("first RRset is not an SOA") + answer_index = 1 + soa_rrset = rrset.copy() + if rdtype == dns.rdatatype.IXFR: + if dns.serial.Serial(soa_rrset[0].serial) <= serial: + # + # We're already up-to-date. + # + done = True + else: + expecting_SOA = True + # + # Process SOAs in the answer section (other than the initial + # SOA in the first message). + # + for rrset in r.answer[answer_index:]: + if done: + raise dns.exception.FormError("answers after final SOA") + if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: + if expecting_SOA: + if rrset[0].serial != serial: + raise dns.exception.FormError( + "IXFR base serial mismatch") + expecting_SOA = False + elif rdtype == dns.rdatatype.IXFR: + delete_mode = not delete_mode # - # We're already up-to-date. + # If this SOA RRset is equal to the first we saw then we're + # finished. If this is an IXFR we also check that we're + # seeing the record in the expected part of the response. # - done = True - else: - expecting_SOA = True - # - # Process SOAs in the answer section (other than the initial - # SOA in the first message). - # - for rrset in r.answer[answer_index:]: - if done: - raise dns.exception.FormError("answers after final SOA") - if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: - if expecting_SOA: - if rrset[0].serial != serial: - raise dns.exception.FormError( - "IXFR base serial mismatch") + if rrset == soa_rrset and \ + (rdtype == dns.rdatatype.AXFR or + (rdtype == dns.rdatatype.IXFR and delete_mode)): + done = True + elif expecting_SOA: + # + # We made an IXFR request and are expecting another + # SOA RR, but saw something else, so this must be an + # AXFR response. + # + rdtype = dns.rdatatype.AXFR expecting_SOA = False - elif rdtype == dns.rdatatype.IXFR: - delete_mode = not delete_mode - # - # If this SOA RRset is equal to the first we saw then we're - # finished. If this is an IXFR we also check that we're seeing - # the record in the expected part of the response. - # - if rrset == soa_rrset and \ - (rdtype == dns.rdatatype.AXFR or - (rdtype == dns.rdatatype.IXFR and delete_mode)): - done = True - elif expecting_SOA: - # - # We made an IXFR request and are expecting another - # SOA RR, but saw something else, so this must be an - # AXFR response. - # - rdtype = dns.rdatatype.AXFR - expecting_SOA = False - if done and q.keyring and not r.had_tsig: - raise dns.exception.FormError("missing TSIG") - yield r - s.close() + if done and q.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + yield r diff --git a/lib/dns/rcode.py b/lib/dns/rcode.py index 314815f7..d9ea0051 100644 --- a/lib/dns/rcode.py +++ b/lib/dns/rcode.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,78 +17,90 @@ """DNS Result Codes.""" +import dns.enum import dns.exception -from ._compat import long +class Rcode(dns.enum.IntEnum): + #: No error + NOERROR = 0 + #: Format error + FORMERR = 1 + #: Server failure + SERVFAIL = 2 + #: Name does not exist ("Name Error" in RFC 1025 terminology). + NXDOMAIN = 3 + #: Not implemented + NOTIMP = 4 + #: Refused + REFUSED = 5 + #: Name exists. + YXDOMAIN = 6 + #: RRset exists. + YXRRSET = 7 + #: RRset does not exist. + NXRRSET = 8 + #: Not authoritative. + NOTAUTH = 9 + #: Name not in zone. + NOTZONE = 10 + #: DSO-TYPE Not Implemented + DSOTYPENI = 11 + #: Bad EDNS version. + BADVERS = 16 + #: TSIG Signature Failure + BADSIG = 16 + #: Key not recognized. + BADKEY = 17 + #: Signature out of time window. + BADTIME = 18 + #: Bad TKEY Mode. + BADMODE = 19 + #: Duplicate key name. + BADNAME = 20 + #: Algorithm not supported. + BADALG = 21 + #: Bad Truncation + BADTRUNC = 22 + #: Bad/missing Server Cookie + BADCOOKIE = 23 -NOERROR = 0 -FORMERR = 1 -SERVFAIL = 2 -NXDOMAIN = 3 -NOTIMP = 4 -REFUSED = 5 -YXDOMAIN = 6 -YXRRSET = 7 -NXRRSET = 8 -NOTAUTH = 9 -NOTZONE = 10 -BADVERS = 16 + @classmethod + def _maximum(cls): + return 4095 -_by_text = { - 'NOERROR': NOERROR, - 'FORMERR': FORMERR, - 'SERVFAIL': SERVFAIL, - 'NXDOMAIN': NXDOMAIN, - 'NOTIMP': NOTIMP, - 'REFUSED': REFUSED, - 'YXDOMAIN': YXDOMAIN, - 'YXRRSET': YXRRSET, - 'NXRRSET': NXRRSET, - 'NOTAUTH': NOTAUTH, - 'NOTZONE': NOTZONE, - 'BADVERS': BADVERS -} - -# We construct the inverse mapping programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mapping not to be a true inverse. - -_by_value = dict((y, x) for x, y in _by_text.items()) + @classmethod + def _unknown_exception_class(cls): + return UnknownRcode +globals().update(Rcode.__members__) class UnknownRcode(dns.exception.DNSException): - """A DNS rcode is unknown.""" def from_text(text): """Convert text into an rcode. - @param text: the textual rcode - @type text: string - @raises UnknownRcode: the rcode is unknown - @rtype: int + *text*, a ``str``, the textual rcode or an integer in textual form. + + Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown. + + Returns an ``int``. """ - if text.isdigit(): - v = int(text) - if v >= 0 and v <= 4095: - return v - v = _by_text.get(text.upper()) - if v is None: - raise UnknownRcode - return v + return Rcode.from_text(text) def from_flags(flags, ednsflags): """Return the rcode value encoded by flags and ednsflags. - @param flags: the DNS flags - @type flags: int - @param ednsflags: the EDNS flags - @type ednsflags: int - @raises ValueError: rcode is < 0 or > 4095 - @rtype: int + *flags*, an ``int``, the DNS flags field. + + *ednsflags*, an ``int``, the EDNS flags field. + + Raises ``ValueError`` if rcode is < 0 or > 4095 + + Returns an ``int``. """ value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) @@ -98,28 +112,30 @@ def from_flags(flags, ednsflags): def to_flags(value): """Return a (flags, ednsflags) tuple which encodes the rcode. - @param value: the rcode - @type value: int - @raises ValueError: rcode is < 0 or > 4095 - @rtype: (int, int) tuple + *value*, an ``int``, the rcode. + + Raises ``ValueError`` if rcode is < 0 or > 4095. + + Returns an ``(int, int)`` tuple. """ if value < 0 or value > 4095: raise ValueError('rcode must be >= 0 and <= 4095') v = value & 0xf - ev = long(value & 0xff0) << 20 + ev = (value & 0xff0) << 20 return (v, ev) -def to_text(value): +def to_text(value, tsig=False): """Convert rcode into text. - @param value: the rcode - @type value: int - @rtype: string + *value*, an ``int``, the rcode. + + Raises ``ValueError`` if rcode is < 0 or > 4095. + + Returns a ``str``. """ - text = _by_value.get(value) - if text is None: - text = str(value) - return text + if tsig and value == Rcode.BADVERS: + return 'BADSIG' + return Rcode.to_text(value) diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index 824731c7..e114fe32 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,95 +15,68 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS rdata. +"""DNS rdata.""" -@var _rdata_modules: A dictionary mapping a (rdclass, rdtype) tuple to -the module which implements that type. -@type _rdata_modules: dict -@var _module_prefix: The prefix to use when forming modules names. The -default is 'dns.rdtypes'. Changing this value will break the library. -@type _module_prefix: string -@var _hex_chunk: At most this many octets that will be represented in each -chunk of hexstring that _hexify() produces before whitespace occurs. -@type _hex_chunk: int""" - -from io import BytesIO +from importlib import import_module import base64 import binascii -import struct +import io +import inspect +import itertools +import dns.wire import dns.exception import dns.name import dns.rdataclass import dns.rdatatype import dns.tokenizer -import dns.wiredata -from ._compat import xrange, string_types, text_type -_hex_chunksize = 32 +_chunksize = 32 -def _hexify(data, chunksize=_hex_chunksize): +def _wordbreak(data, chunksize=_chunksize): + """Break a binary string into chunks of chunksize characters separated by + a space. + """ + + if not chunksize: + return data.decode() + return b' '.join([data[i:i + chunksize] + for i + in range(0, len(data), chunksize)]).decode() + + +def _hexify(data, chunksize=_chunksize): """Convert a binary string into its hex encoding, broken up into chunks - of I{chunksize} characters separated by a space. - - @param data: the binary string - @type data: string - @param chunksize: the chunk size. Default is L{dns.rdata._hex_chunksize} - @rtype: string + of chunksize characters separated by a space. """ - line = binascii.hexlify(data) - return b' '.join([line[i:i + chunksize] - for i - in range(0, len(line), chunksize)]).decode() - -_base64_chunksize = 32 + return _wordbreak(binascii.hexlify(data), chunksize) -def _base64ify(data, chunksize=_base64_chunksize): +def _base64ify(data, chunksize=_chunksize): """Convert a binary string into its base64 encoding, broken up into chunks - of I{chunksize} characters separated by a space. - - @param data: the binary string - @type data: string - @param chunksize: the chunk size. Default is - L{dns.rdata._base64_chunksize} - @rtype: string + of chunksize characters separated by a space. """ - line = base64.b64encode(data) - return b' '.join([line[i:i + chunksize] - for i - in range(0, len(line), chunksize)]).decode() - -__escaped = { - '"': True, - '\\': True, -} + return _wordbreak(base64.b64encode(data), chunksize) +__escaped = b'"\\' def _escapify(qstring): - """Escape the characters in a quoted string which need it. + """Escape the characters in a quoted string which need it.""" - @param qstring: the string - @type qstring: string - @returns: the escaped string - @rtype: string - """ - - if isinstance(qstring, text_type): + if isinstance(qstring, str): qstring = qstring.encode() if not isinstance(qstring, bytearray): qstring = bytearray(qstring) text = '' for c in qstring: - packed = struct.pack('!B', c).decode() - if packed in __escaped: - text += '\\' + packed + if c in __escaped: + text += '\\' + chr(c) elif c >= 0x20 and c < 0x7F: - text += packed + text += chr(c) else: text += '\\%03d' % c return text @@ -110,43 +85,85 @@ def _escapify(qstring): def _truncate_bitmap(what): """Determine the index of greatest byte that isn't all zeros, and return the bitmap that contains all the bytes less than that index. - - @param what: a string of octets representing a bitmap. - @type what: string - @rtype: string """ - for i in xrange(len(what) - 1, -1, -1): + for i in range(len(what) - 1, -1, -1): if what[i] != 0: - break - return what[0: i + 1] + return what[0: i + 1] + return what[0:1] - -class Rdata(object): - - """Base class for all DNS rdata types. +def _constify(o): """ + Convert mutable types to immutable types. + """ + if isinstance(o, bytearray): + return bytes(o) + if isinstance(o, tuple): + try: + hash(o) + return o + except Exception: + return tuple(_constify(elt) for elt in o) + if isinstance(o, list): + return tuple(_constify(elt) for elt in o) + return o + +class Rdata: + """Base class for all DNS rdata types.""" __slots__ = ['rdclass', 'rdtype'] def __init__(self, rdclass, rdtype): """Initialize an rdata. - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. """ - self.rdclass = rdclass - self.rdtype = rdtype + object.__setattr__(self, 'rdclass', rdclass) + object.__setattr__(self, 'rdtype', rdtype) + + def __setattr__(self, name, value): + # Rdatas are immutable + raise TypeError("object doesn't support attribute assignment") + + def __delattr__(self, name): + # Rdatas are immutable + raise TypeError("object doesn't support attribute deletion") + + def _get_all_slots(self): + return itertools.chain.from_iterable(getattr(cls, '__slots__', []) + for cls in self.__class__.__mro__) + + def __getstate__(self): + # We used to try to do a tuple of all slots here, but it + # doesn't work as self._all_slots isn't available at + # __setstate__() time. Before that we tried to store a tuple + # of __slots__, but that didn't work as it didn't store the + # slots defined by ancestors. This older way didn't fail + # outright, but ended up with partially broken objects, e.g. + # if you unpickled an A RR it wouldn't have rdclass and rdtype + # attributes, and would compare badly. + state = {} + for slot in self._get_all_slots(): + state[slot] = getattr(self, slot) + return state + + def __setstate__(self, state): + for slot, val in state.items(): + object.__setattr__(self, slot, val) def covers(self): - """DNS SIG/RRSIG rdatas apply to a specific type; this type is + """Return the type a Rdata covers. + + DNS SIG/RRSIG rdatas apply to a specific type; this type is returned by the covers() function. If the rdata type is not SIG or RRSIG, dns.rdatatype.NONE is returned. This is useful when creating rdatasets, allowing the rdataset to contain only RRSIGs of a particular type, e.g. RRSIG(NS). - @rtype: int + + Returns an ``int``. """ return dns.rdatatype.NONE @@ -155,38 +172,53 @@ class Rdata(object): """Return a 32-bit type value, the least significant 16 bits of which are the ordinary DNS type, and the upper 16 bits of which are the "covered" type, if any. - @rtype: int + + Returns an ``int``. """ return self.covers() << 16 | self.rdtype def to_text(self, origin=None, relativize=True, **kw): """Convert an rdata to text format. - @rtype: string + + Returns a ``str``. """ + raise NotImplementedError - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + raise NotImplementedError + + def to_wire(self, file=None, compress=None, origin=None, + canonicalize=False): """Convert an rdata to wire format. - @rtype: string + + Returns a ``bytes`` or ``None``. """ - raise NotImplementedError + if file: + return self._to_wire(file, compress, origin, canonicalize) + else: + f = io.BytesIO() + self._to_wire(f, compress, origin, canonicalize) + return f.getvalue() + + def to_generic(self, origin=None): + """Creates a dns.rdata.GenericRdata equivalent of this rdata. + + Returns a ``dns.rdata.GenericRdata``. + """ + return dns.rdata.GenericRdata(self.rdclass, self.rdtype, + self.to_wire(origin=origin)) def to_digestable(self, origin=None): """Convert rdata to a format suitable for digesting in hashes. This - is also the DNSSEC canonical form.""" - f = BytesIO() - self.to_wire(f, None, origin) - return f.getvalue() + is also the DNSSEC canonical form. - def validate(self): - """Check that the current contents of the rdata's fields are - valid. If you change an rdata by assigning to its fields, - it is a good idea to call validate() when you are done making - changes. + Returns a ``bytes``. """ - dns.rdata.from_text(self.rdclass, self.rdtype, self.to_text()) + + return self.to_wire(origin=origin, canonicalize=True) def __repr__(self): covers = self.covers() @@ -203,17 +235,20 @@ class Rdata(object): def _cmp(self, other): """Compare an rdata with another rdata of the same rdtype and - rdclass. Return < 0 if self < other in the DNSSEC ordering, - 0 if self == other, and > 0 if self > other. + rdclass. + + Return < 0 if self < other in the DNSSEC ordering, 0 if self + == other, and > 0 if self > other. + """ our = self.to_digestable(dns.name.root) their = other.to_digestable(dns.name.root) if our == their: return 0 - if our > their: + elif our > their: return 1 - - return -1 + else: + return -1 def __eq__(self, other): if not isinstance(other, Rdata): @@ -258,56 +293,55 @@ class Rdata(object): return hash(self.to_digestable(dns.name.root)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - """Build an rdata object from text format. - - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param tok: The tokenizer - @type tok: dns.tokenizer.Tokenizer - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @param relativize: should names be relativized? - @type relativize: bool - @rtype: dns.rdata.Rdata instance - """ - + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): raise NotImplementedError @classmethod def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - """Build an rdata object from wire format - - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param rdlen: The length of the wire-format rdata - @type rdlen: int - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @rtype: dns.rdata.Rdata instance - """ - raise NotImplementedError - def choose_relativity(self, origin=None, relativize=True): - """Convert any domain names in the rdata to the specified - relativization. + def replace(self, **kwargs): + """ + Create a new Rdata instance based on the instance replace was + invoked on. It is possible to pass different parameters to + override the corresponding properties of the base Rdata. + + Any field specific to the Rdata type can be replaced, but the + *rdtype* and *rdclass* fields cannot. + + Returns an instance of the same Rdata subclass as *self*. """ - pass + # Get the constructor parameters. + parameters = inspect.signature(self.__init__).parameters + + # Ensure that all of the arguments correspond to valid fields. + # Don't allow rdclass or rdtype to be changed, though. + for key in kwargs: + if key not in parameters: + raise AttributeError("'{}' object has no attribute '{}'" + .format(self.__class__.__name__, key)) + if key in ('rdclass', 'rdtype'): + raise AttributeError("Cannot overwrite '{}' attribute '{}'" + .format(self.__class__.__name__, key)) + + # Construct the parameter list. For each field, use the value in + # kwargs if present, and the current value otherwise. + args = (kwargs.get(key, getattr(self, key)) for key in parameters) + + # Create, validate, and return the new object. + # + # Note that if we make constructors do validation in the future, + # this validation can go away. + rd = self.__class__(*args) + dns.rdata.from_text(rd.rdclass, rd.rdtype, rd.to_text()) + return rd class GenericRdata(Rdata): - """Generate Rdata Class + """Generic Rdata Class This class is used for rdata types for which we have no better implementation. It implements the DNS "unknown RRs" scheme. @@ -316,16 +350,17 @@ class GenericRdata(Rdata): __slots__ = ['data'] def __init__(self, rdclass, rdtype, data): - super(GenericRdata, self).__init__(rdclass, rdtype) - self.data = data + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'data', data) def to_text(self, origin=None, relativize=True, **kw): return r'\# %d ' % len(self.data) + _hexify(self.data) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): token = tok.get() - if not token.is_identifier() or token.value != '\#': + if not token.is_identifier() or token.value != r'\#': raise dns.exception.SyntaxError( r'generic rdata does not start with \#') length = tok.get_int() @@ -342,52 +377,46 @@ class GenericRdata(Rdata): 'generic rdata hex data has wrong length') return cls(rdclass, rdtype, data) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(self.data) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - return cls(rdclass, rdtype, wire[current: current + rdlen]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + return cls(rdclass, rdtype, parser.get_remaining()) -_rdata_modules = {} +_rdata_classes = {} _module_prefix = 'dns.rdtypes' - def get_rdata_class(rdclass, rdtype): - - def import_module(name): - mod = __import__(name) - components = name.split('.') - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - mod = _rdata_modules.get((rdclass, rdtype)) - rdclass_text = dns.rdataclass.to_text(rdclass) - rdtype_text = dns.rdatatype.to_text(rdtype) - rdtype_text = rdtype_text.replace('-', '_') - if not mod: - mod = _rdata_modules.get((dns.rdatatype.ANY, rdtype)) - if not mod: + cls = _rdata_classes.get((rdclass, rdtype)) + if not cls: + cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype)) + if not cls: + rdclass_text = dns.rdataclass.to_text(rdclass) + rdtype_text = dns.rdatatype.to_text(rdtype) + rdtype_text = rdtype_text.replace('-', '_') try: mod = import_module('.'.join([_module_prefix, rdclass_text, rdtype_text])) - _rdata_modules[(rdclass, rdtype)] = mod + cls = getattr(mod, rdtype_text) + _rdata_classes[(rdclass, rdtype)] = cls except ImportError: try: mod = import_module('.'.join([_module_prefix, 'ANY', rdtype_text])) - _rdata_modules[(dns.rdataclass.ANY, rdtype)] = mod + cls = getattr(mod, rdtype_text) + _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls + _rdata_classes[(rdclass, rdtype)] = cls except ImportError: - mod = None - if mod: - cls = getattr(mod, rdtype_text) - else: + pass + if not cls: cls = GenericRdata + _rdata_classes[(rdclass, rdtype)] = cls return cls -def from_text(rdclass, rdtype, tok, origin=None, relativize=True): +def from_text(rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None, idna_codec=None): """Build an rdata object from text format. This function attempts to dynamically load a class which @@ -398,23 +427,37 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True): Once a class is chosen, its from_text() class method is called with the parameters to this function. - If I{tok} is a string, then a tokenizer is created and the string + If *tok* is a ``str``, then a tokenizer is created and the string is used as its input. - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param tok: The tokenizer or input text - @type tok: dns.tokenizer.Tokenizer or string - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @param relativize: Should names be relativized? - @type relativize: bool - @rtype: dns.rdata.Rdata instance""" + *rdclass*, an ``int``, the rdataclass. - if isinstance(tok, string_types): - tok = dns.tokenizer.Tokenizer(tok) + *rdtype*, an ``int``, the rdatatype. + + *tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use if a tokenizer needs to be created. If + ``None``, the default IDNA 2003 encoder/decoder is used. If a + tokenizer is not created, then the codec associated with the tokenizer + is the one that is used. + + Returns an instance of the chosen Rdata subclass. + + """ + + if isinstance(tok, str): + tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) if cls != GenericRdata: # peek at first token @@ -428,10 +471,41 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True): # from_wire on it. # rdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, - relativize) + relativize, relativize_to) return from_wire(rdclass, rdtype, rdata.data, 0, len(rdata.data), origin) - return cls.from_text(rdclass, rdtype, tok, origin, relativize) + return cls.from_text(rdclass, rdtype, tok, origin, relativize, + relativize_to) + + +def from_wire_parser(rdclass, rdtype, parser, origin=None): + """Build an rdata object from wire format + + This function attempts to dynamically load a class which + implements the specified rdata class and type. If there is no + class-and-type-specific implementation, the GenericRdata class + is used. + + Once a class is chosen, its from_wire() class method is called + with the parameters to this function. + + *rdclass*, an ``int``, the rdataclass. + + *rdtype*, an ``int``, the rdatatype. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restricted to the rdata length. + + *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``, + then names will be relativized to this origin. + + Returns an instance of the chosen Rdata subclass. + """ + + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + cls = get_rdata_class(rdclass, rdtype) + return cls.from_wire_parser(rdclass, rdtype, parser, origin) def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): @@ -445,20 +519,60 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): Once a class is chosen, its from_wire() class method is called with the parameters to this function. - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param rdlen: The length of the wire-format rdata - @type rdlen: int - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @rtype: dns.rdata.Rdata instance""" + *rdclass*, an ``int``, the rdataclass. - wire = dns.wiredata.maybe_wrap(wire) - cls = get_rdata_class(rdclass, rdtype) - return cls.from_wire(rdclass, rdtype, wire, current, rdlen, origin) + *rdtype*, an ``int``, the rdatatype. + + *wire*, a ``bytes``, the wire-format message. + + *current*, an ``int``, the offset in wire of the beginning of + the rdata. + + *rdlen*, an ``int``, the length of the wire-format rdata + + *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``, + then names will be relativized to this origin. + + Returns an instance of the chosen Rdata subclass. + """ + parser = dns.wire.Parser(wire, current) + with parser.restrict_to(rdlen): + return from_wire_parser(rdclass, rdtype, parser, origin) + + +class RdatatypeExists(dns.exception.DNSException): + """DNS rdatatype already exists.""" + supp_kwargs = {'rdclass', 'rdtype'} + fmt = "The rdata type with class {rdclass} and rdtype {rdtype} " + \ + "already exists." + + +def register_type(implementation, rdtype, rdtype_text, is_singleton=False, + rdclass=dns.rdataclass.IN): + """Dynamically register a module to handle an rdatatype. + + *implementation*, a module implementing the type in the usual dnspython + way. + + *rdtype*, an ``int``, the rdatatype to register. + + *rdtype_text*, a ``str``, the textual form of the rdatatype. + + *is_singleton*, a ``bool``, indicating if the type is a singleton (i.e. + RRsets of the type can have only one member.) + + *rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if + it applies to all classes. + """ + + existing_cls = get_rdata_class(rdclass, rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + try: + if dns.rdatatype.RdataType(rdtype).name != rdtype_text: + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + except ValueError: + pass + _rdata_classes[(rdclass, rdtype)] = getattr(implementation, + rdtype_text.replace('-', '_')) + dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/lib/dns/rdataclass.py b/lib/dns/rdataclass.py index 17a4810d..7943a95a 100644 --- a/lib/dns/rdataclass.py +++ b/lib/dns/rdataclass.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,105 +15,87 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS Rdata Classes. - -@var _by_text: The rdata class textual name to value mapping -@type _by_text: dict -@var _by_value: The rdata class value to textual name mapping -@type _by_value: dict -@var _metaclasses: If an rdataclass is a metaclass, there will be a mapping -whose key is the rdatatype value and whose value is True in this dictionary. -@type _metaclasses: dict""" - -import re +"""DNS Rdata Classes.""" +import dns.enum import dns.exception -RESERVED0 = 0 -IN = 1 -CH = 3 -HS = 4 -NONE = 254 -ANY = 255 +class RdataClass(dns.enum.IntEnum): + """DNS Rdata Class""" + RESERVED0 = 0 + IN = 1 + INTERNET = IN + CH = 3 + CHAOS = CH + HS = 4 + HESIOD = HS + NONE = 254 + ANY = 255 -_by_text = { - 'RESERVED0': RESERVED0, - 'IN': IN, - 'CH': CH, - 'HS': HS, - 'NONE': NONE, - 'ANY': ANY -} + @classmethod + def _maximum(cls): + return 65535 -# We construct the inverse mapping programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mapping not to be true inverse. + @classmethod + def _short_name(cls): + return "class" -_by_value = dict((y, x) for x, y in _by_text.items()) + @classmethod + def _prefix(cls): + return "CLASS" -# Now that we've built the inverse map, we can add class aliases to -# the _by_text mapping. + @classmethod + def _unknown_exception_class(cls): + return UnknownRdataclass -_by_text.update({ - 'INTERNET': IN, - 'CHAOS': CH, - 'HESIOD': HS -}) +globals().update(RdataClass.__members__) -_metaclasses = { - NONE: True, - ANY: True -} - -_unknown_class_pattern = re.compile('CLASS([0-9]+)$', re.I) +_metaclasses = {RdataClass.NONE, RdataClass.ANY} class UnknownRdataclass(dns.exception.DNSException): - """A DNS class is unknown.""" def from_text(text): """Convert text into a DNS rdata class value. - @param text: the text - @type text: string - @rtype: int - @raises dns.rdataclass.UnknownRdataclass: the class is unknown - @raises ValueError: the rdata class value is not >= 0 and <= 65535 + + The input text can be a defined DNS RR class mnemonic or + instance of the DNS generic class syntax. + + For example, "IN" and "CLASS1" will both result in a value of 1. + + Raises ``dns.rdatatype.UnknownRdataclass`` if the class is unknown. + + Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. + + Returns an ``int``. """ - value = _by_text.get(text.upper()) - if value is None: - match = _unknown_class_pattern.match(text) - if match is None: - raise UnknownRdataclass - value = int(match.group(1)) - if value < 0 or value > 65535: - raise ValueError("class must be between >= 0 and <= 65535") - return value + return RdataClass.from_text(text) def to_text(value): - """Convert a DNS rdata class to text. - @param value: the rdata class value - @type value: int - @rtype: string - @raises ValueError: the rdata class value is not >= 0 and <= 65535 + """Convert a DNS rdata class value to text. + + If the value has a known mnemonic, it will be used, otherwise the + DNS generic class syntax will be used. + + Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. + + Returns a ``str``. """ - if value < 0 or value > 65535: - raise ValueError("class must be between >= 0 and <= 65535") - text = _by_value.get(value) - if text is None: - text = 'CLASS' + repr(value) - return text + return RdataClass.to_text(value) def is_metaclass(rdclass): - """True if the class is a metaclass. - @param rdclass: the rdata class - @type rdclass: int - @rtype: bool""" + """True if the specified class is a metaclass. + + The currently defined metaclasses are ANY and NONE. + + *rdclass* is an ``int``. + """ if rdclass in _metaclasses: return True diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index db266f2f..660415e7 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,8 +17,8 @@ """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" +import io import random -from io import StringIO import struct import dns.exception @@ -24,60 +26,46 @@ import dns.rdatatype import dns.rdataclass import dns.rdata import dns.set -from ._compat import string_types # define SimpleSet here for backwards compatibility SimpleSet = dns.set.Set class DifferingCovers(dns.exception.DNSException): - """An attempt was made to add a DNS SIG/RRSIG whose covered type is not the same as that of the other rdatas in the rdataset.""" class IncompatibleTypes(dns.exception.DNSException): - """An attempt was made to add DNS RR data of an incompatible type.""" class Rdataset(dns.set.Set): - """A DNS rdataset. - - @ivar rdclass: The class of the rdataset - @type rdclass: int - @ivar rdtype: The type of the rdataset - @type rdtype: int - @ivar covers: The covered type. Usually this value is - dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or - dns.rdatatype.RRSIG, then the covers value will be the rdata - type the SIG/RRSIG covers. The library treats the SIG and RRSIG - types as if they were a family of - types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much - easier to work with than if RRSIGs covering different rdata - types were aggregated into a single RRSIG rdataset. - @type covers: int - @ivar ttl: The DNS TTL (Time To Live) value - @type ttl: int - """ + """A DNS rdataset.""" __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl'] - def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0): """Create a new rdataset of the specified class and type. - @see: the description of the class instance variables for the - meaning of I{rdclass} and I{rdtype}""" + *rdclass*, an ``int``, the rdataclass. - super(Rdataset, self).__init__() + *rdtype*, an ``int``, the rdatatype. + + *covers*, an ``int``, the covered rdatatype. + + *ttl*, an ``int``, the TTL. + """ + + super().__init__() self.rdclass = rdclass self.rdtype = rdtype self.covers = covers - self.ttl = 0 + self.ttl = ttl def _clone(self): - obj = super(Rdataset, self)._clone() + obj = super()._clone() obj.rdclass = self.rdclass obj.rdtype = self.rdtype obj.covers = self.covers @@ -85,11 +73,14 @@ class Rdataset(dns.set.Set): return obj def update_ttl(self, ttl): - """Set the TTL of the rdataset to be the lesser of the set's current + """Perform TTL minimization. + + Set the TTL of the rdataset to be the lesser of the set's current TTL or the specified TTL. If the set contains no rdatas, set the TTL to the specified TTL. - @param ttl: The TTL - @type ttl: int""" + + *ttl*, an ``int``. + """ if len(self) == 0: self.ttl = ttl @@ -99,13 +90,19 @@ class Rdataset(dns.set.Set): def add(self, rd, ttl=None): """Add the specified rdata to the rdataset. - If the optional I{ttl} parameter is supplied, then - self.update_ttl(ttl) will be called prior to adding the rdata. + If the optional *ttl* parameter is supplied, then + ``self.update_ttl(ttl)`` will be called prior to adding the rdata. - @param rd: The rdata - @type rd: dns.rdata.Rdata object - @param ttl: The TTL - @type ttl: int""" + *rd*, a ``dns.rdata.Rdata``, the rdata + + *ttl*, an ``int``, the TTL. + + Raises ``dns.rdataset.IncompatibleTypes`` if the type and class + do not match the type and class of the rdataset. + + Raises ``dns.rdataset.DifferingCovers`` if the type is a signature + type and the covered type does not match that of the rdataset. + """ # # If we're adding a signature, do some special handling to @@ -126,24 +123,33 @@ class Rdataset(dns.set.Set): raise DifferingCovers if dns.rdatatype.is_singleton(rd.rdtype) and len(self) > 0: self.clear() - super(Rdataset, self).add(rd) + super().add(rd) def union_update(self, other): self.update_ttl(other.ttl) - super(Rdataset, self).union_update(other) + super().union_update(other) def intersection_update(self, other): self.update_ttl(other.ttl) - super(Rdataset, self).intersection_update(other) + super().intersection_update(other) def update(self, other): """Add all rdatas in other to self. - @param other: The rdataset from which to update - @type other: dns.rdataset.Rdataset object""" + *other*, a ``dns.rdataset.Rdataset``, the rdataset from which + to update. + """ self.update_ttl(other.ttl) - super(Rdataset, self).update(other) + super().update(other) + + def _rdata_repr(self): + def maybe_truncate(s): + if len(s) > 100: + return s[:100] + '...' + return s + return '[%s]' % ', '.join('<%s>' % maybe_truncate(str(rr)) + for rr in self) def __repr__(self): if self.covers == 0: @@ -151,23 +157,20 @@ class Rdataset(dns.set.Set): else: ctext = '(' + dns.rdatatype.to_text(self.covers) + ')' return '' + dns.rdatatype.to_text(self.rdtype) + ctext + \ + ' rdataset: ' + self._rdata_repr() + '>' def __str__(self): return self.to_text() def __eq__(self, other): - """Two rdatasets are equal if they have the same class, type, and - covers, and contain the same rdata. - @rtype: bool""" - if not isinstance(other, Rdataset): return False if self.rdclass != other.rdclass or \ self.rdtype != other.rdtype or \ self.covers != other.covers: return False - return super(Rdataset, self).__eq__(other) + return super().__eq__(other) def __ne__(self, other): return not self.__eq__(other) @@ -176,20 +179,23 @@ class Rdataset(dns.set.Set): override_rdclass=None, **kw): """Convert the rdataset into DNS master file format. - @see: L{dns.name.Name.choose_relativity} for more information - on how I{origin} and I{relativize} determine the way names + See ``dns.name.Name.choose_relativity`` for more information + on how *origin* and *relativize* determine the way names are emitted. Any additional keyword arguments are passed on to the rdata - to_text() method. + ``to_text()`` method. + + *name*, a ``dns.name.Name``. If name is not ``None``, emit RRs with + *name* as the owner name. + + *origin*, a ``dns.name.Name`` or ``None``, the origin for relative + names. + + *relativize*, a ``bool``. If ``True``, names will be relativized + to *origin*. + """ - @param name: If name is not None, emit a RRs with I{name} as - the owner name. - @type name: dns.name.Name object - @param origin: The origin for relative names, or None. - @type origin: dns.name.Name object - @param relativize: True if names should names be relativized - @type relativize: bool""" if name is not None: name = name.choose_relativity(origin, relativize) ntext = str(name) @@ -197,7 +203,7 @@ class Rdataset(dns.set.Set): else: ntext = '' pad = '' - s = StringIO() + s = io.StringIO() if override_rdclass is not None: rdclass = override_rdclass else: @@ -208,12 +214,12 @@ class Rdataset(dns.set.Set): # some dynamic updates, so we don't need to print out the TTL # (which is meaningless anyway). # - s.write(u'%s%s%s %s\n' % (ntext, pad, - dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype))) + s.write('{}{}{} {}\n'.format(ntext, pad, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype))) else: for rd in self: - s.write(u'%s%s%d %s %s %s\n' % + s.write('%s%s%d %s %s %s\n' % (ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass), dns.rdatatype.to_text(self.rdtype), rd.to_text(origin=origin, relativize=relativize, @@ -227,16 +233,26 @@ class Rdataset(dns.set.Set): override_rdclass=None, want_shuffle=True): """Convert the rdataset to wire format. - @param name: The owner name of the RRset that will be emitted - @type name: dns.name.Name object - @param file: The file to which the wire format data will be appended - @type file: file - @param compress: The compression table to use; the default is None. - @type compress: dict - @param origin: The origin to be appended to any relative names when - they are emitted. The default is None. - @returns: the number of records emitted - @rtype: int + *name*, a ``dns.name.Name`` is the owner name to use. + + *file* is the file where the name is emitted (typically a + BytesIO file). + + *compress*, a ``dict``, is the compression table to use. If + ``None`` (the default), names will not be compressed. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then *origin* will be appended + to it. + + *override_rdclass*, an ``int``, is used as the class instead of the + class of the rdataset. This is useful when rendering rdatasets + associated with dynamic updates. + + *want_shuffle*, a ``bool``. If ``True``, then the order of the + Rdatas within the Rdataset will be shuffled before rendering. + + Returns an ``int``, the number of records emitted. """ if override_rdclass is not None: @@ -272,8 +288,9 @@ class Rdataset(dns.set.Set): return len(self) def match(self, rdclass, rdtype, covers): - """Returns True if this rdataset matches the specified class, type, - and covers""" + """Returns ``True`` if this rdataset matches the specified class, + type, and covers. + """ if self.rdclass == rdclass and \ self.rdtype == rdtype and \ self.covers == covers: @@ -281,21 +298,23 @@ class Rdataset(dns.set.Set): return False -def from_text_list(rdclass, rdtype, ttl, text_rdatas): +def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None): """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. - @rtype: dns.rdataset.Rdataset object + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use; if ``None``, the default IDNA 2003 + encoder/decoder is used. + + Returns a ``dns.rdataset.Rdataset`` object. """ - if isinstance(rdclass, string_types): - rdclass = dns.rdataclass.from_text(rdclass) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) r = Rdataset(rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t) + rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, idna_codec=idna_codec) r.add(rd) return r @@ -304,7 +323,7 @@ def from_text(rdclass, rdtype, ttl, *text_rdatas): """Create an rdataset with the specified class, type, and TTL, and with the specified rdatas in text format. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ return from_text_list(rdclass, rdtype, ttl, text_rdatas) @@ -314,7 +333,7 @@ def from_rdata_list(ttl, rdatas): """Create an rdataset with the specified TTL, and with the specified list of rdata objects. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ if len(rdatas) == 0: @@ -332,7 +351,7 @@ def from_rdata(ttl, *rdatas): """Create an rdataset with the specified TTL, and with the specified rdata objects. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ return from_rdata_list(ttl, rdatas) diff --git a/lib/dns/rdatatype.py b/lib/dns/rdatatype.py index cde1a0a1..c793d5a0 100644 --- a/lib/dns/rdatatype.py +++ b/lib/dns/rdatatype.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,241 +15,207 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS Rdata Types. - -@var _by_text: The rdata type textual name to value mapping -@type _by_text: dict -@var _by_value: The rdata type value to textual name mapping -@type _by_value: dict -@var _metatypes: If an rdatatype is a metatype, there will be a mapping -whose key is the rdatatype value and whose value is True in this dictionary. -@type _metatypes: dict -@var _singletons: If an rdatatype is a singleton, there will be a mapping -whose key is the rdatatype value and whose value is True in this dictionary. -@type _singletons: dict""" - -import re +"""DNS Rdata Types.""" +import dns.enum import dns.exception -NONE = 0 -A = 1 -NS = 2 -MD = 3 -MF = 4 -CNAME = 5 -SOA = 6 -MB = 7 -MG = 8 -MR = 9 -NULL = 10 -WKS = 11 -PTR = 12 -HINFO = 13 -MINFO = 14 -MX = 15 -TXT = 16 -RP = 17 -AFSDB = 18 -X25 = 19 -ISDN = 20 -RT = 21 -NSAP = 22 -NSAP_PTR = 23 -SIG = 24 -KEY = 25 -PX = 26 -GPOS = 27 -AAAA = 28 -LOC = 29 -NXT = 30 -SRV = 33 -NAPTR = 35 -KX = 36 -CERT = 37 -A6 = 38 -DNAME = 39 -OPT = 41 -APL = 42 -DS = 43 -SSHFP = 44 -IPSECKEY = 45 -RRSIG = 46 -NSEC = 47 -DNSKEY = 48 -DHCID = 49 -NSEC3 = 50 -NSEC3PARAM = 51 -TLSA = 52 -HIP = 55 -CDS = 59 -CDNSKEY = 60 -CSYNC = 62 -SPF = 99 -UNSPEC = 103 -EUI48 = 108 -EUI64 = 109 -TKEY = 249 -TSIG = 250 -IXFR = 251 -AXFR = 252 -MAILB = 253 -MAILA = 254 -ANY = 255 -URI = 256 -CAA = 257 -TA = 32768 -DLV = 32769 +class RdataType(dns.enum.IntEnum): + """DNS Rdata Type""" + TYPE0 = 0 + NONE = 0 + A = 1 + NS = 2 + MD = 3 + MF = 4 + CNAME = 5 + SOA = 6 + MB = 7 + MG = 8 + MR = 9 + NULL = 10 + WKS = 11 + PTR = 12 + HINFO = 13 + MINFO = 14 + MX = 15 + TXT = 16 + RP = 17 + AFSDB = 18 + X25 = 19 + ISDN = 20 + RT = 21 + NSAP = 22 + NSAP_PTR = 23 + SIG = 24 + KEY = 25 + PX = 26 + GPOS = 27 + AAAA = 28 + LOC = 29 + NXT = 30 + SRV = 33 + NAPTR = 35 + KX = 36 + CERT = 37 + A6 = 38 + DNAME = 39 + OPT = 41 + APL = 42 + DS = 43 + SSHFP = 44 + IPSECKEY = 45 + RRSIG = 46 + NSEC = 47 + DNSKEY = 48 + DHCID = 49 + NSEC3 = 50 + NSEC3PARAM = 51 + TLSA = 52 + HIP = 55 + NINFO = 56 + CDS = 59 + CDNSKEY = 60 + OPENPGPKEY = 61 + CSYNC = 62 + SPF = 99 + UNSPEC = 103 + EUI48 = 108 + EUI64 = 109 + TKEY = 249 + TSIG = 250 + IXFR = 251 + AXFR = 252 + MAILB = 253 + MAILA = 254 + ANY = 255 + URI = 256 + CAA = 257 + AVC = 258 + AMTRELAY = 259 + TA = 32768 + DLV = 32769 -_by_text = { - 'NONE': NONE, - 'A': A, - 'NS': NS, - 'MD': MD, - 'MF': MF, - 'CNAME': CNAME, - 'SOA': SOA, - 'MB': MB, - 'MG': MG, - 'MR': MR, - 'NULL': NULL, - 'WKS': WKS, - 'PTR': PTR, - 'HINFO': HINFO, - 'MINFO': MINFO, - 'MX': MX, - 'TXT': TXT, - 'RP': RP, - 'AFSDB': AFSDB, - 'X25': X25, - 'ISDN': ISDN, - 'RT': RT, - 'NSAP': NSAP, - 'NSAP-PTR': NSAP_PTR, - 'SIG': SIG, - 'KEY': KEY, - 'PX': PX, - 'GPOS': GPOS, - 'AAAA': AAAA, - 'LOC': LOC, - 'NXT': NXT, - 'SRV': SRV, - 'NAPTR': NAPTR, - 'KX': KX, - 'CERT': CERT, - 'A6': A6, - 'DNAME': DNAME, - 'OPT': OPT, - 'APL': APL, - 'DS': DS, - 'SSHFP': SSHFP, - 'IPSECKEY': IPSECKEY, - 'RRSIG': RRSIG, - 'NSEC': NSEC, - 'DNSKEY': DNSKEY, - 'DHCID': DHCID, - 'NSEC3': NSEC3, - 'NSEC3PARAM': NSEC3PARAM, - 'TLSA': TLSA, - 'HIP': HIP, - 'CDS': CDS, - 'CDNSKEY': CDNSKEY, - 'CSYNC': CSYNC, - 'SPF': SPF, - 'UNSPEC': UNSPEC, - 'EUI48': EUI48, - 'EUI64': EUI64, - 'TKEY': TKEY, - 'TSIG': TSIG, - 'IXFR': IXFR, - 'AXFR': AXFR, - 'MAILB': MAILB, - 'MAILA': MAILA, - 'ANY': ANY, - 'URI': URI, - 'CAA': CAA, - 'TA': TA, - 'DLV': DLV, -} + @classmethod + def _maximum(cls): + return 65535 -# We construct the inverse mapping programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mapping not to be true inverse. + @classmethod + def _short_name(cls): + return "type" -_by_value = dict((y, x) for x, y in _by_text.items()) + @classmethod + def _prefix(cls): + return "TYPE" + @classmethod + def _unknown_exception_class(cls): + return UnknownRdatatype -_metatypes = { - OPT: True -} +_registered_by_text = {} +_registered_by_value = {} -_singletons = { - SOA: True, - NXT: True, - DNAME: True, - NSEC: True, - # CNAME is technically a singleton, but we allow multiple CNAMEs. -} +globals().update(RdataType.__members__) -_unknown_type_pattern = re.compile('TYPE([0-9]+)$', re.I) +_metatypes = {RdataType.OPT} + +_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, + RdataType.NSEC, RdataType.CNAME} class UnknownRdatatype(dns.exception.DNSException): - """DNS resource record type is unknown.""" def from_text(text): """Convert text into a DNS rdata type value. - @param text: the text - @type text: string - @raises dns.rdatatype.UnknownRdatatype: the type is unknown - @raises ValueError: the rdata type value is not >= 0 and <= 65535 - @rtype: int""" - value = _by_text.get(text.upper()) - if value is None: - match = _unknown_type_pattern.match(text) - if match is None: - raise UnknownRdatatype - value = int(match.group(1)) - if value < 0 or value > 65535: - raise ValueError("type must be between >= 0 and <= 65535") - return value + The input text can be a defined DNS RR type mnemonic or + instance of the DNS generic type syntax. + + For example, "NS" and "TYPE2" will both result in a value of 2. + + Raises ``dns.rdatatype.UnknownRdatatype`` if the type is unknown. + + Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. + + Returns an ``int``. + """ + + text = text.upper().replace('-', '_') + try: + return RdataType.from_text(text) + except UnknownRdatatype: + registered_type = _registered_by_text.get(text) + if registered_type: + return registered_type + raise def to_text(value): - """Convert a DNS rdata type to text. - @param value: the rdata type value - @type value: int - @raises ValueError: the rdata type value is not >= 0 and <= 65535 - @rtype: string""" + """Convert a DNS rdata type value to text. - if value < 0 or value > 65535: - raise ValueError("type must be between >= 0 and <= 65535") - text = _by_value.get(value) - if text is None: - text = 'TYPE' + repr(value) - return text + If the value has a known mnemonic, it will be used, otherwise the + DNS generic type syntax will be used. + + Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. + + Returns a ``str``. + """ + + text = RdataType.to_text(value) + if text.startswith("TYPE"): + registered_text = _registered_by_value.get(value) + if registered_text: + text = registered_text + return text.replace('_', '-') def is_metatype(rdtype): - """True if the type is a metatype. - @param rdtype: the type - @type rdtype: int - @rtype: bool""" + """True if the specified type is a metatype. - if rdtype >= TKEY and rdtype <= ANY or rdtype in _metatypes: - return True - return False + *rdtype* is an ``int``. + + The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA, + MAILB, ANY, and OPT. + + Returns a ``bool``. + """ + + return (256 > rdtype >= 128) or rdtype in _metatypes def is_singleton(rdtype): - """True if the type is a singleton. - @param rdtype: the type - @type rdtype: int - @rtype: bool""" + """Is the specified type a singleton type? + + Singleton types can only have a single rdata in an rdataset, or a single + RR in an RRset. + + The currently defined singleton types are CNAME, DNAME, NSEC, NXT, and + SOA. + + *rdtype* is an ``int``. + + Returns a ``bool``. + """ if rdtype in _singletons: return True return False + +# pylint: disable=redefined-outer-name +def register_type(rdtype, rdtype_text, is_singleton=False): + """Dynamically register an rdatatype. + + *rdtype*, an ``int``, the rdatatype to register. + + *rdtype_text*, a ``str``, the textual form of the rdatatype. + + *is_singleton*, a ``bool``, indicating if the type is a singleton (i.e. + RRsets of the type can have only one member.) + """ + + _registered_by_text[rdtype_text] = rdtype + _registered_by_value[rdtype] = rdtype_text + if is_singleton: + _singletons.add(rdtype) diff --git a/lib/dns/rdtypes/ANY/AFSDB.py b/lib/dns/rdtypes/ANY/AFSDB.py index f3d51540..40878900 100644 --- a/lib/dns/rdtypes/ANY/AFSDB.py +++ b/lib/dns/rdtypes/ANY/AFSDB.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,12 +20,7 @@ import dns.rdtypes.mxbase class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): - """AFSDB record - - @ivar subtype: the subtype value - @type subtype: int - @ivar hostname: the hostname name - @type hostname: dns.name.Name object""" + """AFSDB record""" # Use the property mechanism to make "subtype" an alias for the # "preference" attribute, and "hostname" an alias for the "exchange" @@ -36,18 +33,12 @@ class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): # implementation, but this way we don't copy code, and that's # good. - def get_subtype(self): + @property + def subtype(self): + "the AFSDB subtype" return self.preference - def set_subtype(self, subtype): - self.preference = subtype - - subtype = property(get_subtype, set_subtype) - - def get_hostname(self): + @property + def hostname(self): + "the AFSDB hostname" return self.exchange - - def set_hostname(self, hostname): - self.exchange = hostname - - hostname = property(get_hostname, set_hostname) diff --git a/lib/dns/rdtypes/ANY/AMTRELAY.py b/lib/dns/rdtypes/ANY/AMTRELAY.py new file mode 100644 index 00000000..4e012a27 --- /dev/null +++ b/lib/dns/rdtypes/ANY/AMTRELAY.py @@ -0,0 +1,79 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.rdtypes.util + + +class Relay(dns.rdtypes.util.Gateway): + name = 'AMTRELAY relay' + +class AMTRELAY(dns.rdata.Rdata): + + """AMTRELAY record""" + + # see: RFC 8777 + + __slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay'] + + def __init__(self, rdclass, rdtype, precedence, discovery_optional, + relay_type, relay): + super().__init__(rdclass, rdtype) + Relay(relay_type, relay).check() + object.__setattr__(self, 'precedence', precedence) + object.__setattr__(self, 'discovery_optional', discovery_optional) + object.__setattr__(self, 'relay_type', relay_type) + object.__setattr__(self, 'relay', relay) + + def to_text(self, origin=None, relativize=True, **kw): + relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) + return '%d %d %d %s' % (self.precedence, self.discovery_optional, + self.relay_type, relay) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + precedence = tok.get_uint8() + discovery_optional = tok.get_uint8() + if discovery_optional > 1: + raise dns.exception.SyntaxError('expecting 0 or 1') + discovery_optional = bool(discovery_optional) + relay_type = tok.get_uint8() + if relay_type > 0x7f: + raise dns.exception.SyntaxError('expecting an integer <= 127') + relay = Relay(relay_type).from_text(tok, origin, relativize, + relativize_to) + return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, + relay) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + relay_type = self.relay_type | (self.discovery_optional << 7) + header = struct.pack("!BB", self.precedence, relay_type) + file.write(header) + Relay(self.relay_type, self.relay).to_wire(file, compress, origin, + canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (precedence, relay_type) = parser.get_struct('!BB') + discovery_optional = bool(relay_type >> 7) + relay_type &= 0x7f + relay = Relay(relay_type).from_wire_parser(parser, origin) + return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, + relay) diff --git a/lib/dns/hash.py b/lib/dns/rdtypes/ANY/AVC.py similarity index 66% rename from lib/dns/hash.py rename to lib/dns/rdtypes/ANY/AVC.py index 27f7a7e2..1fa5ecfd 100644 --- a/lib/dns/hash.py +++ b/lib/dns/rdtypes/ANY/AVC.py @@ -1,4 +1,6 @@ -# Copyright (C) 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2016 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,20 +15,11 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""Hashing backwards compatibility wrapper""" - -import sys -import hashlib +import dns.rdtypes.txtbase -hashes = {} -hashes['MD5'] = hashlib.md5 -hashes['SHA1'] = hashlib.sha1 -hashes['SHA224'] = hashlib.sha224 -hashes['SHA256'] = hashlib.sha256 -hashes['SHA384'] = hashlib.sha384 -hashes['SHA512'] = hashlib.sha512 +class AVC(dns.rdtypes.txtbase.TXTBase): + """AVC record""" -def get(algorithm): - return hashes[algorithm.upper()] + # See: IANA dns parameters for AVC diff --git a/lib/dns/rdtypes/ANY/CAA.py b/lib/dns/rdtypes/ANY/CAA.py index e80d4693..b7edae87 100644 --- a/lib/dns/rdtypes/ANY/CAA.py +++ b/lib/dns/rdtypes/ANY/CAA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,23 +24,17 @@ import dns.tokenizer class CAA(dns.rdata.Rdata): - """CAA (Certification Authority Authorization) record + """CAA (Certification Authority Authorization) record""" - @ivar flags: the flags - @type flags: int - @ivar tag: the tag - @type tag: string - @ivar value: the value - @type value: string - @see: RFC 6844""" + # see: RFC 6844 __slots__ = ['flags', 'tag', 'value'] def __init__(self, rdclass, rdtype, flags, tag, value): - super(CAA, self).__init__(rdclass, rdtype) - self.flags = flags - self.tag = tag - self.value = value + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'flags', flags) + object.__setattr__(self, 'tag', tag) + object.__setattr__(self, 'value', value) def to_text(self, origin=None, relativize=True, **kw): return '%u %s "%s"' % (self.flags, @@ -46,7 +42,8 @@ class CAA(dns.rdata.Rdata): dns.rdata._escapify(self.value)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): flags = tok.get_uint8() tag = tok.get_string().encode() if len(tag) > 255: @@ -56,7 +53,7 @@ class CAA(dns.rdata.Rdata): value = tok.get_string().encode() return cls(rdclass, rdtype, flags, tag, value) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(struct.pack('!B', self.flags)) l = len(self.tag) assert l < 256 @@ -65,10 +62,8 @@ class CAA(dns.rdata.Rdata): file.write(self.value) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (flags, l) = struct.unpack('!BB', wire[current: current + 2]) - current += 2 - tag = wire[current: current + l] - value = wire[current + l:current + rdlen - 2] + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + flags = parser.get_uint8() + tag = parser.get_counted_bytes() + value = parser.get_remaining() return cls(rdclass, rdtype, flags, tag, value) - diff --git a/lib/dns/rdtypes/ANY/CDNSKEY.py b/lib/dns/rdtypes/ANY/CDNSKEY.py index 83f3d51f..72253183 100644 --- a/lib/dns/rdtypes/ANY/CDNSKEY.py +++ b/lib/dns/rdtypes/ANY/CDNSKEY.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -14,10 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dnskeybase -from dns.rdtypes.dnskeybase import flags_to_text_set, flags_from_text_set - - -__all__ = ['flags_to_text_set', 'flags_from_text_set'] +from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): diff --git a/lib/dns/rdtypes/ANY/CDS.py b/lib/dns/rdtypes/ANY/CDS.py index e1abfc36..a63041dd 100644 --- a/lib/dns/rdtypes/ANY/CDS.py +++ b/lib/dns/rdtypes/ANY/CDS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/CERT.py b/lib/dns/rdtypes/ANY/CERT.py index b7454409..62df241c 100644 --- a/lib/dns/rdtypes/ANY/CERT.py +++ b/lib/dns/rdtypes/ANY/CERT.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -54,27 +56,19 @@ def _ctype_to_text(what): class CERT(dns.rdata.Rdata): - """CERT record + """CERT record""" - @ivar certificate_type: certificate type - @type certificate_type: int - @ivar key_tag: key tag - @type key_tag: int - @ivar algorithm: algorithm - @type algorithm: int - @ivar certificate: the certificate or CRL - @type certificate: string - @see: RFC 2538""" + # see RFC 2538 __slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate'] def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate): - super(CERT, self).__init__(rdclass, rdtype) - self.certificate_type = certificate_type - self.key_tag = key_tag - self.algorithm = algorithm - self.certificate = certificate + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'certificate_type', certificate_type) + object.__setattr__(self, 'key_tag', key_tag) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'certificate', certificate) def to_text(self, origin=None, relativize=True, **kw): certificate_type = _ctype_to_text(self.certificate_type) @@ -83,40 +77,27 @@ class CERT(dns.rdata.Rdata): dns.rdata._base64ify(self.certificate)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): certificate_type = _ctype_from_text(tok.get_string()) key_tag = tok.get_uint16() algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) if algorithm < 0 or algorithm > 255: raise dns.exception.SyntaxError("bad algorithm type") - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - b64 = b''.join(chunks) + b64 = tok.concatenate_remaining_identifiers().encode() certificate = base64.b64decode(b64) return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): prefix = struct.pack("!HHB", self.certificate_type, self.key_tag, self.algorithm) file.write(prefix) file.write(self.certificate) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - prefix = wire[current: current + 5].unwrap() - current += 5 - rdlen -= 5 - if rdlen < 0: - raise dns.exception.FormError - (certificate_type, key_tag, algorithm) = struct.unpack("!HHB", prefix) - certificate = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB") + certificate = parser.get_remaining() return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) - diff --git a/lib/dns/rdtypes/ANY/CNAME.py b/lib/dns/rdtypes/ANY/CNAME.py index 65cf570c..11d42aa7 100644 --- a/lib/dns/rdtypes/ANY/CNAME.py +++ b/lib/dns/rdtypes/ANY/CNAME.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/CSYNC.py b/lib/dns/rdtypes/ANY/CSYNC.py index bf95cb27..9cba5fad 100644 --- a/lib/dns/rdtypes/ANY/CSYNC.py +++ b/lib/dns/rdtypes/ANY/CSYNC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011, 2016 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -19,106 +21,43 @@ import dns.exception import dns.rdata import dns.rdatatype import dns.name -from dns._compat import xrange +import dns.rdtypes.util + + +class Bitmap(dns.rdtypes.util.Bitmap): + type_name = 'CSYNC' + class CSYNC(dns.rdata.Rdata): - """CSYNC record - - @ivar serial: the SOA serial number - @type serial: int - @ivar flags: the CSYNC flags - @type flags: int - @ivar windows: the windowed bitmap list - @type windows: list of (window number, string) tuples""" + """CSYNC record""" __slots__ = ['serial', 'flags', 'windows'] def __init__(self, rdclass, rdtype, serial, flags, windows): - super(CSYNC, self).__init__(rdclass, rdtype) - self.serial = serial - self.flags = flags - self.windows = windows + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'serial', serial) + object.__setattr__(self, 'flags', flags) + object.__setattr__(self, 'windows', dns.rdata._constify(windows)) def to_text(self, origin=None, relativize=True, **kw): - text = '' - for (window, bitmap) in self.windows: - bits = [] - for i in xrange(0, len(bitmap)): - byte = bitmap[i] - for j in xrange(0, 8): - if byte & (0x80 >> j): - bits.append(dns.rdatatype.to_text(window * 256 + - i * 8 + j)) - text += (' ' + ' '.join(bits)) + text = Bitmap(self.windows).to_text() return '%d %d%s' % (self.serial, self.flags, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): serial = tok.get_uint32() flags = tok.get_uint16() - rdtypes = [] - while 1: - token = tok.get().unescape() - if token.is_eol_or_eof(): - break - nrdtype = dns.rdatatype.from_text(token.value) - if nrdtype == 0: - raise dns.exception.SyntaxError("CSYNC with bit 0") - if nrdtype > 65535: - raise dns.exception.SyntaxError("CSYNC with bit > 65535") - rdtypes.append(nrdtype) - rdtypes.sort() - window = 0 - octets = 0 - prior_rdtype = 0 - bitmap = bytearray(b'\0' * 32) - windows = [] - for nrdtype in rdtypes: - if nrdtype == prior_rdtype: - continue - prior_rdtype = nrdtype - new_window = nrdtype // 256 - if new_window != window: - windows.append((window, bitmap[0:octets])) - bitmap = bytearray(b'\0' * 32) - window = new_window - offset = nrdtype % 256 - byte = offset // 8 - bit = offset % 8 - octets = byte + 1 - bitmap[byte] = bitmap[byte] | (0x80 >> bit) - - windows.append((window, bitmap[0:octets])) + windows = Bitmap().from_text(tok) return cls(rdclass, rdtype, serial, flags, windows) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(struct.pack('!IH', self.serial, self.flags)) - for (window, bitmap) in self.windows: - file.write(struct.pack('!BB', window, len(bitmap))) - file.write(bitmap) + Bitmap(self.windows).to_wire(file) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 6: - raise dns.exception.FormError("CSYNC too short") - (serial, flags) = struct.unpack("!IH", wire[current: current + 6]) - current += 6 - rdlen -= 6 - windows = [] - while rdlen > 0: - if rdlen < 3: - raise dns.exception.FormError("CSYNC too short") - window = wire[current] - octets = wire[current + 1] - if octets == 0 or octets > 32: - raise dns.exception.FormError("bad CSYNC octets") - current += 2 - rdlen -= 2 - if rdlen < octets: - raise dns.exception.FormError("bad CSYNC bitmap length") - bitmap = bytearray(wire[current: current + octets].unwrap()) - current += octets - rdlen -= octets - windows.append((window, bitmap)) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (serial, flags) = parser.get_struct("!IH") + windows = Bitmap().from_wire_parser(parser) return cls(rdclass, rdtype, serial, flags, windows) diff --git a/lib/dns/rdtypes/ANY/DLV.py b/lib/dns/rdtypes/ANY/DLV.py index cd1244c1..16352125 100644 --- a/lib/dns/rdtypes/ANY/DLV.py +++ b/lib/dns/rdtypes/ANY/DLV.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/DNAME.py b/lib/dns/rdtypes/ANY/DNAME.py index dac97214..2000d9b0 100644 --- a/lib/dns/rdtypes/ANY/DNAME.py +++ b/lib/dns/rdtypes/ANY/DNAME.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -20,5 +22,5 @@ class DNAME(dns.rdtypes.nsbase.UncompressedNS): """DNAME record""" - def to_digestable(self, origin=None): - return self.target.to_digestable(origin) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.target.to_wire(file, None, origin, canonicalize) diff --git a/lib/dns/rdtypes/ANY/DNSKEY.py b/lib/dns/rdtypes/ANY/DNSKEY.py index e915e98b..2ee37988 100644 --- a/lib/dns/rdtypes/ANY/DNSKEY.py +++ b/lib/dns/rdtypes/ANY/DNSKEY.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -14,10 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.rdtypes.dnskeybase -from dns.rdtypes.dnskeybase import flags_to_text_set, flags_from_text_set - - -__all__ = ['flags_to_text_set', 'flags_from_text_set'] +from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): diff --git a/lib/dns/rdtypes/ANY/DS.py b/lib/dns/rdtypes/ANY/DS.py index 577c8d84..7d457b22 100644 --- a/lib/dns/rdtypes/ANY/DS.py +++ b/lib/dns/rdtypes/ANY/DS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/EUI48.py b/lib/dns/rdtypes/ANY/EUI48.py index aa260e20..b16e81f3 100644 --- a/lib/dns/rdtypes/ANY/EUI48.py +++ b/lib/dns/rdtypes/ANY/EUI48.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2015 Red Hat, Inc. # Author: Petr Spacek # @@ -19,11 +21,9 @@ import dns.rdtypes.euibase class EUI48(dns.rdtypes.euibase.EUIBase): - """EUI48 record + """EUI48 record""" - @ivar fingerprint: 48-bit Extended Unique Identifier (EUI-48) - @type fingerprint: string - @see: rfc7043.txt""" + # see: rfc7043.txt byte_len = 6 # 0123456789ab (in hex) text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab diff --git a/lib/dns/rdtypes/ANY/EUI64.py b/lib/dns/rdtypes/ANY/EUI64.py index 5eba350d..cc080760 100644 --- a/lib/dns/rdtypes/ANY/EUI64.py +++ b/lib/dns/rdtypes/ANY/EUI64.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2015 Red Hat, Inc. # Author: Petr Spacek # @@ -19,11 +21,9 @@ import dns.rdtypes.euibase class EUI64(dns.rdtypes.euibase.EUIBase): - """EUI64 record + """EUI64 record""" - @ivar fingerprint: 64-bit Extended Unique Identifier (EUI-64) - @type fingerprint: string - @see: rfc7043.txt""" + # see: rfc7043.txt byte_len = 8 # 0123456789abcdef (in hex) text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab-cd-ef diff --git a/lib/dns/rdtypes/ANY/GPOS.py b/lib/dns/rdtypes/ANY/GPOS.py index a359a771..03677fd2 100644 --- a/lib/dns/rdtypes/ANY/GPOS.py +++ b/lib/dns/rdtypes/ANY/GPOS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,15 +20,19 @@ import struct import dns.exception import dns.rdata import dns.tokenizer -from dns._compat import long, text_type def _validate_float_string(what): + if len(what) == 0: + raise dns.exception.FormError if what[0] == b'-'[0] or what[0] == b'+'[0]: what = what[1:] if what.isdigit(): return - (left, right) = what.split(b'.') + try: + (left, right) = what.split(b'.') + except ValueError: + raise dns.exception.FormError if left == b'' and right == b'': raise dns.exception.FormError if not left == b'' and not left.decode().isdigit(): @@ -36,38 +42,29 @@ def _validate_float_string(what): def _sanitize(value): - if isinstance(value, text_type): + if isinstance(value, str): return value.encode() return value class GPOS(dns.rdata.Rdata): - """GPOS record + """GPOS record""" - @ivar latitude: latitude - @type latitude: string - @ivar longitude: longitude - @type longitude: string - @ivar altitude: altitude - @type altitude: string - @see: RFC 1712""" + # see: RFC 1712 __slots__ = ['latitude', 'longitude', 'altitude'] def __init__(self, rdclass, rdtype, latitude, longitude, altitude): - super(GPOS, self).__init__(rdclass, rdtype) + super().__init__(rdclass, rdtype) if isinstance(latitude, float) or \ - isinstance(latitude, int) or \ - isinstance(latitude, long): + isinstance(latitude, int): latitude = str(latitude) if isinstance(longitude, float) or \ - isinstance(longitude, int) or \ - isinstance(longitude, long): + isinstance(longitude, int): longitude = str(longitude) if isinstance(altitude, float) or \ - isinstance(altitude, int) or \ - isinstance(altitude, long): + isinstance(altitude, int): altitude = str(altitude) latitude = _sanitize(latitude) longitude = _sanitize(longitude) @@ -75,24 +72,31 @@ class GPOS(dns.rdata.Rdata): _validate_float_string(latitude) _validate_float_string(longitude) _validate_float_string(altitude) - self.latitude = latitude - self.longitude = longitude - self.altitude = altitude + object.__setattr__(self, 'latitude', latitude) + object.__setattr__(self, 'longitude', longitude) + object.__setattr__(self, 'altitude', altitude) + flat = self.float_latitude + if flat < -90.0 or flat > 90.0: + raise dns.exception.FormError('bad latitude') + flong = self.float_longitude + if flong < -180.0 or flong > 180.0: + raise dns.exception.FormError('bad longitude') def to_text(self, origin=None, relativize=True, **kw): - return '%s %s %s' % (self.latitude.decode(), - self.longitude.decode(), - self.altitude.decode()) + return '{} {} {}'.format(self.latitude.decode(), + self.longitude.decode(), + self.altitude.decode()) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): latitude = tok.get_string() longitude = tok.get_string() altitude = tok.get_string() tok.get_eol() return cls(rdclass, rdtype, latitude, longitude, altitude) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.latitude) assert l < 256 file.write(struct.pack('!B', l)) @@ -107,54 +111,23 @@ class GPOS(dns.rdata.Rdata): file.write(self.altitude) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - latitude = wire[current: current + l].unwrap() - current += l - rdlen -= l - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - longitude = wire[current: current + l].unwrap() - current += l - rdlen -= l - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - altitude = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + latitude = parser.get_counted_bytes() + longitude = parser.get_counted_bytes() + altitude = parser.get_counted_bytes() return cls(rdclass, rdtype, latitude, longitude, altitude) - def _get_float_latitude(self): + @property + def float_latitude(self): + "latitude as a floating point value" return float(self.latitude) - def _set_float_latitude(self, value): - self.latitude = str(value) - - float_latitude = property(_get_float_latitude, _set_float_latitude, - doc="latitude as a floating point value") - - def _get_float_longitude(self): + @property + def float_longitude(self): + "longitude as a floating point value" return float(self.longitude) - def _set_float_longitude(self, value): - self.longitude = str(value) - - float_longitude = property(_get_float_longitude, _set_float_longitude, - doc="longitude as a floating point value") - - def _get_float_altitude(self): + @property + def float_altitude(self): + "altitude as a floating point value" return float(self.altitude) - - def _set_float_altitude(self, value): - self.altitude = str(value) - - float_altitude = property(_get_float_altitude, _set_float_altitude, - doc="altitude as a floating point value") diff --git a/lib/dns/rdtypes/ANY/HINFO.py b/lib/dns/rdtypes/ANY/HINFO.py index 52298bc4..587e0ad1 100644 --- a/lib/dns/rdtypes/ANY/HINFO.py +++ b/lib/dns/rdtypes/ANY/HINFO.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,44 +20,40 @@ import struct import dns.exception import dns.rdata import dns.tokenizer -from dns._compat import text_type class HINFO(dns.rdata.Rdata): - """HINFO record + """HINFO record""" - @ivar cpu: the CPU type - @type cpu: string - @ivar os: the OS type - @type os: string - @see: RFC 1035""" + # see: RFC 1035 __slots__ = ['cpu', 'os'] def __init__(self, rdclass, rdtype, cpu, os): - super(HINFO, self).__init__(rdclass, rdtype) - if isinstance(cpu, text_type): - self.cpu = cpu.encode() + super().__init__(rdclass, rdtype) + if isinstance(cpu, str): + object.__setattr__(self, 'cpu', cpu.encode()) else: - self.cpu = cpu - if isinstance(os, text_type): - self.os = os.encode() + object.__setattr__(self, 'cpu', cpu) + if isinstance(os, str): + object.__setattr__(self, 'os', os.encode()) else: - self.os = os + object.__setattr__(self, 'os', os) def to_text(self, origin=None, relativize=True, **kw): - return '"%s" "%s"' % (dns.rdata._escapify(self.cpu), - dns.rdata._escapify(self.os)) + return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), + dns.rdata._escapify(self.os)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - cpu = tok.get_string() - os = tok.get_string() + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + cpu = tok.get_string(max_length=255) + os = tok.get_string(max_length=255) tok.get_eol() return cls(rdclass, rdtype, cpu, os) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.cpu) assert l < 256 file.write(struct.pack('!B', l)) @@ -66,20 +64,7 @@ class HINFO(dns.rdata.Rdata): file.write(self.os) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - cpu = wire[current:current + l].unwrap() - current += l - rdlen -= l - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - os = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + cpu = parser.get_counted_bytes() + os = parser.get_counted_bytes() return cls(rdclass, rdtype, cpu, os) - diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index e0cd2755..1c774bbf 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2010, 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -24,40 +26,33 @@ import dns.rdatatype class HIP(dns.rdata.Rdata): - """HIP record + """HIP record""" - @ivar hit: the host identity tag - @type hit: string - @ivar algorithm: the public key cryptographic algorithm - @type algorithm: int - @ivar key: the public key - @type key: string - @ivar servers: the rendezvous servers - @type servers: list of dns.name.Name objects - @see: RFC 5205""" + # see: RFC 5205 __slots__ = ['hit', 'algorithm', 'key', 'servers'] def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): - super(HIP, self).__init__(rdclass, rdtype) - self.hit = hit - self.algorithm = algorithm - self.key = key - self.servers = servers + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'hit', hit) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'key', key) + object.__setattr__(self, 'servers', dns.rdata._constify(servers)) def to_text(self, origin=None, relativize=True, **kw): hit = binascii.hexlify(self.hit).decode() key = base64.b64encode(self.key).replace(b'\n', b'').decode() - text = u'' + text = '' servers = [] for server in self.servers: servers.append(server.choose_relativity(origin, relativize)) if len(servers) > 0: - text += (u' ' + u' '.join(map(lambda x: x.to_unicode(), servers))) - return u'%u %s %s%s' % (self.algorithm, hit, key, text) + text += (' ' + ' '.join((x.to_unicode() for x in servers))) + return '%u %s %s%s' % (self.algorithm, hit, key, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): algorithm = tok.get_uint8() hit = binascii.unhexlify(tok.get_string().encode()) if len(hit) > 255: @@ -68,46 +63,26 @@ class HIP(dns.rdata.Rdata): token = tok.get() if token.is_eol_or_eof(): break - server = dns.name.from_text(token.value, origin) - server.choose_relativity(origin, relativize) + server = tok.as_name(token, origin, relativize, relativize_to) servers.append(server) return cls(rdclass, rdtype, hit, algorithm, key, servers) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): lh = len(self.hit) lk = len(self.key) file.write(struct.pack("!BBH", lh, self.algorithm, lk)) file.write(self.hit) file.write(self.key) for server in self.servers: - server.to_wire(file, None, origin) + server.to_wire(file, None, origin, False) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (lh, algorithm, lk) = struct.unpack('!BBH', - wire[current: current + 4]) - current += 4 - rdlen -= 4 - hit = wire[current: current + lh].unwrap() - current += lh - rdlen -= lh - key = wire[current: current + lk].unwrap() - current += lk - rdlen -= lk + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (lh, algorithm, lk) = parser.get_struct('!BBH') + hit = parser.get_bytes(lh) + key = parser.get_bytes(lk) servers = [] - while rdlen > 0: - (server, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - current += cused - rdlen -= cused - if origin is not None: - server = server.relativize(origin) + while parser.remaining() > 0: + server = parser.get_name(origin) servers.append(server) return cls(rdclass, rdtype, hit, algorithm, key, servers) - - def choose_relativity(self, origin=None, relativize=True): - servers = [] - for server in self.servers: - server = server.choose_relativity(origin, relativize) - servers.append(server) - self.servers = servers diff --git a/lib/dns/rdtypes/ANY/ISDN.py b/lib/dns/rdtypes/ANY/ISDN.py index 01284a82..6834b3c7 100644 --- a/lib/dns/rdtypes/ANY/ISDN.py +++ b/lib/dns/rdtypes/ANY/ISDN.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,41 +20,37 @@ import struct import dns.exception import dns.rdata import dns.tokenizer -from dns._compat import text_type class ISDN(dns.rdata.Rdata): - """ISDN record + """ISDN record""" - @ivar address: the ISDN address - @type address: string - @ivar subaddress: the ISDN subaddress (or '' if not present) - @type subaddress: string - @see: RFC 1183""" + # see: RFC 1183 __slots__ = ['address', 'subaddress'] def __init__(self, rdclass, rdtype, address, subaddress): - super(ISDN, self).__init__(rdclass, rdtype) - if isinstance(address, text_type): - self.address = address.encode() + super().__init__(rdclass, rdtype) + if isinstance(address, str): + object.__setattr__(self, 'address', address.encode()) else: - self.address = address - if isinstance(address, text_type): - self.subaddress = subaddress.encode() + object.__setattr__(self, 'address', address) + if isinstance(address, str): + object.__setattr__(self, 'subaddress', subaddress.encode()) else: - self.subaddress = subaddress + object.__setattr__(self, 'subaddress', subaddress) def to_text(self, origin=None, relativize=True, **kw): if self.subaddress: - return '"%s" "%s"' % (dns.rdata._escapify(self.address), - dns.rdata._escapify(self.subaddress)) + return '"{}" "{}"'.format(dns.rdata._escapify(self.address), + dns.rdata._escapify(self.subaddress)) else: return '"%s"' % dns.rdata._escapify(self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): address = tok.get_string() t = tok.get() if not t.is_eol_or_eof(): @@ -64,7 +62,7 @@ class ISDN(dns.rdata.Rdata): tok.get_eol() return cls(rdclass, rdtype, address, subaddress) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.address) assert l < 256 file.write(struct.pack('!B', l)) @@ -76,23 +74,10 @@ class ISDN(dns.rdata.Rdata): file.write(self.subaddress) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - address = wire[current: current + l].unwrap() - current += l - rdlen -= l - if rdlen > 0: - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - subaddress = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_counted_bytes() + if parser.remaining() > 0: + subaddress = parser.get_counted_bytes() else: - subaddress = '' + subaddress = b'' return cls(rdclass, rdtype, address, subaddress) - diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index fbfcd70f..eb00a1cd 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -17,23 +19,32 @@ import struct import dns.exception import dns.rdata -from dns._compat import long, xrange -_pows = tuple(long(10**i) for i in range(0, 11)) +_pows = tuple(10**i for i in range(0, 11)) # default values are in centimeters _default_size = 100.0 _default_hprec = 1000000.0 _default_vprec = 1000.0 +# for use by from_wire() +_MAX_LATITUDE = 0x80000000 + 90 * 3600000 +_MIN_LATITUDE = 0x80000000 - 90 * 3600000 +_MAX_LONGITUDE = 0x80000000 + 180 * 3600000 +_MIN_LONGITUDE = 0x80000000 - 180 * 3600000 + +# pylint complains about division since we don't have a from __future__ for +# it, but we don't care about python 2 warnings, so turn them off. +# +# pylint: disable=old-division def _exponent_of(what, desc): if what == 0: return 0 exp = None - for i in xrange(len(_pows)): - if what // _pows[i] == long(0): + for (i, pow) in enumerate(_pows): + if what // pow == 0: exp = i - 1 break if exp is None or exp < 0: @@ -47,7 +58,7 @@ def _float_to_tuple(what): what *= -1 else: sign = 1 - what = long(round(what * 3600000)) + what = round(what * 3600000) # pylint: disable=round-builtin degrees = int(what // 3600000) what -= degrees * 3600000 minutes = int(what // 60000) @@ -67,7 +78,7 @@ def _tuple_to_float(what): def _encode_size(what, desc): - what = long(what) + what = int(what) exponent = _exponent_of(what, desc) & 0xF base = what // pow(10, exponent) & 0xF return base * 16 + exponent @@ -76,32 +87,18 @@ def _encode_size(what, desc): def _decode_size(what, desc): exponent = what & 0x0F if exponent > 9: - raise dns.exception.SyntaxError("bad %s exponent" % desc) + raise dns.exception.FormError("bad %s exponent" % desc) base = (what & 0xF0) >> 4 if base > 9: - raise dns.exception.SyntaxError("bad %s base" % desc) - return long(base) * pow(10, exponent) + raise dns.exception.FormError("bad %s base" % desc) + return base * pow(10, exponent) class LOC(dns.rdata.Rdata): - """LOC record + """LOC record""" - @ivar latitude: latitude - @type latitude: (int, int, int, int, sign) tuple specifying the degrees, minutes, - seconds, milliseconds, and sign of the coordinate. - @ivar longitude: longitude - @type longitude: (int, int, int, int, sign) tuple specifying the degrees, - minutes, seconds, milliseconds, and sign of the coordinate. - @ivar altitude: altitude - @type altitude: float - @ivar size: size of the sphere - @type size: float - @ivar horizontal_precision: horizontal precision - @type horizontal_precision: float - @ivar vertical_precision: vertical precision - @type vertical_precision: float - @see: RFC 1876""" + # see: RFC 1876 __slots__ = ['latitude', 'longitude', 'altitude', 'size', 'horizontal_precision', 'vertical_precision'] @@ -117,35 +114,31 @@ class LOC(dns.rdata.Rdata): degrees. The other parameters are floats. Size, horizontal precision, and vertical precision are specified in centimeters.""" - super(LOC, self).__init__(rdclass, rdtype) - if isinstance(latitude, int) or isinstance(latitude, long): + super().__init__(rdclass, rdtype) + if isinstance(latitude, int): latitude = float(latitude) if isinstance(latitude, float): latitude = _float_to_tuple(latitude) - self.latitude = latitude - if isinstance(longitude, int) or isinstance(longitude, long): + object.__setattr__(self, 'latitude', dns.rdata._constify(latitude)) + if isinstance(longitude, int): longitude = float(longitude) if isinstance(longitude, float): longitude = _float_to_tuple(longitude) - self.longitude = longitude - self.altitude = float(altitude) - self.size = float(size) - self.horizontal_precision = float(hprec) - self.vertical_precision = float(vprec) + object.__setattr__(self, 'longitude', dns.rdata._constify(longitude)) + object.__setattr__(self, 'altitude', float(altitude)) + object.__setattr__(self, 'size', float(size)) + object.__setattr__(self, 'horizontal_precision', float(hprec)) + object.__setattr__(self, 'vertical_precision', float(vprec)) def to_text(self, origin=None, relativize=True, **kw): if self.latitude[4] > 0: lat_hemisphere = 'N' - lat_degrees = self.latitude[0] else: lat_hemisphere = 'S' - lat_degrees = -1 * self.latitude[0] if self.longitude[4] > 0: long_hemisphere = 'E' - long_degrees = self.longitude[0] else: long_hemisphere = 'W' - long_degrees = -1 * self.longitude[0] text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % ( self.latitude[0], self.latitude[1], self.latitude[2], self.latitude[3], lat_hemisphere, @@ -158,14 +151,15 @@ class LOC(dns.rdata.Rdata): if self.size != _default_size or \ self.horizontal_precision != _default_hprec or \ self.vertical_precision != _default_vprec: - text += " %0.2fm %0.2fm %0.2fm" % ( + text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( self.size / 100.0, self.horizontal_precision / 100.0, self.vertical_precision / 100.0 ) return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): latitude = [0, 0, 0, 0, 1] longitude = [0, 0, 0, 0, 1] size = _default_size @@ -173,9 +167,13 @@ class LOC(dns.rdata.Rdata): vprec = _default_vprec latitude[0] = tok.get_int() + if latitude[0] > 90: + raise dns.exception.SyntaxError('latitude >= 90') t = tok.get_string() if t.isdigit(): latitude[1] = int(t) + if latitude[1] >= 60: + raise dns.exception.SyntaxError('latitude minutes >= 60') t = tok.get_string() if '.' in t: (seconds, milliseconds) = t.split('.') @@ -206,9 +204,13 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError('bad latitude hemisphere value') longitude[0] = tok.get_int() + if longitude[0] > 180: + raise dns.exception.SyntaxError('longitude > 180') t = tok.get_string() if t.isdigit(): longitude[1] = int(t) + if longitude[1] >= 60: + raise dns.exception.SyntaxError('longitude minutes >= 60') t = tok.get_string() if '.' in t: (seconds, milliseconds) = t.split('.') @@ -263,21 +265,26 @@ class LOC(dns.rdata.Rdata): vprec = float(value) * 100.0 # m -> cm tok.get_eol() + # Try encoding these now so we raise if they are bad + _encode_size(size, "size") + _encode_size(hprec, "horizontal precision") + _encode_size(vprec, "vertical precision") + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): milliseconds = (self.latitude[0] * 3600000 + self.latitude[1] * 60000 + self.latitude[2] * 1000 + self.latitude[3]) * self.latitude[4] - latitude = long(0x80000000) + milliseconds + latitude = 0x80000000 + milliseconds milliseconds = (self.longitude[0] * 3600000 + self.longitude[1] * 60000 + self.longitude[2] * 1000 + self.longitude[3]) * self.longitude[4] - longitude = long(0x80000000) + milliseconds - altitude = long(self.altitude) + long(10000000) + longitude = 0x80000000 + milliseconds + altitude = int(self.altitude) + 10000000 size = _encode_size(self.size, "size") hprec = _encode_size(self.horizontal_precision, "horizontal precision") vprec = _encode_size(self.vertical_precision, "vertical precision") @@ -286,21 +293,21 @@ class LOC(dns.rdata.Rdata): file.write(wire) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (version, size, hprec, vprec, latitude, longitude, altitude) = \ - struct.unpack("!BBBBIII", wire[current: current + rdlen]) - if latitude > long(0x80000000): - latitude = float(latitude - long(0x80000000)) / 3600000 - else: - latitude = -1 * float(long(0x80000000) - latitude) / 3600000 - if latitude < -90.0 or latitude > 90.0: + parser.get_struct("!BBBBIII") + if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: raise dns.exception.FormError("bad latitude") - if longitude > long(0x80000000): - longitude = float(longitude - long(0x80000000)) / 3600000 + if latitude > 0x80000000: + latitude = (latitude - 0x80000000) / 3600000 else: - longitude = -1 * float(long(0x80000000) - longitude) / 3600000 - if longitude < -180.0 or longitude > 180.0: + latitude = -1 * (0x80000000 - latitude) / 3600000 + if longitude < _MIN_LONGITUDE or longitude > _MAX_LONGITUDE: raise dns.exception.FormError("bad longitude") + if longitude > 0x80000000: + longitude = (longitude - 0x80000000) / 3600000 + else: + longitude = -1 * (0x80000000 - longitude) / 3600000 altitude = float(altitude) - 10000000.0 size = _decode_size(size, "size") hprec = _decode_size(hprec, "horizontal precision") @@ -308,20 +315,12 @@ class LOC(dns.rdata.Rdata): return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) - def _get_float_latitude(self): + @property + def float_latitude(self): + "latitude as a floating point value" return _tuple_to_float(self.latitude) - def _set_float_latitude(self, value): - self.latitude = _float_to_tuple(value) - - float_latitude = property(_get_float_latitude, _set_float_latitude, - doc="latitude as a floating point value") - - def _get_float_longitude(self): + @property + def float_longitude(self): + "longitude as a floating point value" return _tuple_to_float(self.longitude) - - def _set_float_longitude(self, value): - self.longitude = _float_to_tuple(value) - - float_longitude = property(_get_float_longitude, _set_float_longitude, - doc="longitude as a floating point value") diff --git a/lib/dns/rdtypes/ANY/MX.py b/lib/dns/rdtypes/ANY/MX.py index 3a6735dc..0a06494f 100644 --- a/lib/dns/rdtypes/ANY/MX.py +++ b/lib/dns/rdtypes/ANY/MX.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/NINFO.py b/lib/dns/rdtypes/ANY/NINFO.py new file mode 100644 index 00000000..d4c8572c --- /dev/null +++ b/lib/dns/rdtypes/ANY/NINFO.py @@ -0,0 +1,25 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.rdtypes.txtbase + + +class NINFO(dns.rdtypes.txtbase.TXTBase): + + """NINFO record""" + + # see: draft-reid-dnsext-zs-01 diff --git a/lib/dns/rdtypes/ANY/NS.py b/lib/dns/rdtypes/ANY/NS.py index ae56d819..f9fcf637 100644 --- a/lib/dns/rdtypes/ANY/NS.py +++ b/lib/dns/rdtypes/ANY/NS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index dfe96859..626d3399 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -13,114 +15,46 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct - import dns.exception import dns.rdata import dns.rdatatype import dns.name -from dns._compat import xrange +import dns.rdtypes.util + + +class Bitmap(dns.rdtypes.util.Bitmap): + type_name = 'NSEC' class NSEC(dns.rdata.Rdata): - """NSEC record - - @ivar next: the next name - @type next: dns.name.Name object - @ivar windows: the windowed bitmap list - @type windows: list of (window number, string) tuples""" + """NSEC record""" __slots__ = ['next', 'windows'] def __init__(self, rdclass, rdtype, next, windows): - super(NSEC, self).__init__(rdclass, rdtype) - self.next = next - self.windows = windows + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'next', next) + object.__setattr__(self, 'windows', dns.rdata._constify(windows)) def to_text(self, origin=None, relativize=True, **kw): next = self.next.choose_relativity(origin, relativize) - text = '' - for (window, bitmap) in self.windows: - bits = [] - for i in xrange(0, len(bitmap)): - byte = bitmap[i] - for j in xrange(0, 8): - if byte & (0x80 >> j): - bits.append(dns.rdatatype.to_text(window * 256 + - i * 8 + j)) - text += (' ' + ' '.join(bits)) - return '%s%s' % (next, text) + text = Bitmap(self.windows).to_text() + return '{}{}'.format(next, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - next = tok.get_name() - next = next.choose_relativity(origin, relativize) - rdtypes = [] - while 1: - token = tok.get().unescape() - if token.is_eol_or_eof(): - break - nrdtype = dns.rdatatype.from_text(token.value) - if nrdtype == 0: - raise dns.exception.SyntaxError("NSEC with bit 0") - if nrdtype > 65535: - raise dns.exception.SyntaxError("NSEC with bit > 65535") - rdtypes.append(nrdtype) - rdtypes.sort() - window = 0 - octets = 0 - prior_rdtype = 0 - bitmap = bytearray(b'\0' * 32) - windows = [] - for nrdtype in rdtypes: - if nrdtype == prior_rdtype: - continue - prior_rdtype = nrdtype - new_window = nrdtype // 256 - if new_window != window: - windows.append((window, bitmap[0:octets])) - bitmap = bytearray(b'\0' * 32) - window = new_window - offset = nrdtype % 256 - byte = offset // 8 - bit = offset % 8 - octets = byte + 1 - bitmap[byte] = bitmap[byte] | (0x80 >> bit) - - windows.append((window, bitmap[0:octets])) + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + next = tok.get_name(origin, relativize, relativize_to) + windows = Bitmap().from_text(tok) return cls(rdclass, rdtype, next, windows) - def to_wire(self, file, compress=None, origin=None): - self.next.to_wire(file, None, origin) - for (window, bitmap) in self.windows: - file.write(struct.pack('!BB', window, len(bitmap))) - file.write(bitmap) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.next.to_wire(file, None, origin, False) + Bitmap(self.windows).to_wire(file) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (next, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - windows = [] - while rdlen > 0: - if rdlen < 3: - raise dns.exception.FormError("NSEC too short") - window = wire[current] - octets = wire[current + 1] - if octets == 0 or octets > 32: - raise dns.exception.FormError("bad NSEC octets") - current += 2 - rdlen -= 2 - if rdlen < octets: - raise dns.exception.FormError("bad NSEC bitmap length") - bitmap = bytearray(wire[current: current + octets].unwrap()) - current += octets - rdlen -= octets - windows.append((window, bitmap)) - if origin is not None: - next = next.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + next = parser.get_name(origin) + windows = Bitmap().from_wire_parser(parser) return cls(rdclass, rdtype, next, windows) - - def choose_relativity(self, origin=None, relativize=True): - self.next = self.next.choose_relativity(origin, relativize) diff --git a/lib/dns/rdtypes/ANY/NSEC3.py b/lib/dns/rdtypes/ANY/NSEC3.py index 3982f4b4..91471f0f 100644 --- a/lib/dns/rdtypes/ANY/NSEC3.py +++ b/lib/dns/rdtypes/ANY/NSEC3.py @@ -1,4 +1,6 @@ -# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,24 +17,18 @@ import base64 import binascii -import string import struct import dns.exception import dns.rdata import dns.rdatatype -from dns._compat import xrange, text_type +import dns.rdtypes.util -try: - b32_hex_to_normal = string.maketrans('0123456789ABCDEFGHIJKLMNOPQRSTUV', - 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') - b32_normal_to_hex = string.maketrans('ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', - '0123456789ABCDEFGHIJKLMNOPQRSTUV') -except AttributeError: - b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV', - b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') - b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', - b'0123456789ABCDEFGHIJKLMNOPQRSTUV') + +b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV', + b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') +b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', + b'0123456789ABCDEFGHIJKLMNOPQRSTUV') # hash algorithm constants SHA1 = 1 @@ -41,37 +37,28 @@ SHA1 = 1 OPTOUT = 1 +class Bitmap(dns.rdtypes.util.Bitmap): + type_name = 'NSEC3' + + class NSEC3(dns.rdata.Rdata): - """NSEC3 record - - @ivar algorithm: the hash algorithm number - @type algorithm: int - @ivar flags: the flags - @type flags: int - @ivar iterations: the number of iterations - @type iterations: int - @ivar salt: the salt - @type salt: string - @ivar next: the next name hash - @type next: string - @ivar windows: the windowed bitmap list - @type windows: list of (window number, string) tuples""" + """NSEC3 record""" __slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows'] def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows): - super(NSEC3, self).__init__(rdclass, rdtype) - self.algorithm = algorithm - self.flags = flags - self.iterations = iterations - if isinstance(salt, text_type): - self.salt = salt.encode() + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'flags', flags) + object.__setattr__(self, 'iterations', iterations) + if isinstance(salt, str): + object.__setattr__(self, 'salt', salt.encode()) else: - self.salt = salt - self.next = next - self.windows = windows + object.__setattr__(self, 'salt', salt) + object.__setattr__(self, 'next', next) + object.__setattr__(self, 'windows', dns.rdata._constify(windows)) def to_text(self, origin=None, relativize=True, **kw): next = base64.b32encode(self.next).translate( @@ -80,70 +67,29 @@ class NSEC3(dns.rdata.Rdata): salt = '-' else: salt = binascii.hexlify(self.salt).decode() - text = u'' - for (window, bitmap) in self.windows: - bits = [] - for i in xrange(0, len(bitmap)): - byte = bitmap[i] - for j in xrange(0, 8): - if byte & (0x80 >> j): - bits.append(dns.rdatatype.to_text(window * 256 + - i * 8 + j)) - text += (u' ' + u' '.join(bits)) - return u'%u %u %u %s %s%s' % (self.algorithm, self.flags, - self.iterations, salt, next, text) + text = Bitmap(self.windows).to_text() + return '%u %u %u %s %s%s' % (self.algorithm, self.flags, + self.iterations, salt, next, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): algorithm = tok.get_uint8() flags = tok.get_uint8() iterations = tok.get_uint16() salt = tok.get_string() - if salt == u'-': + if salt == '-': salt = b'' else: salt = binascii.unhexlify(salt.encode('ascii')) next = tok.get_string().encode( 'ascii').upper().translate(b32_hex_to_normal) next = base64.b32decode(next) - rdtypes = [] - while 1: - token = tok.get().unescape() - if token.is_eol_or_eof(): - break - nrdtype = dns.rdatatype.from_text(token.value) - if nrdtype == 0: - raise dns.exception.SyntaxError("NSEC3 with bit 0") - if nrdtype > 65535: - raise dns.exception.SyntaxError("NSEC3 with bit > 65535") - rdtypes.append(nrdtype) - rdtypes.sort() - window = 0 - octets = 0 - prior_rdtype = 0 - bitmap = bytearray(b'\0' * 32) - windows = [] - for nrdtype in rdtypes: - if nrdtype == prior_rdtype: - continue - prior_rdtype = nrdtype - new_window = nrdtype // 256 - if new_window != window: - if octets != 0: - windows.append((window, ''.join(bitmap[0:octets]))) - bitmap = bytearray(b'\0' * 32) - window = new_window - offset = nrdtype % 256 - byte = offset // 8 - bit = offset % 8 - octets = byte + 1 - bitmap[byte] = bitmap[byte] | (0x80 >> bit) - if octets != 0: - windows.append((window, bitmap[0:octets])) + windows = Bitmap().from_text(tok) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, windows) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) @@ -151,42 +97,13 @@ class NSEC3(dns.rdata.Rdata): l = len(self.next) file.write(struct.pack("!B", l)) file.write(self.next) - for (window, bitmap) in self.windows: - file.write(struct.pack("!BB", window, len(bitmap))) - file.write(bitmap) + Bitmap(self.windows).to_wire(file) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (algorithm, flags, iterations, slen) = \ - struct.unpack('!BBHB', wire[current: current + 5]) - - current += 5 - rdlen -= 5 - salt = wire[current: current + slen].unwrap() - current += slen - rdlen -= slen - nlen = wire[current] - current += 1 - rdlen -= 1 - next = wire[current: current + nlen].unwrap() - current += nlen - rdlen -= nlen - windows = [] - while rdlen > 0: - if rdlen < 3: - raise dns.exception.FormError("NSEC3 too short") - window = wire[current] - octets = wire[current + 1] - if octets == 0 or octets > 32: - raise dns.exception.FormError("bad NSEC3 octets") - current += 2 - rdlen -= 2 - if rdlen < octets: - raise dns.exception.FormError("bad NSEC3 bitmap length") - bitmap = bytearray(wire[current: current + octets].unwrap()) - current += octets - rdlen -= octets - windows.append((window, bitmap)) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (algorithm, flags, iterations) = parser.get_struct('!BBH') + salt = parser.get_counted_bytes() + next = parser.get_counted_bytes() + windows = Bitmap().from_wire_parser(parser) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, windows) - diff --git a/lib/dns/rdtypes/ANY/NSEC3PARAM.py b/lib/dns/rdtypes/ANY/NSEC3PARAM.py index b506282b..8ac76271 100644 --- a/lib/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/lib/dns/rdtypes/ANY/NSEC3PARAM.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,33 +20,23 @@ import binascii import dns.exception import dns.rdata -from dns._compat import text_type class NSEC3PARAM(dns.rdata.Rdata): - """NSEC3PARAM record - - @ivar algorithm: the hash algorithm number - @type algorithm: int - @ivar flags: the flags - @type flags: int - @ivar iterations: the number of iterations - @type iterations: int - @ivar salt: the salt - @type salt: string""" + """NSEC3PARAM record""" __slots__ = ['algorithm', 'flags', 'iterations', 'salt'] def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): - super(NSEC3PARAM, self).__init__(rdclass, rdtype) - self.algorithm = algorithm - self.flags = flags - self.iterations = iterations - if isinstance(salt, text_type): - self.salt = salt.encode() + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'flags', flags) + object.__setattr__(self, 'iterations', iterations) + if isinstance(salt, str): + object.__setattr__(self, 'salt', salt.encode()) else: - self.salt = salt + object.__setattr__(self, 'salt', salt) def to_text(self, origin=None, relativize=True, **kw): if self.salt == b'': @@ -55,7 +47,8 @@ class NSEC3PARAM(dns.rdata.Rdata): salt) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): algorithm = tok.get_uint8() flags = tok.get_uint8() iterations = tok.get_uint16() @@ -67,23 +60,14 @@ class NSEC3PARAM(dns.rdata.Rdata): tok.get_eol() return cls(rdclass, rdtype, algorithm, flags, iterations, salt) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) file.write(self.salt) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (algorithm, flags, iterations, slen) = \ - struct.unpack('!BBHB', - wire[current: current + 5]) - current += 5 - rdlen -= 5 - salt = wire[current: current + slen].unwrap() - current += slen - rdlen -= slen - if rdlen != 0: - raise dns.exception.FormError + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (algorithm, flags, iterations) = parser.get_struct('!BBH') + salt = parser.get_counted_bytes() return cls(rdclass, rdtype, algorithm, flags, iterations, salt) - diff --git a/lib/dns/rdtypes/ANY/OPENPGPKEY.py b/lib/dns/rdtypes/ANY/OPENPGPKEY.py new file mode 100644 index 00000000..f632132e --- /dev/null +++ b/lib/dns/rdtypes/ANY/OPENPGPKEY.py @@ -0,0 +1,50 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2016 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 + +import dns.exception +import dns.rdata +import dns.tokenizer + +class OPENPGPKEY(dns.rdata.Rdata): + + """OPENPGPKEY record""" + + # see: RFC 7929 + + def __init__(self, rdclass, rdtype, key): + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'key', key) + + def to_text(self, origin=None, relativize=True, **kw): + return dns.rdata._base64ify(self.key) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + b64 = tok.concatenate_remaining_identifiers().encode() + key = base64.b64decode(b64) + return cls(rdclass, rdtype, key) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(self.key) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + key = parser.get_remaining() + return cls(rdclass, rdtype, key) diff --git a/lib/dns/rdtypes/ANY/OPT.py b/lib/dns/rdtypes/ANY/OPT.py new file mode 100644 index 00000000..c48aa12f --- /dev/null +++ b/lib/dns/rdtypes/ANY/OPT.py @@ -0,0 +1,67 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.edns +import dns.exception +import dns.rdata + + +class OPT(dns.rdata.Rdata): + + """OPT record""" + + __slots__ = ['options'] + + def __init__(self, rdclass, rdtype, options): + """Initialize an OPT rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata, + which is also the payload size. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *options*, a tuple of ``bytes`` + """ + + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'options', dns.rdata._constify(options)) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + for opt in self.options: + owire = opt.to_wire() + file.write(struct.pack("!HH", opt.otype, len(owire))) + file.write(owire) + + def to_text(self, origin=None, relativize=True, **kw): + return ' '.join(opt.to_text() for opt in self.options) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + options = [] + while parser.remaining() > 0: + (otype, olen) = parser.get_struct('!HH') + with parser.restrict_to(olen): + opt = dns.edns.option_from_wire_parser(otype, parser) + options.append(opt) + return cls(rdclass, rdtype, options) + + @property + def payload(self): + "payload size" + return self.rdclass diff --git a/lib/dns/rdtypes/ANY/PTR.py b/lib/dns/rdtypes/ANY/PTR.py index 250187a6..20cd5076 100644 --- a/lib/dns/rdtypes/ANY/PTR.py +++ b/lib/dns/rdtypes/ANY/PTR.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index e9071c76..7446de6d 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -20,61 +22,36 @@ import dns.name class RP(dns.rdata.Rdata): - """RP record + """RP record""" - @ivar mbox: The responsible person's mailbox - @type mbox: dns.name.Name object - @ivar txt: The owner name of a node with TXT records, or the root name - if no TXT records are associated with this RP. - @type txt: dns.name.Name object - @see: RFC 1183""" + # see: RFC 1183 __slots__ = ['mbox', 'txt'] def __init__(self, rdclass, rdtype, mbox, txt): - super(RP, self).__init__(rdclass, rdtype) - self.mbox = mbox - self.txt = txt + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'mbox', mbox) + object.__setattr__(self, 'txt', txt) def to_text(self, origin=None, relativize=True, **kw): mbox = self.mbox.choose_relativity(origin, relativize) txt = self.txt.choose_relativity(origin, relativize) - return "%s %s" % (str(mbox), str(txt)) + return "{} {}".format(str(mbox), str(txt)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - mbox = tok.get_name() - txt = tok.get_name() - mbox = mbox.choose_relativity(origin, relativize) - txt = txt.choose_relativity(origin, relativize) + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + mbox = tok.get_name(origin, relativize, relativize_to) + txt = tok.get_name(origin, relativize, relativize_to) tok.get_eol() return cls(rdclass, rdtype, mbox, txt) - def to_wire(self, file, compress=None, origin=None): - self.mbox.to_wire(file, None, origin) - self.txt.to_wire(file, None, origin) - - def to_digestable(self, origin=None): - return self.mbox.to_digestable(origin) + \ - self.txt.to_digestable(origin) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.mbox.to_wire(file, None, origin, canonicalize) + self.txt.to_wire(file, None, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (mbox, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - current += cused - rdlen -= cused - if rdlen <= 0: - raise dns.exception.FormError - (txt, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - mbox = mbox.relativize(origin) - txt = txt.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + mbox = parser.get_name(origin) + txt = parser.get_name(origin) return cls(rdclass, rdtype, mbox, txt) - - def choose_relativity(self, origin=None, relativize=True): - self.mbox = self.mbox.choose_relativity(origin, relativize) - self.txt = self.txt.choose_relativity(origin, relativize) diff --git a/lib/dns/rdtypes/ANY/RRSIG.py b/lib/dns/rdtypes/ANY/RRSIG.py index 953dfb9a..2077d905 100644 --- a/lib/dns/rdtypes/ANY/RRSIG.py +++ b/lib/dns/rdtypes/ANY/RRSIG.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -30,6 +32,8 @@ class BadSigTime(dns.exception.DNSException): def sigtime_to_posixtime(what): + if len(what) <= 10 and what.isdigit(): + return int(what) if len(what) != 14: raise BadSigTime year = int(what[0:4]) @@ -48,26 +52,7 @@ def posixtime_to_sigtime(what): class RRSIG(dns.rdata.Rdata): - """RRSIG record - - @ivar type_covered: the rdata type this signature covers - @type type_covered: int - @ivar algorithm: the algorithm used for the sig - @type algorithm: int - @ivar labels: number of labels - @type labels: int - @ivar original_ttl: the original TTL - @type original_ttl: long - @ivar expiration: signature expiration time - @type expiration: long - @ivar inception: signature inception time - @type inception: long - @ivar key_tag: the key tag - @type key_tag: int - @ivar signer: the signer - @type signer: dns.name.Name object - @ivar signature: the signature - @type signature: string""" + """RRSIG record""" __slots__ = ['type_covered', 'algorithm', 'labels', 'original_ttl', 'expiration', 'inception', 'key_tag', 'signer', @@ -76,16 +61,16 @@ class RRSIG(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, type_covered, algorithm, labels, original_ttl, expiration, inception, key_tag, signer, signature): - super(RRSIG, self).__init__(rdclass, rdtype) - self.type_covered = type_covered - self.algorithm = algorithm - self.labels = labels - self.original_ttl = original_ttl - self.expiration = expiration - self.inception = inception - self.key_tag = key_tag - self.signer = signer - self.signature = signature + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'type_covered', type_covered) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'labels', labels) + object.__setattr__(self, 'original_ttl', original_ttl) + object.__setattr__(self, 'expiration', expiration) + object.__setattr__(self, 'inception', inception) + object.__setattr__(self, 'key_tag', key_tag) + object.__setattr__(self, 'signer', signer) + object.__setattr__(self, 'signature', signature) def covers(self): return self.type_covered @@ -104,7 +89,8 @@ class RRSIG(dns.rdata.Rdata): ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): type_covered = dns.rdatatype.from_text(tok.get_string()) algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) labels = tok.get_int() @@ -112,45 +98,25 @@ class RRSIG(dns.rdata.Rdata): expiration = sigtime_to_posixtime(tok.get_string()) inception = sigtime_to_posixtime(tok.get_string()) key_tag = tok.get_int() - signer = tok.get_name() - signer = signer.choose_relativity(origin, relativize) - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - b64 = b''.join(chunks) + signer = tok.get_name(origin, relativize, relativize_to) + b64 = tok.concatenate_remaining_identifiers().encode() signature = base64.b64decode(b64) return cls(rdclass, rdtype, type_covered, algorithm, labels, original_ttl, expiration, inception, key_tag, signer, signature) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack('!HBBIIIH', self.type_covered, self.algorithm, self.labels, self.original_ttl, self.expiration, self.inception, self.key_tag) file.write(header) - self.signer.to_wire(file, None, origin) + self.signer.to_wire(file, None, origin, canonicalize) file.write(self.signature) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack('!HBBIIIH', wire[current: current + 18]) - current += 18 - rdlen -= 18 - (signer, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - if origin is not None: - signer = signer.relativize(origin) - signature = wire[current: current + rdlen].unwrap() - return cls(rdclass, rdtype, header[0], header[1], header[2], - header[3], header[4], header[5], header[6], signer, - signature) - - def choose_relativity(self, origin=None, relativize=True): - self.signer = self.signer.choose_relativity(origin, relativize) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct('!HBBIIIH') + signer = parser.get_name(origin) + signature = parser.get_remaining() + return cls(rdclass, rdtype, *header, signer, signature) diff --git a/lib/dns/rdtypes/ANY/RT.py b/lib/dns/rdtypes/ANY/RT.py index 88b75486..d0feb79e 100644 --- a/lib/dns/rdtypes/ANY/RT.py +++ b/lib/dns/rdtypes/ANY/RT.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/SOA.py b/lib/dns/rdtypes/ANY/SOA.py index cc0098e8..e93274ed 100644 --- a/lib/dns/rdtypes/ANY/SOA.py +++ b/lib/dns/rdtypes/ANY/SOA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,38 +24,23 @@ import dns.name class SOA(dns.rdata.Rdata): - """SOA record + """SOA record""" - @ivar mname: the SOA MNAME (master name) field - @type mname: dns.name.Name object - @ivar rname: the SOA RNAME (responsible name) field - @type rname: dns.name.Name object - @ivar serial: The zone's serial number - @type serial: int - @ivar refresh: The zone's refresh value (in seconds) - @type refresh: int - @ivar retry: The zone's retry value (in seconds) - @type retry: int - @ivar expire: The zone's expiration value (in seconds) - @type expire: int - @ivar minimum: The zone's negative caching time (in seconds, called - "minimum" for historical reasons) - @type minimum: int - @see: RFC 1035""" + # see: RFC 1035 __slots__ = ['mname', 'rname', 'serial', 'refresh', 'retry', 'expire', 'minimum'] def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum): - super(SOA, self).__init__(rdclass, rdtype) - self.mname = mname - self.rname = rname - self.serial = serial - self.refresh = refresh - self.retry = retry - self.expire = expire - self.minimum = minimum + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'mname', mname) + object.__setattr__(self, 'rname', rname) + object.__setattr__(self, 'serial', serial) + object.__setattr__(self, 'refresh', refresh) + object.__setattr__(self, 'retry', retry) + object.__setattr__(self, 'expire', expire) + object.__setattr__(self, 'minimum', minimum) def to_text(self, origin=None, relativize=True, **kw): mname = self.mname.choose_relativity(origin, relativize) @@ -63,11 +50,10 @@ class SOA(dns.rdata.Rdata): self.expire, self.minimum) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - mname = tok.get_name() - rname = tok.get_name() - mname = mname.choose_relativity(origin, relativize) - rname = rname.choose_relativity(origin, relativize) + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + mname = tok.get_name(origin, relativize, relativize_to) + rname = tok.get_name(origin, relativize, relativize_to) serial = tok.get_uint32() refresh = tok.get_ttl() retry = tok.get_ttl() @@ -77,38 +63,15 @@ class SOA(dns.rdata.Rdata): return cls(rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum) - def to_wire(self, file, compress=None, origin=None): - self.mname.to_wire(file, compress, origin) - self.rname.to_wire(file, compress, origin) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.mname.to_wire(file, compress, origin, canonicalize) + self.rname.to_wire(file, compress, origin, canonicalize) five_ints = struct.pack('!IIIII', self.serial, self.refresh, self.retry, self.expire, self.minimum) file.write(five_ints) - def to_digestable(self, origin=None): - return self.mname.to_digestable(origin) + \ - self.rname.to_digestable(origin) + \ - struct.pack('!IIIII', self.serial, self.refresh, - self.retry, self.expire, self.minimum) - @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (mname, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - (rname, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - if rdlen != 20: - raise dns.exception.FormError - five_ints = struct.unpack('!IIIII', - wire[current: current + rdlen]) - if origin is not None: - mname = mname.relativize(origin) - rname = rname.relativize(origin) - return cls(rdclass, rdtype, mname, rname, - five_ints[0], five_ints[1], five_ints[2], five_ints[3], - five_ints[4]) - - def choose_relativity(self, origin=None, relativize=True): - self.mname = self.mname.choose_relativity(origin, relativize) - self.rname = self.rname.choose_relativity(origin, relativize) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + mname = parser.get_name(origin) + rname = parser.get_name(origin) + return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII')) diff --git a/lib/dns/rdtypes/ANY/SPF.py b/lib/dns/rdtypes/ANY/SPF.py index f3e0904e..f1f6834e 100644 --- a/lib/dns/rdtypes/ANY/SPF.py +++ b/lib/dns/rdtypes/ANY/SPF.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,6 +20,6 @@ import dns.rdtypes.txtbase class SPF(dns.rdtypes.txtbase.TXTBase): - """SPF record + """SPF record""" - @see: RFC 4408""" + # see: RFC 4408 diff --git a/lib/dns/rdtypes/ANY/SSHFP.py b/lib/dns/rdtypes/ANY/SSHFP.py index b6ed396f..a3cc0039 100644 --- a/lib/dns/rdtypes/ANY/SSHFP.py +++ b/lib/dns/rdtypes/ANY/SSHFP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,24 +24,18 @@ import dns.rdatatype class SSHFP(dns.rdata.Rdata): - """SSHFP record + """SSHFP record""" - @ivar algorithm: the algorithm - @type algorithm: int - @ivar fp_type: the digest type - @type fp_type: int - @ivar fingerprint: the fingerprint - @type fingerprint: string - @see: draft-ietf-secsh-dns-05.txt""" + # See RFC 4255 __slots__ = ['algorithm', 'fp_type', 'fingerprint'] def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint): - super(SSHFP, self).__init__(rdclass, rdtype) - self.algorithm = algorithm - self.fp_type = fp_type - self.fingerprint = fingerprint + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'fp_type', fp_type) + object.__setattr__(self, 'fingerprint', fingerprint) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %s' % (self.algorithm, @@ -48,31 +44,21 @@ class SSHFP(dns.rdata.Rdata): chunksize=128)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): algorithm = tok.get_uint8() fp_type = tok.get_uint8() - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - fingerprint = b''.join(chunks) + fingerprint = tok.concatenate_remaining_identifiers().encode() fingerprint = binascii.unhexlify(fingerprint) return cls(rdclass, rdtype, algorithm, fp_type, fingerprint) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack("!BB", self.algorithm, self.fp_type) file.write(header) file.write(self.fingerprint) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack("!BB", wire[current: current + 2]) - current += 2 - rdlen -= 2 - fingerprint = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BB") + fingerprint = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], fingerprint) - diff --git a/lib/dns/rdtypes/ANY/TLSA.py b/lib/dns/rdtypes/ANY/TLSA.py index 23f4e94b..9c9c8662 100644 --- a/lib/dns/rdtypes/ANY/TLSA.py +++ b/lib/dns/rdtypes/ANY/TLSA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,27 +24,19 @@ import dns.rdatatype class TLSA(dns.rdata.Rdata): - """TLSA record + """TLSA record""" - @ivar usage: The certificate usage - @type usage: int - @ivar selector: The selector field - @type selector: int - @ivar mtype: The 'matching type' field - @type mtype: int - @ivar cert: The 'Certificate Association Data' field - @type cert: string - @see: RFC 6698""" + # see: RFC 6698 __slots__ = ['usage', 'selector', 'mtype', 'cert'] def __init__(self, rdclass, rdtype, usage, selector, mtype, cert): - super(TLSA, self).__init__(rdclass, rdtype) - self.usage = usage - self.selector = selector - self.mtype = mtype - self.cert = cert + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'usage', usage) + object.__setattr__(self, 'selector', selector) + object.__setattr__(self, 'mtype', mtype) + object.__setattr__(self, 'cert', cert) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %d %s' % (self.usage, @@ -52,32 +46,22 @@ class TLSA(dns.rdata.Rdata): chunksize=128)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): usage = tok.get_uint8() selector = tok.get_uint8() mtype = tok.get_uint8() - cert_chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - cert_chunks.append(t.value.encode()) - cert = b''.join(cert_chunks) + cert = tok.concatenate_remaining_identifiers().encode() cert = binascii.unhexlify(cert) return cls(rdclass, rdtype, usage, selector, mtype, cert) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack("!BBB", self.usage, self.selector, self.mtype) file.write(header) file.write(self.cert) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack("!BBB", wire[current: current + 3]) - current += 3 - rdlen -= 3 - cert = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BBB") + cert = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], cert) - diff --git a/lib/dns/rdtypes/ANY/TSIG.py b/lib/dns/rdtypes/ANY/TSIG.py new file mode 100644 index 00000000..18db4c9e --- /dev/null +++ b/lib/dns/rdtypes/ANY/TSIG.py @@ -0,0 +1,91 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.rdata + + +class TSIG(dns.rdata.Rdata): + + """TSIG record""" + + __slots__ = ['algorithm', 'time_signed', 'fudge', 'mac', + 'original_id', 'error', 'other'] + + def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac, + original_id, error, other): + """Initialize a TSIG rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *algorithm*, a ``dns.name.Name``. + + *time_signed*, an ``int``. + + *fudge*, an ``int`. + + *mac*, a ``bytes`` + + *original_id*, an ``int`` + + *error*, an ``int`` + + *other*, a ``bytes`` + """ + + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'time_signed', time_signed) + object.__setattr__(self, 'fudge', fudge) + object.__setattr__(self, 'mac', dns.rdata._constify(mac)) + object.__setattr__(self, 'original_id', original_id) + object.__setattr__(self, 'error', error) + object.__setattr__(self, 'other', dns.rdata._constify(other)) + + def to_text(self, origin=None, relativize=True, **kw): + algorithm = self.algorithm.choose_relativity(origin, relativize) + return f"{algorithm} {self.fudge} {self.time_signed} " + \ + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \ + f"{self.original_id} {self.error} " + \ + f"{len(self.other)} {dns.rdata._base64ify(self.other, 0)}" + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.algorithm.to_wire(file, None, origin, False) + file.write(struct.pack('!HIHH', + (self.time_signed >> 32) & 0xffff, + self.time_signed & 0xffffffff, + self.fudge, + len(self.mac))) + file.write(self.mac) + file.write(struct.pack('!HHH', self.original_id, self.error, + len(self.other))) + file.write(self.other) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + algorithm = parser.get_name(origin) + (time_hi, time_lo, fudge) = parser.get_struct('!HIH') + time_signed = (time_hi << 32) + time_lo + mac = parser.get_counted_bytes(2) + (original_id, error) = parser.get_struct('!HH') + other = parser.get_counted_bytes(2) + return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, + original_id, error, other) diff --git a/lib/dns/rdtypes/ANY/TXT.py b/lib/dns/rdtypes/ANY/TXT.py index 6c7fa450..c5ae919c 100644 --- a/lib/dns/rdtypes/ANY/TXT.py +++ b/lib/dns/rdtypes/ANY/TXT.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/ANY/URI.py b/lib/dns/rdtypes/ANY/URI.py index 0c121d2c..84296f52 100644 --- a/lib/dns/rdtypes/ANY/URI.py +++ b/lib/dns/rdtypes/ANY/URI.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # Copyright (C) 2015 Red Hat, Inc. # @@ -19,40 +21,34 @@ import struct import dns.exception import dns.rdata import dns.name -from dns._compat import text_type class URI(dns.rdata.Rdata): - """URI record + """URI record""" - @ivar priority: the priority - @type priority: int - @ivar weight: the weight - @type weight: int - @ivar target: the target host - @type target: dns.name.Name object - @see: draft-faltstrom-uri-13""" + # see RFC 7553 __slots__ = ['priority', 'weight', 'target'] def __init__(self, rdclass, rdtype, priority, weight, target): - super(URI, self).__init__(rdclass, rdtype) - self.priority = priority - self.weight = weight + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'priority', priority) + object.__setattr__(self, 'weight', weight) if len(target) < 1: raise dns.exception.SyntaxError("URI target cannot be empty") - if isinstance(target, text_type): - self.target = target.encode() + if isinstance(target, str): + object.__setattr__(self, 'target', target.encode()) else: - self.target = target + object.__setattr__(self, 'target', target) def to_text(self, origin=None, relativize=True, **kw): return '%d %d "%s"' % (self.priority, self.weight, self.target.decode()) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): priority = tok.get_uint16() weight = tok.get_uint16() target = tok.get().unescape() @@ -61,21 +57,15 @@ class URI(dns.rdata.Rdata): tok.get_eol() return cls(rdclass, rdtype, priority, weight, target.value) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): two_ints = struct.pack("!HH", self.priority, self.weight) file.write(two_ints) file.write(self.target) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 5: - raise dns.exception.FormError('URI RR is shorter than 5 octets') - - (priority, weight) = struct.unpack('!HH', wire[current: current + 4]) - current += 4 - rdlen -= 4 - target = wire[current: current + rdlen] - current += rdlen - + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (priority, weight) = parser.get_struct('!HH') + target = parser.get_remaining() + if len(target) == 0: + raise dns.exception.FormError('URI target may not be empty') return cls(rdclass, rdtype, priority, weight, target) - diff --git a/lib/dns/rdtypes/ANY/X25.py b/lib/dns/rdtypes/ANY/X25.py index f5cca114..214f1dca 100644 --- a/lib/dns/rdtypes/ANY/X25.py +++ b/lib/dns/rdtypes/ANY/X25.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,48 +20,40 @@ import struct import dns.exception import dns.rdata import dns.tokenizer -from dns._compat import text_type class X25(dns.rdata.Rdata): - """X25 record + """X25 record""" - @ivar address: the PSDN address - @type address: string - @see: RFC 1183""" + # see RFC 1183 __slots__ = ['address'] def __init__(self, rdclass, rdtype, address): - super(X25, self).__init__(rdclass, rdtype) - if isinstance(address, text_type): - self.address = address.encode() + super().__init__(rdclass, rdtype) + if isinstance(address, str): + object.__setattr__(self, 'address', address.encode()) else: - self.address = address + object.__setattr__(self, 'address', address) def to_text(self, origin=None, relativize=True, **kw): return '"%s"' % dns.rdata._escapify(self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): address = tok.get_string() tok.get_eol() return cls(rdclass, rdtype, address) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.address) assert l < 256 file.write(struct.pack('!B', l)) file.write(self.address) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - address = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_counted_bytes() return cls(rdclass, rdtype, address) - diff --git a/lib/dns/rdtypes/ANY/__init__.py b/lib/dns/rdtypes/ANY/__init__.py index ea9c3e2e..ea704c86 100644 --- a/lib/dns/rdtypes/ANY/__init__.py +++ b/lib/dns/rdtypes/ANY/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -17,10 +19,13 @@ __all__ = [ 'AFSDB', + 'AVC', + 'CAA', 'CDNSKEY', 'CDS', 'CERT', 'CNAME', + 'CSYNC', 'DLV', 'DNAME', 'DNSKEY', @@ -37,7 +42,8 @@ __all__ = [ 'NSEC', 'NSEC3', 'NSEC3PARAM', - 'TLSA', + 'OPENPGPKEY', + 'OPT', 'PTR', 'RP', 'RRSIG', @@ -45,6 +51,9 @@ __all__ = [ 'SOA', 'SPF', 'SSHFP', + 'TLSA', + 'TSIG', 'TXT', + 'URI', 'X25', ] diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py new file mode 100644 index 00000000..b738ac6c --- /dev/null +++ b/lib/dns/rdtypes/CH/A.py @@ -0,0 +1,56 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.rdtypes.mxbase +import struct + +class A(dns.rdata.Rdata): + + """A record for Chaosnet""" + + # domain: the domain of the address + # address: the 16-bit address + + __slots__ = ['domain', 'address'] + + def __init__(self, rdclass, rdtype, domain, address): + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'domain', domain) + object.__setattr__(self, 'address', address) + + def to_text(self, origin=None, relativize=True, **kw): + domain = self.domain.choose_relativity(origin, relativize) + return '%s %o' % (domain, self.address) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + domain = tok.get_name(origin, relativize, relativize_to) + address = tok.get_uint16(base=8) + tok.get_eol() + return cls(rdclass, rdtype, domain, address) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.domain.to_wire(file, compress, origin, canonicalize) + pref = struct.pack("!H", self.address) + file.write(pref) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + domain = parser.get_name(origin) + address = parser.get_uint16() + return cls(rdclass, rdtype, domain, address) diff --git a/lib/dns/rdtypes/CH/__init__.py b/lib/dns/rdtypes/CH/__init__.py new file mode 100644 index 00000000..7184a733 --- /dev/null +++ b/lib/dns/rdtypes/CH/__init__.py @@ -0,0 +1,22 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Class CH rdata type classes.""" + +__all__ = [ + 'A', +] diff --git a/lib/dns/rdtypes/IN/A.py b/lib/dns/rdtypes/IN/A.py index 42faf9ba..8b71e329 100644 --- a/lib/dns/rdtypes/IN/A.py +++ b/lib/dns/rdtypes/IN/A.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -21,33 +23,30 @@ import dns.tokenizer class A(dns.rdata.Rdata): - """A record. - - @ivar address: an IPv4 address - @type address: string (in the standard "dotted quad" format)""" + """A record.""" __slots__ = ['address'] def __init__(self, rdclass, rdtype, address): - super(A, self).__init__(rdclass, rdtype) + super().__init__(rdclass, rdtype) # check that it's OK dns.ipv4.inet_aton(address) - self.address = address + object.__setattr__(self, 'address', address) def to_text(self, origin=None, relativize=True, **kw): return self.address @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): address = tok.get_identifier() tok.get_eol() return cls(rdclass, rdtype, address) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(dns.ipv4.inet_aton(self.address)) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.ipv4.inet_ntoa(wire[current: current + rdlen]).decode() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = dns.ipv4.inet_ntoa(parser.get_remaining()) return cls(rdclass, rdtype, address) - diff --git a/lib/dns/rdtypes/IN/AAAA.py b/lib/dns/rdtypes/IN/AAAA.py index d2c65c63..08f9d679 100644 --- a/lib/dns/rdtypes/IN/AAAA.py +++ b/lib/dns/rdtypes/IN/AAAA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -14,41 +16,37 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import dns.exception -import dns.inet +import dns.ipv6 import dns.rdata import dns.tokenizer class AAAA(dns.rdata.Rdata): - """AAAA record. - - @ivar address: an IPv6 address - @type address: string (in the standard IPv6 format)""" + """AAAA record.""" __slots__ = ['address'] def __init__(self, rdclass, rdtype, address): - super(AAAA, self).__init__(rdclass, rdtype) + super().__init__(rdclass, rdtype) # check that it's OK - dns.inet.inet_pton(dns.inet.AF_INET6, address) - self.address = address + dns.ipv6.inet_aton(address) + object.__setattr__(self, 'address', address) def to_text(self, origin=None, relativize=True, **kw): return self.address @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): address = tok.get_identifier() tok.get_eol() return cls(rdclass, rdtype, address) - def to_wire(self, file, compress=None, origin=None): - file.write(dns.inet.inet_pton(dns.inet.AF_INET6, self.address)) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(dns.ipv6.inet_aton(self.address)) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.inet.inet_ntop(dns.inet.AF_INET6, - wire[current: current + rdlen]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = dns.ipv6.inet_ntoa(parser.get_remaining()) return cls(rdclass, rdtype, address) - diff --git a/lib/dns/rdtypes/IN/APL.py b/lib/dns/rdtypes/IN/APL.py index 82026adf..ab7fe4bc 100644 --- a/lib/dns/rdtypes/IN/APL.py +++ b/lib/dns/rdtypes/IN/APL.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,29 +15,19 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import codecs +import struct import dns.exception -import dns.inet +import dns.ipv4 +import dns.ipv6 import dns.rdata import dns.tokenizer -from dns._compat import xrange +class APLItem: -class APLItem(object): - - """An APL list item. - - @ivar family: the address family (IANA address family registry) - @type family: int - @ivar negation: is this item negated? - @type negation: bool - @ivar address: the address - @type address: string - @ivar prefix: the prefix length - @type prefix: int - """ + """An APL list item.""" __slots__ = ['family', 'negation', 'address', 'prefix'] @@ -53,17 +45,17 @@ class APLItem(object): def to_wire(self, file): if self.family == 1: - address = dns.inet.inet_pton(dns.inet.AF_INET, self.address) + address = dns.ipv4.inet_aton(self.address) elif self.family == 2: - address = dns.inet.inet_pton(dns.inet.AF_INET6, self.address) + address = dns.ipv6.inet_aton(self.address) else: address = binascii.unhexlify(self.address) # # Truncate least significant zero bytes. # last = 0 - for i in xrange(len(address) - 1, -1, -1): - if address[i] != chr(0): + for i in range(len(address) - 1, -1, -1): + if address[i] != 0: last = i + 1 break address = address[0: last] @@ -78,25 +70,24 @@ class APLItem(object): class APL(dns.rdata.Rdata): - """APL record. + """APL record.""" - @ivar items: a list of APL items - @type items: list of APL_Item - @see: RFC 3123""" + # see: RFC 3123 __slots__ = ['items'] def __init__(self, rdclass, rdtype, items): - super(APL, self).__init__(rdclass, rdtype) - self.items = items + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'items', dns.rdata._constify(items)) def to_text(self, origin=None, relativize=True, **kw): - return ' '.join(map(lambda x: str(x), self.items)) + return ' '.join(map(str, self.items)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): items = [] - while 1: + while True: token = tok.get().unescape() if token.is_eol_or_eof(): break @@ -115,48 +106,38 @@ class APL(dns.rdata.Rdata): return cls(rdclass, rdtype, items) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for item in self.items: item.to_wire(file) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + items = [] - while 1: - if rdlen == 0: - break - if rdlen < 4: - raise dns.exception.FormError - header = struct.unpack('!HBB', wire[current: current + 4]) + while parser.remaining() > 0: + header = parser.get_struct('!HBB') afdlen = header[2] if afdlen > 127: negation = True afdlen -= 128 else: negation = False - current += 4 - rdlen -= 4 - if rdlen < afdlen: - raise dns.exception.FormError - address = wire[current: current + afdlen].unwrap() + address = parser.get_bytes(afdlen) l = len(address) if header[0] == 1: if l < 4: - address += '\x00' * (4 - l) - address = dns.inet.inet_ntop(dns.inet.AF_INET, address) + address += b'\x00' * (4 - l) + address = dns.ipv4.inet_ntoa(address) elif header[0] == 2: if l < 16: - address += '\x00' * (16 - l) - address = dns.inet.inet_ntop(dns.inet.AF_INET6, address) + address += b'\x00' * (16 - l) + address = dns.ipv6.inet_ntoa(address) else: # # This isn't really right according to the RFC, but it # seems better than throwing an exception # - address = address.encode('hex_codec') - current += afdlen - rdlen -= afdlen + address = codecs.encode(address, 'hex_codec') item = APLItem(header[0], negation, address, header[1]) items.append(item) return cls(rdclass, rdtype, items) - diff --git a/lib/dns/rdtypes/IN/DHCID.py b/lib/dns/rdtypes/IN/DHCID.py index 06a850ad..6f66eb89 100644 --- a/lib/dns/rdtypes/IN/DHCID.py +++ b/lib/dns/rdtypes/IN/DHCID.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -20,41 +22,30 @@ import dns.exception class DHCID(dns.rdata.Rdata): - """DHCID record + """DHCID record""" - @ivar data: the data (the content of the RR is opaque as far as the - DNS is concerned) - @type data: string - @see: RFC 4701""" + # see: RFC 4701 __slots__ = ['data'] def __init__(self, rdclass, rdtype, data): - super(DHCID, self).__init__(rdclass, rdtype) - self.data = data + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'data', data) def to_text(self, origin=None, relativize=True, **kw): return dns.rdata._base64ify(self.data) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - b64 = b''.join(chunks) + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + b64 = tok.concatenate_remaining_identifiers().encode() data = base64.b64decode(b64) return cls(rdclass, rdtype, data) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(self.data) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - data = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + data = parser.get_remaining() return cls(rdclass, rdtype, data) - diff --git a/lib/dns/rdtypes/IN/IPSECKEY.py b/lib/dns/rdtypes/IN/IPSECKEY.py index 4f07bd09..182ad2cb 100644 --- a/lib/dns/rdtypes/IN/IPSECKEY.py +++ b/lib/dns/rdtypes/IN/IPSECKEY.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -17,133 +19,63 @@ import struct import base64 import dns.exception -import dns.inet -import dns.name +import dns.rdtypes.util +class Gateway(dns.rdtypes.util.Gateway): + name = 'IPSECKEY gateway' + class IPSECKEY(dns.rdata.Rdata): - """IPSECKEY record + """IPSECKEY record""" - @ivar precedence: the precedence for this key data - @type precedence: int - @ivar gateway_type: the gateway type - @type gateway_type: int - @ivar algorithm: the algorithm to use - @type algorithm: int - @ivar gateway: the public key - @type gateway: None, IPv4 address, IPV6 address, or domain name - @ivar key: the public key - @type key: string - @see: RFC 4025""" + # see: RFC 4025 __slots__ = ['precedence', 'gateway_type', 'algorithm', 'gateway', 'key'] def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key): - super(IPSECKEY, self).__init__(rdclass, rdtype) - if gateway_type == 0: - if gateway != '.' and gateway is not None: - raise SyntaxError('invalid gateway for gateway type 0') - gateway = None - elif gateway_type == 1: - # check that it's OK - dns.inet.inet_pton(dns.inet.AF_INET, gateway) - elif gateway_type == 2: - # check that it's OK - dns.inet.inet_pton(dns.inet.AF_INET6, gateway) - elif gateway_type == 3: - pass - else: - raise SyntaxError( - 'invalid IPSECKEY gateway type: %d' % gateway_type) - self.precedence = precedence - self.gateway_type = gateway_type - self.algorithm = algorithm - self.gateway = gateway - self.key = key + super().__init__(rdclass, rdtype) + Gateway(gateway_type, gateway).check() + object.__setattr__(self, 'precedence', precedence) + object.__setattr__(self, 'gateway_type', gateway_type) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'gateway', gateway) + object.__setattr__(self, 'key', key) def to_text(self, origin=None, relativize=True, **kw): - if self.gateway_type == 0: - gateway = '.' - elif self.gateway_type == 1: - gateway = self.gateway - elif self.gateway_type == 2: - gateway = self.gateway - elif self.gateway_type == 3: - gateway = str(self.gateway.choose_relativity(origin, relativize)) - else: - raise ValueError('invalid gateway type') + gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, + relativize) return '%d %d %d %s %s' % (self.precedence, self.gateway_type, self.algorithm, gateway, dns.rdata._base64ify(self.key)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): precedence = tok.get_uint8() gateway_type = tok.get_uint8() algorithm = tok.get_uint8() - if gateway_type == 3: - gateway = tok.get_name().choose_relativity(origin, relativize) - else: - gateway = tok.get_string() - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - b64 = b''.join(chunks) + gateway = Gateway(gateway_type).from_text(tok, origin, relativize, + relativize_to) + b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) return cls(rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm) file.write(header) - if self.gateway_type == 0: - pass - elif self.gateway_type == 1: - file.write(dns.inet.inet_pton(dns.inet.AF_INET, self.gateway)) - elif self.gateway_type == 2: - file.write(dns.inet.inet_pton(dns.inet.AF_INET6, self.gateway)) - elif self.gateway_type == 3: - self.gateway.to_wire(file, None, origin) - else: - raise ValueError('invalid gateway type') + Gateway(self.gateway_type, self.gateway).to_wire(file, compress, + origin, canonicalize) file.write(self.key) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 3: - raise dns.exception.FormError - header = struct.unpack('!BBB', wire[current: current + 3]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct('!BBB') gateway_type = header[1] - current += 3 - rdlen -= 3 - if gateway_type == 0: - gateway = None - elif gateway_type == 1: - gateway = dns.inet.inet_ntop(dns.inet.AF_INET, - wire[current: current + 4]) - current += 4 - rdlen -= 4 - elif gateway_type == 2: - gateway = dns.inet.inet_ntop(dns.inet.AF_INET6, - wire[current: current + 16]) - current += 16 - rdlen -= 16 - elif gateway_type == 3: - (gateway, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - current += cused - rdlen -= cused - else: - raise dns.exception.FormError('invalid IPSECKEY gateway type') - key = wire[current: current + rdlen].unwrap() + gateway = Gateway(gateway_type).from_wire_parser(parser, origin) + key = parser.get_remaining() return cls(rdclass, rdtype, header[0], gateway_type, header[2], gateway, key) - diff --git a/lib/dns/rdtypes/IN/KX.py b/lib/dns/rdtypes/IN/KX.py index adbfe34b..ebf8fd77 100644 --- a/lib/dns/rdtypes/IN/KX.py +++ b/lib/dns/rdtypes/IN/KX.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -16,6 +18,6 @@ import dns.rdtypes.mxbase -class KX(dns.rdtypes.mxbase.UncompressedMX): +class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX): """KX record""" diff --git a/lib/dns/rdtypes/IN/NAPTR.py b/lib/dns/rdtypes/IN/NAPTR.py index 5ae2feb1..48d43562 100644 --- a/lib/dns/rdtypes/IN/NAPTR.py +++ b/lib/dns/rdtypes/IN/NAPTR.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,7 +20,6 @@ import struct import dns.exception import dns.name import dns.rdata -from dns._compat import xrange, text_type def _write_string(file, s): @@ -29,41 +30,29 @@ def _write_string(file, s): def _sanitize(value): - if isinstance(value, text_type): + if isinstance(value, str): return value.encode() return value class NAPTR(dns.rdata.Rdata): - """NAPTR record + """NAPTR record""" - @ivar order: order - @type order: int - @ivar preference: preference - @type preference: int - @ivar flags: flags - @type flags: string - @ivar service: service - @type service: string - @ivar regexp: regular expression - @type regexp: string - @ivar replacement: replacement name - @type replacement: dns.name.Name object - @see: RFC 3403""" + # see: RFC 3403 __slots__ = ['order', 'preference', 'flags', 'service', 'regexp', 'replacement'] def __init__(self, rdclass, rdtype, order, preference, flags, service, regexp, replacement): - super(NAPTR, self).__init__(rdclass, rdtype) - self.flags = _sanitize(flags) - self.service = _sanitize(service) - self.regexp = _sanitize(regexp) - self.order = order - self.preference = preference - self.replacement = replacement + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'flags', _sanitize(flags)) + object.__setattr__(self, 'service', _sanitize(service)) + object.__setattr__(self, 'regexp', _sanitize(regexp)) + object.__setattr__(self, 'order', order) + object.__setattr__(self, 'preference', preference) + object.__setattr__(self, 'replacement', replacement) def to_text(self, origin=None, relativize=True, **kw): replacement = self.replacement.choose_relativity(origin, relativize) @@ -75,51 +64,33 @@ class NAPTR(dns.rdata.Rdata): replacement) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): order = tok.get_uint16() preference = tok.get_uint16() flags = tok.get_string() service = tok.get_string() regexp = tok.get_string() - replacement = tok.get_name() - replacement = replacement.choose_relativity(origin, relativize) + replacement = tok.get_name(origin, relativize, relativize_to) tok.get_eol() return cls(rdclass, rdtype, order, preference, flags, service, regexp, replacement) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): two_ints = struct.pack("!HH", self.order, self.preference) file.write(two_ints) _write_string(file, self.flags) _write_string(file, self.service) _write_string(file, self.regexp) - self.replacement.to_wire(file, compress, origin) + self.replacement.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (order, preference) = struct.unpack('!HH', wire[current: current + 4]) - current += 4 - rdlen -= 4 + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (order, preference) = parser.get_struct('!HH') strings = [] - for i in xrange(3): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen or rdlen < 0: - raise dns.exception.FormError - s = wire[current: current + l].unwrap() - current += l - rdlen -= l + for i in range(3): + s = parser.get_counted_bytes() strings.append(s) - (replacement, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - replacement = replacement.relativize(origin) + replacement = parser.get_name(origin) return cls(rdclass, rdtype, order, preference, strings[0], strings[1], strings[2], replacement) - - def choose_relativity(self, origin=None, relativize=True): - self.replacement = self.replacement.choose_relativity(origin, - relativize) diff --git a/lib/dns/rdtypes/IN/NSAP.py b/lib/dns/rdtypes/IN/NSAP.py index 6dbe5af0..227465fa 100644 --- a/lib/dns/rdtypes/IN/NSAP.py +++ b/lib/dns/rdtypes/IN/NSAP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,23 +24,22 @@ import dns.tokenizer class NSAP(dns.rdata.Rdata): - """NSAP record. + """NSAP record.""" - @ivar address: a NASP - @type address: string - @see: RFC 1706""" + # see: RFC 1706 __slots__ = ['address'] def __init__(self, rdclass, rdtype, address): - super(NSAP, self).__init__(rdclass, rdtype) - self.address = address + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'address', address) def to_text(self, origin=None, relativize=True, **kw): return "0x%s" % binascii.hexlify(self.address).decode() @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): address = tok.get_string() tok.get_eol() if address[0:2] != '0x': @@ -49,11 +50,10 @@ class NSAP(dns.rdata.Rdata): address = binascii.unhexlify(address.encode()) return cls(rdclass, rdtype, address) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(self.address) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_remaining() return cls(rdclass, rdtype, address) - diff --git a/lib/dns/rdtypes/IN/NSAP_PTR.py b/lib/dns/rdtypes/IN/NSAP_PTR.py index 56967df0..a5b66c80 100644 --- a/lib/dns/rdtypes/IN/NSAP_PTR.py +++ b/lib/dns/rdtypes/IN/NSAP_PTR.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/lib/dns/rdtypes/IN/PX.py b/lib/dns/rdtypes/IN/PX.py index e1ef102b..946d79f8 100644 --- a/lib/dns/rdtypes/IN/PX.py +++ b/lib/dns/rdtypes/IN/PX.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,23 +24,17 @@ import dns.name class PX(dns.rdata.Rdata): - """PX record. + """PX record.""" - @ivar preference: the preference value - @type preference: int - @ivar map822: the map822 name - @type map822: dns.name.Name object - @ivar mapx400: the mapx400 name - @type mapx400: dns.name.Name object - @see: RFC 2163""" + # see: RFC 2163 __slots__ = ['preference', 'map822', 'mapx400'] def __init__(self, rdclass, rdtype, preference, map822, mapx400): - super(PX, self).__init__(rdclass, rdtype) - self.preference = preference - self.map822 = map822 - self.mapx400 = mapx400 + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'preference', preference) + object.__setattr__(self, 'map822', map822) + object.__setattr__(self, 'mapx400', mapx400) def to_text(self, origin=None, relativize=True, **kw): map822 = self.map822.choose_relativity(origin, relativize) @@ -46,42 +42,23 @@ class PX(dns.rdata.Rdata): return '%d %s %s' % (self.preference, map822, mapx400) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): preference = tok.get_uint16() - map822 = tok.get_name() - map822 = map822.choose_relativity(origin, relativize) - mapx400 = tok.get_name(None) - mapx400 = mapx400.choose_relativity(origin, relativize) + map822 = tok.get_name(origin, relativize, relativize_to) + mapx400 = tok.get_name(origin, relativize, relativize_to) tok.get_eol() return cls(rdclass, rdtype, preference, map822, mapx400) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): pref = struct.pack("!H", self.preference) file.write(pref) - self.map822.to_wire(file, None, origin) - self.mapx400.to_wire(file, None, origin) + self.map822.to_wire(file, None, origin, canonicalize) + self.mapx400.to_wire(file, None, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (preference, ) = struct.unpack('!H', wire[current: current + 2]) - current += 2 - rdlen -= 2 - (map822, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused > rdlen: - raise dns.exception.FormError - current += cused - rdlen -= cused - if origin is not None: - map822 = map822.relativize(origin) - (mapx400, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - mapx400 = mapx400.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + map822 = parser.get_name(origin) + mapx400 = parser.get_name(origin) return cls(rdclass, rdtype, preference, map822, mapx400) - - def choose_relativity(self, origin=None, relativize=True): - self.map822 = self.map822.choose_relativity(origin, relativize) - self.mapx400 = self.mapx400.choose_relativity(origin, relativize) diff --git a/lib/dns/rdtypes/IN/SRV.py b/lib/dns/rdtypes/IN/SRV.py index f4396d61..485153f4 100644 --- a/lib/dns/rdtypes/IN/SRV.py +++ b/lib/dns/rdtypes/IN/SRV.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -22,26 +24,18 @@ import dns.name class SRV(dns.rdata.Rdata): - """SRV record + """SRV record""" - @ivar priority: the priority - @type priority: int - @ivar weight: the weight - @type weight: int - @ivar port: the port of the service - @type port: int - @ivar target: the target host - @type target: dns.name.Name object - @see: RFC 2782""" + # see: RFC 2782 __slots__ = ['priority', 'weight', 'port', 'target'] def __init__(self, rdclass, rdtype, priority, weight, port, target): - super(SRV, self).__init__(rdclass, rdtype) - self.priority = priority - self.weight = weight - self.port = port - self.target = target + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'priority', priority) + object.__setattr__(self, 'weight', weight) + object.__setattr__(self, 'port', port) + object.__setattr__(self, 'target', target) def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) @@ -49,33 +43,22 @@ class SRV(dns.rdata.Rdata): target) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): priority = tok.get_uint16() weight = tok.get_uint16() port = tok.get_uint16() - target = tok.get_name(None) - target = target.choose_relativity(origin, relativize) + target = tok.get_name(origin, relativize, relativize_to) tok.get_eol() return cls(rdclass, rdtype, priority, weight, port, target) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): three_ints = struct.pack("!HHH", self.priority, self.weight, self.port) file.write(three_ints) - self.target.to_wire(file, compress, origin) + self.target.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (priority, weight, port) = struct.unpack('!HHH', - wire[current: current + 6]) - current += 6 - rdlen -= 6 - (target, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - target = target.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (priority, weight, port) = parser.get_struct('!HHH') + target = parser.get_name(origin) return cls(rdclass, rdtype, priority, weight, port, target) - - def choose_relativity(self, origin=None, relativize=True): - self.target = self.target.choose_relativity(origin, relativize) diff --git a/lib/dns/rdtypes/IN/WKS.py b/lib/dns/rdtypes/IN/WKS.py index da2a2d88..d66d8583 100644 --- a/lib/dns/rdtypes/IN/WKS.py +++ b/lib/dns/rdtypes/IN/WKS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,7 +20,6 @@ import struct import dns.ipv4 import dns.rdata -from dns._compat import xrange _proto_tcp = socket.getprotobyname('tcp') _proto_udp = socket.getprotobyname('udp') @@ -26,39 +27,31 @@ _proto_udp = socket.getprotobyname('udp') class WKS(dns.rdata.Rdata): - """WKS record + """WKS record""" - @ivar address: the address - @type address: string - @ivar protocol: the protocol - @type protocol: int - @ivar bitmap: the bitmap - @type bitmap: string - @see: RFC 1035""" + # see: RFC 1035 __slots__ = ['address', 'protocol', 'bitmap'] def __init__(self, rdclass, rdtype, address, protocol, bitmap): - super(WKS, self).__init__(rdclass, rdtype) - self.address = address - self.protocol = protocol - if not isinstance(bitmap, bytearray): - self.bitmap = bytearray(bitmap) - else: - self.bitmap = bitmap + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'address', address) + object.__setattr__(self, 'protocol', protocol) + object.__setattr__(self, 'bitmap', dns.rdata._constify(bitmap)) def to_text(self, origin=None, relativize=True, **kw): bits = [] - for i in xrange(0, len(self.bitmap)): + for i in range(0, len(self.bitmap)): byte = self.bitmap[i] - for j in xrange(0, 8): + for j in range(0, 8): if byte & (0x80 >> j): bits.append(str(i * 8 + j)) text = ' '.join(bits) return '%s %d %s' % (self.address, self.protocol, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): address = tok.get_string() protocol = tok.get_string() if protocol.isdigit(): @@ -83,24 +76,21 @@ class WKS(dns.rdata.Rdata): i = serv // 8 l = len(bitmap) if l < i + 1: - for j in xrange(l, i + 1): + for j in range(l, i + 1): bitmap.append(0) bitmap[i] = bitmap[i] | (0x80 >> (serv % 8)) bitmap = dns.rdata._truncate_bitmap(bitmap) return cls(rdclass, rdtype, address, protocol, bitmap) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(dns.ipv4.inet_aton(self.address)) protocol = struct.pack('!B', self.protocol) file.write(protocol) file.write(self.bitmap) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.ipv4.inet_ntoa(wire[current: current + 4]) - protocol, = struct.unpack('!B', wire[current + 4: current + 5]) - current += 5 - rdlen -= 5 - bitmap = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = dns.ipv4.inet_ntoa(parser.get_bytes(4)) + protocol = parser.get_uint8() + bitmap = parser.get_remaining() return cls(rdclass, rdtype, address, protocol, bitmap) - diff --git a/lib/dns/rdtypes/IN/__init__.py b/lib/dns/rdtypes/IN/__init__.py index 24cf1ece..d7e69c9f 100644 --- a/lib/dns/rdtypes/IN/__init__.py +++ b/lib/dns/rdtypes/IN/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -20,6 +22,7 @@ __all__ = [ 'AAAA', 'APL', 'DHCID', + 'IPSECKEY', 'KX', 'NAPTR', 'NSAP', diff --git a/lib/dns/rdtypes/__init__.py b/lib/dns/rdtypes/__init__.py index 826efbb6..ccc848cf 100644 --- a/lib/dns/rdtypes/__init__.py +++ b/lib/dns/rdtypes/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,7 +20,9 @@ __all__ = [ 'ANY', 'IN', + 'CH', 'euibase', 'mxbase', 'nsbase', + 'util' ] diff --git a/lib/dns/rdtypes/dnskeybase.py b/lib/dns/rdtypes/dnskeybase.py index 85c4b23f..0243d6f3 100644 --- a/lib/dns/rdtypes/dnskeybase.py +++ b/lib/dns/rdtypes/dnskeybase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -14,6 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import base64 +import enum import struct import dns.exception @@ -21,116 +24,51 @@ import dns.dnssec import dns.rdata # wildcard import -__all__ = ["SEP", "REVOKE", "ZONE", - "flags_to_text_set", "flags_from_text_set"] +__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 -# flag constants -SEP = 0x0001 -REVOKE = 0x0080 -ZONE = 0x0100 +class Flag(enum.IntFlag): + SEP = 0x0001 + REVOKE = 0x0080 + ZONE = 0x0100 -_flag_by_text = { - 'SEP': SEP, - 'REVOKE': REVOKE, - 'ZONE': ZONE -} - -# We construct the inverse mapping programmatically to ensure that we -# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that -# would cause the mapping not to be true inverse. -_flag_by_value = dict((y, x) for x, y in _flag_by_text.items()) - - -def flags_to_text_set(flags): - """Convert a DNSKEY flags value to set texts - @rtype: set([string])""" - - flags_set = set() - mask = 0x1 - while mask <= 0x8000: - if flags & mask: - text = _flag_by_value.get(mask) - if not text: - text = hex(mask) - flags_set.add(text) - mask <<= 1 - return flags_set - - -def flags_from_text_set(texts_set): - """Convert set of DNSKEY flag mnemonic texts to DNSKEY flag value - @rtype: int""" - - flags = 0 - for text in texts_set: - try: - flags += _flag_by_text[text] - except KeyError: - raise NotImplementedError( - "DNSKEY flag '%s' is not supported" % text) - return flags +globals().update(Flag.__members__) class DNSKEYBase(dns.rdata.Rdata): - """Base class for rdata that is like a DNSKEY record - - @ivar flags: the key flags - @type flags: int - @ivar protocol: the protocol for which this key may be used - @type protocol: int - @ivar algorithm: the algorithm used for the key - @type algorithm: int - @ivar key: the public key - @type key: string""" + """Base class for rdata that is like a DNSKEY record""" __slots__ = ['flags', 'protocol', 'algorithm', 'key'] def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): - super(DNSKEYBase, self).__init__(rdclass, rdtype) - self.flags = flags - self.protocol = protocol - self.algorithm = algorithm - self.key = key + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'flags', flags) + object.__setattr__(self, 'protocol', protocol) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'key', key) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %d %s' % (self.flags, self.protocol, self.algorithm, dns.rdata._base64ify(self.key)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): flags = tok.get_uint16() protocol = tok.get_uint8() algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - b64 = b''.join(chunks) + b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) return cls(rdclass, rdtype, flags, protocol, algorithm, key) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack("!HBB", self.flags, self.protocol, self.algorithm) file.write(header) file.write(self.key) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 4: - raise dns.exception.FormError - header = struct.unpack('!HBB', wire[current: current + 4]) - current += 4 - rdlen -= 4 - key = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct('!HBB') + key = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], key) - - def flags_to_text_set(self): - """Convert a DNSKEY flags value to set texts - @rtype: set([string])""" - return flags_to_text_set(self.flags) diff --git a/lib/dns/rdtypes/dsbase.py b/lib/dns/rdtypes/dsbase.py index 80f792ac..d7850bee 100644 --- a/lib/dns/rdtypes/dsbase.py +++ b/lib/dns/rdtypes/dsbase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2010, 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -16,33 +18,24 @@ import struct import binascii +import dns.dnssec import dns.rdata import dns.rdatatype class DSBase(dns.rdata.Rdata): - """Base class for rdata that is like a DS record - - @ivar key_tag: the key tag - @type key_tag: int - @ivar algorithm: the algorithm - @type algorithm: int - @ivar digest_type: the digest type - @type digest_type: int - @ivar digest: the digest - @type digest: int - @see: draft-ietf-dnsext-delegation-signer-14.txt""" + """Base class for rdata that is like a DS record""" __slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest'] def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest): - super(DSBase, self).__init__(rdclass, rdtype) - self.key_tag = key_tag - self.algorithm = algorithm - self.digest_type = digest_type - self.digest = digest + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'key_tag', key_tag) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'digest_type', digest_type) + object.__setattr__(self, 'digest', digest) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %d %s' % (self.key_tag, self.algorithm, @@ -51,34 +44,24 @@ class DSBase(dns.rdata.Rdata): chunksize=128)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): key_tag = tok.get_uint16() - algorithm = tok.get_uint8() + algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) digest_type = tok.get_uint8() - chunks = [] - while 1: - t = tok.get().unescape() - if t.is_eol_or_eof(): - break - if not t.is_identifier(): - raise dns.exception.SyntaxError - chunks.append(t.value.encode()) - digest = b''.join(chunks) + digest = tok.concatenate_remaining_identifiers().encode() digest = binascii.unhexlify(digest) return cls(rdclass, rdtype, key_tag, algorithm, digest_type, digest) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): header = struct.pack("!HBB", self.key_tag, self.algorithm, self.digest_type) file.write(header) file.write(self.digest) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack("!HBB", wire[current: current + 4]) - current += 4 - rdlen -= 4 - digest = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!HBB") + digest = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], digest) - diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index 13109163..c1677a81 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -21,11 +21,9 @@ import dns.rdata class EUIBase(dns.rdata.Rdata): - """EUIxx record + """EUIxx record""" - @ivar fingerprint: xx-bit Extended Unique Identifier (EUI-xx) - @type fingerprint: string - @see: rfc7043.txt""" + # see: rfc7043.txt __slots__ = ['eui'] # define these in subclasses @@ -33,24 +31,24 @@ class EUIBase(dns.rdata.Rdata): # text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab def __init__(self, rdclass, rdtype, eui): - super(EUIBase, self).__init__(rdclass, rdtype) + super().__init__(rdclass, rdtype) if len(eui) != self.byte_len: raise dns.exception.FormError('EUI%s rdata has to have %s bytes' % (self.byte_len * 8, self.byte_len)) - self.eui = eui + object.__setattr__(self, 'eui', eui) def to_text(self, origin=None, relativize=True, **kw): return dns.rdata._hexify(self.eui, chunksize=2).replace(' ', '-') @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): text = tok.get_string() tok.get_eol() if len(text) != cls.text_len: raise dns.exception.SyntaxError( 'Input text must have %s characters' % cls.text_len) - expected_dash_idxs = range(2, cls.byte_len * 3 - 1, 3) - for i in expected_dash_idxs: + for i in range(2, cls.byte_len * 3 - 1, 3): if text[i] != '-': raise dns.exception.SyntaxError('Dash expected at position %s' % i) @@ -61,11 +59,10 @@ class EUIBase(dns.rdata.Rdata): raise dns.exception.SyntaxError('Hex decoding error: %s' % str(ex)) return cls(rdclass, rdtype, data) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(self.eui) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - eui = wire[current:current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + eui = parser.get_bytes(cls.byte_len) return cls(rdclass, rdtype, eui) - diff --git a/lib/dns/rdtypes/mxbase.py b/lib/dns/rdtypes/mxbase.py index 5ac8cef9..d6a6efed 100644 --- a/lib/dns/rdtypes/mxbase.py +++ b/lib/dns/rdtypes/mxbase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -15,7 +17,6 @@ """MX-like base classes.""" -from io import BytesIO import struct import dns.exception @@ -25,57 +26,38 @@ import dns.name class MXBase(dns.rdata.Rdata): - """Base class for rdata that is like an MX record. - - @ivar preference: the preference value - @type preference: int - @ivar exchange: the exchange name - @type exchange: dns.name.Name object""" + """Base class for rdata that is like an MX record.""" __slots__ = ['preference', 'exchange'] def __init__(self, rdclass, rdtype, preference, exchange): - super(MXBase, self).__init__(rdclass, rdtype) - self.preference = preference - self.exchange = exchange + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'preference', preference) + object.__setattr__(self, 'exchange', exchange) def to_text(self, origin=None, relativize=True, **kw): exchange = self.exchange.choose_relativity(origin, relativize) return '%d %s' % (self.preference, exchange) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): preference = tok.get_uint16() - exchange = tok.get_name() - exchange = exchange.choose_relativity(origin, relativize) + exchange = tok.get_name(origin, relativize, relativize_to) tok.get_eol() return cls(rdclass, rdtype, preference, exchange) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): pref = struct.pack("!H", self.preference) file.write(pref) - self.exchange.to_wire(file, compress, origin) - - def to_digestable(self, origin=None): - return struct.pack("!H", self.preference) + \ - self.exchange.to_digestable(origin) + self.exchange.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (preference, ) = struct.unpack('!H', wire[current: current + 2]) - current += 2 - rdlen -= 2 - (exchange, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - exchange = exchange.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + exchange = parser.get_name(origin) return cls(rdclass, rdtype, preference, exchange) - def choose_relativity(self, origin=None, relativize=True): - self.exchange = self.exchange.choose_relativity(origin, relativize) - class UncompressedMX(MXBase): @@ -83,13 +65,8 @@ class UncompressedMX(MXBase): is not compressed when converted to DNS wire format, and whose digestable form is not downcased.""" - def to_wire(self, file, compress=None, origin=None): - super(UncompressedMX, self).to_wire(file, None, origin) - - def to_digestable(self, origin=None): - f = BytesIO() - self.to_wire(f, None, origin) - return f.getvalue() + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + super()._to_wire(file, None, origin, False) class UncompressedDowncasingMX(MXBase): @@ -97,5 +74,5 @@ class UncompressedDowncasingMX(MXBase): """Base class for rdata that is like an MX record, but whose name is not compressed when convert to DNS wire format.""" - def to_wire(self, file, compress=None, origin=None): - super(UncompressedDowncasingMX, self).to_wire(file, None, origin) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + super()._to_wire(file, None, origin, canonicalize) diff --git a/lib/dns/rdtypes/nsbase.py b/lib/dns/rdtypes/nsbase.py index 79333a14..93d3ee53 100644 --- a/lib/dns/rdtypes/nsbase.py +++ b/lib/dns/rdtypes/nsbase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -15,8 +17,6 @@ """NS-like base classes.""" -from io import BytesIO - import dns.exception import dns.rdata import dns.name @@ -24,47 +24,33 @@ import dns.name class NSBase(dns.rdata.Rdata): - """Base class for rdata that is like an NS record. - - @ivar target: the target name of the rdata - @type target: dns.name.Name object""" + """Base class for rdata that is like an NS record.""" __slots__ = ['target'] def __init__(self, rdclass, rdtype, target): - super(NSBase, self).__init__(rdclass, rdtype) - self.target = target + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'target', target) def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) return str(target) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - target = tok.get_name() - target = target.choose_relativity(origin, relativize) + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): + target = tok.get_name(origin, relativize, relativize_to) tok.get_eol() return cls(rdclass, rdtype, target) - def to_wire(self, file, compress=None, origin=None): - self.target.to_wire(file, compress, origin) - - def to_digestable(self, origin=None): - return self.target.to_digestable(origin) + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.target.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (target, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - target = target.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + target = parser.get_name(origin) return cls(rdclass, rdtype, target) - def choose_relativity(self, origin=None, relativize=True): - self.target = self.target.choose_relativity(origin, relativize) - class UncompressedNS(NSBase): @@ -72,10 +58,5 @@ class UncompressedNS(NSBase): is not compressed when convert to DNS wire format, and whose digestable form is not downcased.""" - def to_wire(self, file, compress=None, origin=None): - super(UncompressedNS, self).to_wire(file, None, origin) - - def to_digestable(self, origin=None): - f = BytesIO() - self.to_wire(f, None, origin) - return f.getvalue() + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.target.to_wire(file, None, origin, False) diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index 54d7e6f0..ad0093da 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -1,4 +1,6 @@ -# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,54 +22,61 @@ import struct import dns.exception import dns.rdata import dns.tokenizer -from dns._compat import binary_type class TXTBase(dns.rdata.Rdata): - """Base class for rdata that is like a TXT record - - @ivar strings: the text strings - @type strings: list of string - @see: RFC 1035""" + """Base class for rdata that is like a TXT record (see RFC 1035).""" __slots__ = ['strings'] def __init__(self, rdclass, rdtype, strings): - super(TXTBase, self).__init__(rdclass, rdtype) - if isinstance(strings, str): - strings = [strings] - self.strings = strings[:] + """Initialize a TXT-like rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *strings*, a tuple of ``bytes`` + """ + super().__init__(rdclass, rdtype) + if isinstance(strings, (bytes, str)): + strings = (strings,) + encoded_strings = [] + for string in strings: + if isinstance(string, str): + string = string.encode() + else: + string = dns.rdata._constify(string) + encoded_strings.append(string) + object.__setattr__(self, 'strings', tuple(encoded_strings)) def to_text(self, origin=None, relativize=True, **kw): txt = '' prefix = '' for s in self.strings: - txt += '%s"%s"' % (prefix, dns.rdata._escapify(s)) + txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) prefix = ' ' return txt @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, + relativize_to=None): strings = [] while 1: - token = tok.get().unescape() + token = tok.get().unescape_to_bytes() if token.is_eol_or_eof(): break if not (token.is_quoted_string() or token.is_identifier()): raise dns.exception.SyntaxError("expected a string") if len(token.value) > 255: raise dns.exception.SyntaxError("string too long") - value = token.value - if isinstance(value, binary_type): - strings.append(value) - else: - strings.append(value.encode()) + strings.append(token.value) if len(strings) == 0: raise dns.exception.UnexpectedEnd return cls(rdclass, rdtype, strings) - def to_wire(self, file, compress=None, origin=None): + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for s in self.strings: l = len(s) assert l < 256 @@ -75,17 +84,9 @@ class TXTBase(dns.rdata.Rdata): file.write(s) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): strings = [] - while rdlen > 0: - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - s = wire[current: current + l].unwrap() - current += l - rdlen -= l + while parser.remaining() > 0: + s = parser.get_counted_bytes() strings.append(s) return cls(rdclass, rdtype, strings) - diff --git a/lib/dns/rdtypes/util.py b/lib/dns/rdtypes/util.py new file mode 100644 index 00000000..a63d1a0a --- /dev/null +++ b/lib/dns/rdtypes/util.py @@ -0,0 +1,166 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.name +import dns.ipv4 +import dns.ipv6 + +class Gateway: + """A helper class for the IPSECKEY gateway and AMTRELAY relay fields""" + name = "" + + def __init__(self, type, gateway=None): + self.type = type + self.gateway = gateway + + def _invalid_type(self): + return f"invalid {self.name} type: {self.type}" + + def check(self): + if self.type == 0: + if self.gateway not in (".", None): + raise SyntaxError(f"invalid {self.name} for type 0") + self.gateway = None + elif self.type == 1: + # check that it's OK + dns.ipv4.inet_aton(self.gateway) + elif self.type == 2: + # check that it's OK + dns.ipv6.inet_aton(self.gateway) + elif self.type == 3: + if not isinstance(self.gateway, dns.name.Name): + raise SyntaxError(f"invalid {self.name}; not a name") + else: + raise SyntaxError(self._invalid_type()) + + def to_text(self, origin=None, relativize=True): + if self.type == 0: + return "." + elif self.type in (1, 2): + return self.gateway + elif self.type == 3: + return str(self.gateway.choose_relativity(origin, relativize)) + else: + raise ValueError(self._invalid_type()) + + def from_text(self, tok, origin=None, relativize=True, relativize_to=None): + if self.type in (0, 1, 2): + return tok.get_string() + elif self.type == 3: + return tok.get_name(origin, relativize, relativize_to) + else: + raise dns.exception.SyntaxError(self._invalid_type()) + + def to_wire(self, file, compress=None, origin=None, canonicalize=False): + if self.type == 0: + pass + elif self.type == 1: + file.write(dns.ipv4.inet_aton(self.gateway)) + elif self.type == 2: + file.write(dns.ipv6.inet_aton(self.gateway)) + elif self.type == 3: + self.gateway.to_wire(file, None, origin, False) + else: + raise ValueError(self._invalid_type()) + + def from_wire_parser(self, parser, origin=None): + if self.type == 0: + return None + elif self.type == 1: + return dns.ipv4.inet_ntoa(parser.get_bytes(4)) + elif self.type == 2: + return dns.ipv6.inet_ntoa(parser.get_bytes(16)) + elif self.type == 3: + return parser.get_name(origin) + else: + raise dns.exception.FormError(self._invalid_type()) + +class Bitmap: + """A helper class for the NSEC/NSEC3/CSYNC type bitmaps""" + type_name = "" + + def __init__(self, windows=None): + self.windows = windows + + def to_text(self): + text = "" + for (window, bitmap) in self.windows: + bits = [] + for (i, byte) in enumerate(bitmap): + for j in range(0, 8): + if byte & (0x80 >> j): + rdtype = window * 256 + i * 8 + j + bits.append(dns.rdatatype.to_text(rdtype)) + text += (' ' + ' '.join(bits)) + return text + + def from_text(self, tok): + rdtypes = [] + while True: + token = tok.get().unescape() + if token.is_eol_or_eof(): + break + rdtype = dns.rdatatype.from_text(token.value) + if rdtype == 0: + raise dns.exception.SyntaxError(f"{self.type_name} with bit 0") + rdtypes.append(rdtype) + rdtypes.sort() + window = 0 + octets = 0 + prior_rdtype = 0 + bitmap = bytearray(b'\0' * 32) + windows = [] + for rdtype in rdtypes: + if rdtype == prior_rdtype: + continue + prior_rdtype = rdtype + new_window = rdtype // 256 + if new_window != window: + if octets != 0: + windows.append((window, bitmap[0:octets])) + bitmap = bytearray(b'\0' * 32) + window = new_window + offset = rdtype % 256 + byte = offset // 8 + bit = offset % 8 + octets = byte + 1 + bitmap[byte] = bitmap[byte] | (0x80 >> bit) + if octets != 0: + windows.append((window, bitmap[0:octets])) + return windows + + def to_wire(self, file): + for (window, bitmap) in self.windows: + file.write(struct.pack('!BB', window, len(bitmap))) + file.write(bitmap) + + def from_wire_parser(self, parser): + windows = [] + last_window = -1 + while parser.remaining() > 0: + window = parser.get_uint8() + if window <= last_window: + raise dns.exception.FormError(f"bad {self.type_name} bitmap") + bitmap = parser.get_counted_bytes() + if len(bitmap) == 0 or len(bitmap) > 32: + raise dns.exception.FormError(f"bad {self.type_name} octets") + windows.append((window, bitmap)) + last_window = window + return windows diff --git a/lib/dns/renderer.py b/lib/dns/renderer.py index ddc277cd..72f0f7a8 100644 --- a/lib/dns/renderer.py +++ b/lib/dns/renderer.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,15 +17,14 @@ """Help for building DNS wire format messages""" -from io import BytesIO +import contextlib +import io import struct import random import time -import sys import dns.exception import dns.tsig -from ._compat import long QUESTION = 0 @@ -32,8 +33,7 @@ AUTHORITY = 2 ADDITIONAL = 3 -class Renderer(object): - +class Renderer: """Helper class for building DNS wire-format messages. Most applications can use the higher-level L{dns.message.Message} @@ -55,43 +55,29 @@ class Renderer(object): r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac) wire = r.get_wire() - @ivar output: where rendering is written - @type output: BytesIO object - @ivar id: the message id - @type id: int - @ivar flags: the message flags - @type flags: int - @ivar max_size: the maximum size of the message - @type max_size: int - @ivar origin: the origin to use when rendering relative names - @type origin: dns.name.Name object - @ivar compress: the compression table - @type compress: dict - @ivar section: the section currently being rendered - @type section: int (dns.renderer.QUESTION, dns.renderer.ANSWER, - dns.renderer.AUTHORITY, or dns.renderer.ADDITIONAL) - @ivar counts: list of the number of RRs in each section - @type counts: int list of length 4 - @ivar mac: the MAC of the rendered message (if TSIG was used) - @type mac: string + output, an io.BytesIO, where rendering is written + + id: the message id + + flags: the message flags + + max_size: the maximum size of the message + + origin: the origin to use when rendering relative names + + compress: the compression table + + section: an int, the section currently being rendered + + counts: list of the number of RRs in each section + + mac: the MAC of the rendered message (if TSIG was used) """ def __init__(self, id=None, flags=0, max_size=65535, origin=None): - """Initialize a new renderer. + """Initialize a new renderer.""" - @param id: the message id - @type id: int - @param flags: the DNS message flags - @type flags: int - @param max_size: the maximum message size; the default is 65535. - If rendering results in a message greater than I{max_size}, - then L{dns.exception.TooBig} will be raised. - @type max_size: int - @param origin: the origin to use when rendering relative names - @type origin: dns.name.Name or None. - """ - - self.output = BytesIO() + self.output = io.BytesIO() if id is None: self.id = random.randint(0, 65535) else: @@ -106,12 +92,9 @@ class Renderer(object): self.mac = '' def _rollback(self, where): - """Truncate the output buffer at offset I{where}, and remove any + """Truncate the output buffer at offset *where*, and remove any compression table entries that pointed beyond the truncation point. - - @param where: the offset - @type where: int """ self.output.seek(where) @@ -129,9 +112,7 @@ class Renderer(object): Sections must be rendered order: QUESTION, ANSWER, AUTHORITY, ADDITIONAL. Sections may be empty. - @param section: the section - @type section: int - @raises dns.exception.FormError: an attempt was made to set + Raises dns.exception.FormError if an attempt was made to set a section value less than the current section. """ @@ -140,25 +121,21 @@ class Renderer(object): raise dns.exception.FormError self.section = section - def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): - """Add a question to the message. + @contextlib.contextmanager + def _track_size(self): + start = self.output.tell() + yield start + if self.output.tell() > self.max_size: + self._rollback(start) + raise dns.exception.TooBig - @param qname: the question name - @type qname: dns.name.Name - @param rdtype: the question rdata type - @type rdtype: int - @param rdclass: the question rdata class - @type rdclass: int - """ + def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): + """Add a question to the message.""" self._set_section(QUESTION) - before = self.output.tell() - qname.to_wire(self.output, self.compress, self.origin) - self.output.write(struct.pack("!HH", rdtype, rdclass)) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + qname.to_wire(self.output, self.compress, self.origin) + self.output.write(struct.pack("!HH", rdtype, rdclass)) self.counts[QUESTION] += 1 def add_rrset(self, section, rrset, **kw): @@ -166,20 +143,11 @@ class Renderer(object): Any keyword arguments are passed on to the rdataset's to_wire() routine. - - @param section: the section - @type section: int - @param rrset: the rrset - @type rrset: dns.rrset.RRset object """ self._set_section(section) - before = self.output.tell() - n = rrset.to_wire(self.output, self.compress, self.origin, **kw) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + n = rrset.to_wire(self.output, self.compress, self.origin, **kw) self.counts[section] += n def add_rdataset(self, section, name, rdataset, **kw): @@ -188,124 +156,79 @@ class Renderer(object): Any keyword arguments are passed on to the rdataset's to_wire() routine. - - @param section: the section - @type section: int - @param name: the owner name - @type name: dns.name.Name object - @param rdataset: the rdataset - @type rdataset: dns.rdataset.Rdataset object """ self._set_section(section) - before = self.output.tell() - n = rdataset.to_wire(name, self.output, self.compress, self.origin, - **kw) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + n = rdataset.to_wire(name, self.output, self.compress, self.origin, + **kw) self.counts[section] += n def add_edns(self, edns, ednsflags, payload, options=None): - """Add an EDNS OPT record to the message. - - @param edns: The EDNS level to use. - @type edns: int - @param ednsflags: EDNS flag values. - @type ednsflags: int - @param payload: The EDNS sender's payload field, which is the maximum - size of UDP datagram the sender can handle. - @type payload: int - @param options: The EDNS options list - @type options: list of dns.edns.Option instances - @see: RFC 2671 - """ + """Add an EDNS OPT record to the message.""" # make sure the EDNS version in ednsflags agrees with edns - ednsflags &= long(0xFF00FFFF) + ednsflags &= 0xFF00FFFF ednsflags |= (edns << 16) - self._set_section(ADDITIONAL) - before = self.output.tell() - self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload, - ednsflags, 0)) - if options is not None: - lstart = self.output.tell() - for opt in options: - stuff = struct.pack("!HH", opt.otype, 0) - self.output.write(stuff) - start = self.output.tell() - opt.to_wire(self.output) - end = self.output.tell() - assert end - start < 65536 - self.output.seek(start - 2) - stuff = struct.pack("!H", end - start) - self.output.write(stuff) - self.output.seek(0, 2) - lend = self.output.tell() - assert lend - lstart < 65536 - self.output.seek(lstart - 2) - stuff = struct.pack("!H", lend - lstart) - self.output.write(stuff) - self.output.seek(0, 2) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig - self.counts[ADDITIONAL] += 1 + opt = dns.message.Message._make_opt(ednsflags, payload, options) + self.add_rrset(ADDITIONAL, opt) def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, request_mac, algorithm=dns.tsig.default_algorithm): - """Add a TSIG signature to the message. + """Add a TSIG signature to the message.""" - @param keyname: the TSIG key name - @type keyname: dns.name.Name object - @param secret: the secret to use - @type secret: string - @param fudge: TSIG time fudge - @type fudge: int - @param id: the message id to encode in the tsig signature - @type id: int - @param tsig_error: TSIG error code; default is 0. - @type tsig_error: int - @param other_data: TSIG other data. - @type other_data: string - @param request_mac: This message is a response to the request which - had the specified MAC. - @type request_mac: string - @param algorithm: the TSIG algorithm to use - @type algorithm: dns.name.Name object - """ - - self._set_section(ADDITIONAL) - before = self.output.tell() s = self.output.getvalue() - (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s, - keyname, - secret, - int(time.time()), - fudge, - id, - tsig_error, - other_data, - request_mac, - algorithm=algorithm) - keyname.to_wire(self.output, self.compress, self.origin) - self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, - dns.rdataclass.ANY, 0, 0)) - rdata_start = self.output.tell() - self.output.write(tsig_rdata) + + if isinstance(secret, dns.tsig.Key): + key = secret + else: + key = dns.tsig.Key(keyname, secret, algorithm) + tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, + b'', id, tsig_error, other_data) + (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), + request_mac) + self._write_tsig(tsig, keyname) + + def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error, + other_data, request_mac, + algorithm=dns.tsig.default_algorithm): + """Add a TSIG signature to the message. Unlike add_tsig(), this can be + used for a series of consecutive DNS envelopes, e.g. for a zone + transfer over TCP [RFC2845, 4.4]. + + For the first message in the sequence, give ctx=None. For each + subsequent message, give the ctx that was returned from the + add_multi_tsig() call for the previous message.""" + + s = self.output.getvalue() + + if isinstance(secret, dns.tsig.Key): + key = secret + else: + key = dns.tsig.Key(keyname, secret, algorithm) + tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, + b'', id, tsig_error, other_data) + (tsig, ctx) = dns.tsig.sign(s, key, tsig[0], int(time.time()), + request_mac, ctx, True) + self._write_tsig(tsig, keyname) + return ctx + + def _write_tsig(self, tsig, keyname): + self._set_section(ADDITIONAL) + with self._track_size(): + keyname.to_wire(self.output, self.compress, self.origin) + self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, + dns.rdataclass.ANY, 0, 0)) + rdata_start = self.output.tell() + tsig.to_wire(self.output) + after = self.output.tell() - assert after - rdata_start < 65536 - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig self.output.seek(rdata_start - 2) self.output.write(struct.pack('!H', after - rdata_start)) self.counts[ADDITIONAL] += 1 self.output.seek(10) self.output.write(struct.pack('!H', self.counts[ADDITIONAL])) - self.output.seek(0, 2) + self.output.seek(0, io.SEEK_END) def write_header(self): """Write the DNS message header. @@ -319,12 +242,9 @@ class Renderer(object): self.output.write(struct.pack('!HHHHHH', self.id, self.flags, self.counts[0], self.counts[1], self.counts[2], self.counts[3])) - self.output.seek(0, 2) + self.output.seek(0, io.SEEK_END) def get_wire(self): - """Return the wire format message. - - @rtype: string - """ + """Return the wire format message.""" return self.output.getvalue() diff --git a/lib/dns/resolver.py b/lib/dns/resolver.py index bccb430d..4f630e4d 100644 --- a/lib/dns/resolver.py +++ b/lib/dns/resolver.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,23 +15,22 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS stub resolver. - -@var default_resolver: The default resolver object -@type default_resolver: dns.resolver.Resolver object""" - +"""DNS stub resolver.""" +from urllib.parse import urlparse +import contextlib import socket import sys import time import random - +import warnings try: import threading as _threading -except ImportError: - import dummy_threading as _threading +except ImportError: # pragma: no cover + import dummy_threading as _threading # type: ignore import dns.exception import dns.flags +import dns.inet import dns.ipv4 import dns.ipv6 import dns.message @@ -40,36 +41,92 @@ import dns.rdataclass import dns.rdatatype import dns.reversename import dns.tsig -from ._compat import xrange, string_types if sys.platform == 'win32': - try: - import winreg as _winreg - except ImportError: - import _winreg + import winreg # pragma: no cover class NXDOMAIN(dns.exception.DNSException): - """The DNS query name does not exist.""" - supp_kwargs = set(['qname']) + supp_kwargs = {'qnames', 'responses'} + fmt = None # we have our own __str__ implementation + + def _check_kwargs(self, qnames, responses=None): + if not isinstance(qnames, (list, tuple, set)): + raise AttributeError("qnames must be a list, tuple or set") + if len(qnames) == 0: + raise AttributeError("qnames must contain at least one element") + if responses is None: + responses = {} + elif not isinstance(responses, dict): + raise AttributeError("responses must be a dict(qname=response)") + kwargs = dict(qnames=qnames, responses=responses) + return kwargs def __str__(self): - if 'qname' not in self.kwargs: - return super(NXDOMAIN, self).__str__() + if 'qnames' not in self.kwargs: + return super().__str__() + qnames = self.kwargs['qnames'] + if len(qnames) > 1: + msg = 'None of DNS query names exist' + else: + msg = 'The DNS query name does not exist' + qnames = ', '.join(map(str, qnames)) + return "{}: {}".format(msg, qnames) - qname = self.kwargs['qname'] - msg = self.__doc__[:-1] - if isinstance(qname, (list, set)): - if len(qname) > 1: - msg = 'None of DNS query names exist' - qname = list(map(str, qname)) - else: - qname = qname[0] - return "%s: %s" % (msg, (str(qname))) + @property + def canonical_name(self): + """Return the unresolved canonical name.""" + if 'qnames' not in self.kwargs: + raise TypeError("parametrized exception required") + IN = dns.rdataclass.IN + CNAME = dns.rdatatype.CNAME + cname = None + for qname in self.kwargs['qnames']: + response = self.kwargs['responses'][qname] + for answer in response.answer: + if answer.rdtype != CNAME or answer.rdclass != IN: + continue + cname = answer[0].target.to_text() + if cname is not None: + return dns.name.from_text(cname) + return self.kwargs['qnames'][0] + + def __add__(self, e_nx): + """Augment by results from another NXDOMAIN exception.""" + qnames0 = list(self.kwargs.get('qnames', [])) + responses0 = dict(self.kwargs.get('responses', {})) + responses1 = e_nx.kwargs.get('responses', {}) + for qname1 in e_nx.kwargs.get('qnames', []): + if qname1 not in qnames0: + qnames0.append(qname1) + if qname1 in responses1: + responses0[qname1] = responses1[qname1] + return NXDOMAIN(qnames=qnames0, responses=responses0) + + def qnames(self): + """All of the names that were tried. + + Returns a list of ``dns.name.Name``. + """ + return self.kwargs['qnames'] + + def responses(self): + """A map from queried names to their NXDOMAIN responses. + + Returns a dict mapping a ``dns.name.Name`` to a + ``dns.message.Message``. + """ + return self.kwargs['responses'] + + def response(self, qname): + """The response for query *qname*. + + Returns a ``dns.message.Message``. + """ + return self.kwargs['responses'][qname] class YXDOMAIN(dns.exception.DNSException): - """The DNS query name is too long after DNAME substitution.""" # The definition of the Timeout exception has moved from here to the @@ -80,95 +137,78 @@ Timeout = dns.exception.Timeout class NoAnswer(dns.exception.DNSException): - """The DNS response does not contain an answer to the question.""" fmt = 'The DNS response does not contain an answer ' + \ 'to the question: {query}' - supp_kwargs = set(['response']) + supp_kwargs = {'response'} def _fmt_kwargs(self, **kwargs): - return super(NoAnswer, self)._fmt_kwargs( - query=kwargs['response'].question) + return super()._fmt_kwargs(query=kwargs['response'].question) class NoNameservers(dns.exception.DNSException): - """All nameservers failed to answer the query. - @param errors: list of servers and respective errors - @type errors: [(server ip address, any object convertible to string)] + errors: list of servers and respective errors + The type of errors is + [(server IP address, any object convertible to string)]. Non-empty errors list will add explanatory message () """ msg = "All nameservers failed to answer the query." fmt = "%s {query}: {errors}" % msg[:-1] - supp_kwargs = set(['request', 'errors']) + supp_kwargs = {'request', 'errors'} def _fmt_kwargs(self, **kwargs): srv_msgs = [] for err in kwargs['errors']: - srv_msgs.append('Server %s %s port %s answered %s' % (err[0], + srv_msgs.append('Server {} {} port {} answered {}'.format(err[0], 'TCP' if err[1] else 'UDP', err[2], err[3])) - return super(NoNameservers, self)._fmt_kwargs( - query=kwargs['request'].question, errors='; '.join(srv_msgs)) + return super()._fmt_kwargs(query=kwargs['request'].question, + errors='; '.join(srv_msgs)) class NotAbsolute(dns.exception.DNSException): - """An absolute domain name is required but a relative name was provided.""" class NoRootSOA(dns.exception.DNSException): - """There is no SOA RR at the DNS root name. This should never happen!""" class NoMetaqueries(dns.exception.DNSException): - """DNS metaqueries are not allowed.""" +class NoResolverConfiguration(dns.exception.DNSException): + """Resolver configuration could not be read or specified no nameservers.""" -class Answer(object): - - """DNS stub resolver answer +class Answer: + """DNS stub resolver answer. Instances of this class bundle up the result of a successful DNS resolution. For convenience, the answer object implements much of the sequence - protocol, forwarding to its rrset. E.g. "for a in answer" is - equivalent to "for a in answer.rrset", "answer[i]" is equivalent - to "answer.rrset[i]", and "answer[i:j]" is equivalent to - "answer.rrset[i:j]". + protocol, forwarding to its ``rrset`` attribute. E.g. + ``for a in answer`` is equivalent to ``for a in answer.rrset``. + ``answer[i]`` is equivalent to ``answer.rrset[i]``, and + ``answer[i:j]`` is equivalent to ``answer.rrset[i:j]``. Note that CNAMEs or DNAMEs in the response may mean that answer - node's name might not be the query name. - - @ivar qname: The query name - @type qname: dns.name.Name object - @ivar rdtype: The query type - @type rdtype: int - @ivar rdclass: The query class - @type rdclass: int - @ivar response: The response message - @type response: dns.message.Message object - @ivar rrset: The answer - @type rrset: dns.rrset.RRset object - @ivar expiration: The time when the answer expires - @type expiration: float (seconds since the epoch) - @ivar canonical_name: The canonical name of the query name - @type canonical_name: dns.name.Name object + RRset's name might not be the query name. """ - def __init__(self, qname, rdtype, rdclass, response, - raise_on_no_answer=True): + def __init__(self, qname, rdtype, rdclass, response, nameserver=None, + port=None): self.qname = qname self.rdtype = rdtype self.rdclass = rdclass self.response = response + self.nameserver = nameserver + self.port = port min_ttl = -1 rrset = None - for count in xrange(0, 15): + for count in range(0, 15): try: rrset = response.find_rrset(response.answer, qname, rdclass, rdtype) @@ -189,12 +229,8 @@ class Answer(object): break continue except KeyError: - if raise_on_no_answer: - raise NoAnswer(response=response) - if raise_on_no_answer: - raise NoAnswer(response=response) - if rrset is None and raise_on_no_answer: - raise NoAnswer(response=response) + # Exit the chaining loop + break self.canonical_name = qname self.rrset = rrset if rrset is None: @@ -216,7 +252,7 @@ class Answer(object): break self.expiration = time.time() + min_ttl - def __getattr__(self, attr): + def __getattr__(self, attr): # pragma: no cover if attr == 'name': return self.rrset.name elif attr == 'ttl': @@ -231,44 +267,28 @@ class Answer(object): raise AttributeError(attr) def __len__(self): - return len(self.rrset) + return self.rrset and len(self.rrset) or 0 def __iter__(self): - return iter(self.rrset) + return self.rrset and iter(self.rrset) or iter(tuple()) def __getitem__(self, i): + if self.rrset is None: + raise IndexError return self.rrset[i] def __delitem__(self, i): + if self.rrset is None: + raise IndexError del self.rrset[i] - def __getslice__(self, i, j): - return self.rrset[i:j] - def __delslice__(self, i, j): - del self.rrset[i:j] - - -class Cache(object): - - """Simple DNS answer cache. - - @ivar data: A dictionary of cached data - @type data: dict - @ivar cleaning_interval: The number of seconds between cleanings. The - default is 300 (5 minutes). - @type cleaning_interval: float - @ivar next_cleaning: The time the cache should next be cleaned (in seconds - since the epoch.) - @type next_cleaning: float - """ +class Cache: + """Simple thread-safe DNS answer cache.""" def __init__(self, cleaning_interval=300.0): - """Initialize a DNS cache. - - @param cleaning_interval: the number of seconds between periodic - cleanings. The default is 300.0 - @type cleaning_interval: float. + """*cleaning_interval*, a ``float`` is the number of seconds between + periodic cleanings. """ self.data = {} @@ -291,66 +311,57 @@ class Cache(object): self.next_cleaning = now + self.cleaning_interval def get(self, key): - """Get the answer associated with I{key}. Returns None if - no answer is cached for the key. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @rtype: dns.resolver.Answer object or None + """Get the answer associated with *key*. + + Returns None if no answer is cached for the key. + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + Returns a ``dns.resolver.Answer`` or ``None``. """ - try: - self.lock.acquire() + with self.lock: self._maybe_clean() v = self.data.get(key) if v is None or v.expiration <= time.time(): return None return v - finally: - self.lock.release() def put(self, key, value): """Associate key and value in the cache. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @param value: The answer being cached - @type value: dns.resolver.Answer object + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + *value*, a ``dns.resolver.Answer``, the answer. """ - try: - self.lock.acquire() + with self.lock: self._maybe_clean() self.data[key] = value - finally: - self.lock.release() def flush(self, key=None): """Flush the cache. - If I{key} is specified, only that item is flushed. Otherwise + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache is flushed. - @param key: the key to flush - @type key: (dns.name.Name, int, int) tuple or None + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. """ - try: - self.lock.acquire() + with self.lock: if key is not None: if key in self.data: del self.data[key] else: self.data = {} self.next_cleaning = time.time() + self.cleaning_interval - finally: - self.lock.release() -class LRUCacheNode(object): - - """LRUCache node. - """ +class LRUCacheNode: + """LRUCache node.""" def __init__(self, key, value): self.key = key @@ -358,12 +369,6 @@ class LRUCacheNode(object): self.prev = self self.next = self - def link_before(self, node): - self.prev = node.prev - self.next = node - node.prev.next = self - node.prev = self - def link_after(self, node): self.prev = node self.next = node.next @@ -375,34 +380,26 @@ class LRUCacheNode(object): self.prev.next = self.next -class LRUCache(object): - - """Bounded least-recently-used DNS answer cache. +class LRUCache: + """Thread-safe, bounded, least-recently-used DNS answer cache. This cache is better than the simple cache (above) if you're running a web crawler or other process that does a lot of resolutions. The LRUCache has a maximum number of nodes, and when it is full, the least-recently used node is removed to make space for a new one. - - @ivar data: A dictionary of cached data - @type data: dict - @ivar sentinel: sentinel node for circular doubly linked list of nodes - @type sentinel: LRUCacheNode object - @ivar max_size: The maximum number of nodes - @type max_size: int """ def __init__(self, max_size=100000): - """Initialize a DNS cache. - - @param max_size: The maximum number of nodes to cache; the default is - 100000. Must be > 1. - @type max_size: int + """*max_size*, an ``int``, is the maximum number of nodes to cache; + it must be greater than 0. """ + self.data = {} self.set_max_size(max_size) self.sentinel = LRUCacheNode(None, None) + self.sentinel.prev = self.sentinel + self.sentinel.next = self.sentinel self.lock = _threading.Lock() def set_max_size(self, max_size): @@ -411,15 +408,17 @@ class LRUCache(object): self.max_size = max_size def get(self, key): - """Get the answer associated with I{key}. Returns None if - no answer is cached for the key. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @rtype: dns.resolver.Answer object or None + """Get the answer associated with *key*. + + Returns None if no answer is cached for the key. + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + Returns a ``dns.resolver.Answer`` or ``None``. """ - try: - self.lock.acquire() + + with self.lock: node = self.data.get(key) if node is None: return None @@ -431,19 +430,17 @@ class LRUCache(object): return None node.link_after(self.sentinel) return node.value - finally: - self.lock.release() def put(self, key, value): """Associate key and value in the cache. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @param value: The answer being cached - @type value: dns.resolver.Answer object + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + *value*, a ``dns.resolver.Answer``, the answer. """ - try: - self.lock.acquire() + + with self.lock: node = self.data.get(key) if node is not None: node.unlink() @@ -455,20 +452,18 @@ class LRUCache(object): node = LRUCacheNode(key, value) node.link_after(self.sentinel) self.data[key] = node - finally: - self.lock.release() def flush(self, key=None): """Flush the cache. - If I{key} is specified, only that item is flushed. Otherwise + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache is flushed. - @param key: the key to flush - @type key: (dns.name.Name, int, int) tuple or None + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. """ - try: - self.lock.acquire() + + with self.lock: if key is not None: node = self.data.get(key) if node is not None: @@ -478,81 +473,230 @@ class LRUCache(object): node = self.sentinel.next while node != self.sentinel: next = node.next - node.prev = None - node.next = None + node.unlink() node = next self.data = {} - finally: - self.lock.release() +class _Resolution: + """Helper class for dns.resolver.Resolver.resolve(). -class Resolver(object): + All of the "business logic" of resolution is encapsulated in this + class, allowing us to have multiple resolve() implementations + using different I/O schemes without copying all of the + complicated logic. - """DNS stub resolver - - @ivar domain: The domain of this host - @type domain: dns.name.Name object - @ivar nameservers: A list of nameservers to query. Each nameserver is - a string which contains the IP address of a nameserver. - @type nameservers: list of strings - @ivar search: The search list. If the query name is a relative name, - the resolver will construct an absolute query name by appending the search - names one by one to the query name. - @type search: list of dns.name.Name objects - @ivar port: The port to which to send queries. The default is 53. - @type port: int - @ivar timeout: The number of seconds to wait for a response from a - server, before timing out. - @type timeout: float - @ivar lifetime: The total number of seconds to spend trying to get an - answer to the question. If the lifetime expires, a Timeout exception - will occur. - @type lifetime: float - @ivar keyring: The TSIG keyring to use. The default is None. - @type keyring: dict - @ivar keyname: The TSIG keyname to use. The default is None. - @type keyname: dns.name.Name object - @ivar keyalgorithm: The TSIG key algorithm to use. The default is - dns.tsig.default_algorithm. - @type keyalgorithm: string - @ivar edns: The EDNS level to use. The default is -1, no Edns. - @type edns: int - @ivar ednsflags: The EDNS flags - @type ednsflags: int - @ivar payload: The EDNS payload size. The default is 0. - @type payload: int - @ivar flags: The message flags to use. The default is None (i.e. not - overwritten) - @type flags: int - @ivar cache: The cache to use. The default is None. - @type cache: dns.resolver.Cache object - @ivar retry_servfail: should we retry a nameserver if it says SERVFAIL? - The default is 'false'. - @type retry_servfail: bool + This class is a "friend" to dns.resolver.Resolver and manipulates + resolver data structures directly. """ - def __init__(self, filename='/etc/resolv.conf', configure=True): - """Initialize a resolver instance. + def __init__(self, resolver, qname, rdtype, rdclass, tcp, + raise_on_no_answer, search): + if isinstance(qname, str): + qname = dns.name.from_text(qname, None) + rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(rdtype): + raise NoMetaqueries + rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(rdclass): + raise NoMetaqueries + self.resolver = resolver + self.qnames_to_try = resolver._get_qnames_to_try(qname, search) + self.qnames = self.qnames_to_try[:] + self.rdtype = rdtype + self.rdclass = rdclass + self.tcp = tcp + self.raise_on_no_answer = raise_on_no_answer + self.nxdomain_responses = {} + # + # Initialize other things to help analysis tools + self.qname = dns.name.empty + self.nameservers = [] + self.current_nameservers = [] + self.errors = [] + self.nameserver = None + self.port = 0 + self.tcp_attempt = False + self.retry_with_tcp = False + self.request = None + self.backoff = 0 - @param filename: The filename of a configuration file in - standard /etc/resolv.conf format. This parameter is meaningful - only when I{configure} is true and the platform is POSIX. - @type filename: string or file object - @param configure: If True (the default), the resolver instance - is configured in the normal fashion for the operating system - the resolver is running on. (I.e. a /etc/resolv.conf file on - POSIX systems and from the registry on Windows systems.) - @type configure: bool""" + def next_request(self): + """Get the next request to send, and check the cache. + + Returns a (request, answer) tuple. At most one of request or + answer will not be None. + """ + + # We return a tuple instead of Union[Message,Answer] as it lets + # the caller avoid isinstance(). + + while len(self.qnames) > 0: + self.qname = self.qnames.pop(0) + + # Do we know the answer? + if self.resolver.cache: + answer = self.resolver.cache.get((self.qname, self.rdtype, + self.rdclass)) + if answer is not None: + if answer.rrset is None and self.raise_on_no_answer: + raise NoAnswer(response=answer.response) + else: + return (None, answer) + answer = self.resolver.cache.get((self.qname, + dns.rdatatype.ANY, + self.rdclass)) + if answer is not None and \ + answer.response.rcode() == dns.rcode.NXDOMAIN: + # cached NXDOMAIN; record it and continue to next + # name. + self.nxdomain_responses[self.qname] = answer.response + continue + + # Build the request + request = dns.message.make_query(self.qname, self.rdtype, + self.rdclass) + if self.resolver.keyname is not None: + request.use_tsig(self.resolver.keyring, self.resolver.keyname, + algorithm=self.resolver.keyalgorithm) + request.use_edns(self.resolver.edns, self.resolver.ednsflags, + self.resolver.payload) + if self.resolver.flags is not None: + request.flags = self.resolver.flags + + self.nameservers = self.resolver.nameservers[:] + if self.resolver.rotate: + random.shuffle(self.nameservers) + self.current_nameservers = self.nameservers[:] + self.errors = [] + self.nameserver = None + self.tcp_attempt = False + self.retry_with_tcp = False + self.request = request + self.backoff = 0.10 + + return (request, None) + + # + # We've tried everything and only gotten NXDOMAINs. (We know + # it's only NXDOMAINs as anything else would have returned + # before now.) + # + raise NXDOMAIN(qnames=self.qnames_to_try, + responses=self.nxdomain_responses) + + def next_nameserver(self): + if self.retry_with_tcp: + assert self.nameserver is not None + self.tcp_attempt = True + self.retry_with_tcp = False + return (self.nameserver, self.port, True, 0) + + backoff = 0 + if not self.current_nameservers: + if len(self.nameservers) == 0: + # Out of things to try! + raise NoNameservers(request=self.request, errors=self.errors) + self.current_nameservers = self.nameservers[:] + backoff = self.backoff + self.backoff = min(self.backoff * 2, 2) + + self.nameserver = self.current_nameservers.pop(0) + self.port = self.resolver.nameserver_ports.get(self.nameserver, + self.resolver.port) + self.tcp_attempt = self.tcp + return (self.nameserver, self.port, self.tcp_attempt, backoff) + + def query_result(self, response, ex): + # + # returns an (answer: Answer, end_loop: bool) tuple. + # + if ex: + # Exception during I/O or from_wire() + assert response is None + self.errors.append((self.nameserver, self.tcp_attempt, self.port, + ex, response)) + if isinstance(ex, dns.exception.FormError) or \ + isinstance(ex, EOFError) or \ + isinstance(ex, OSError) or \ + isinstance(ex, NotImplementedError): + # This nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + elif isinstance(ex, dns.message.Truncated): + if self.tcp_attempt: + # Truncation with TCP is no good! + self.nameservers.remove(self.nameserver) + else: + self.retry_with_tcp = True + return (None, False) + # We got an answer! + assert response is not None + rcode = response.rcode() + if rcode == dns.rcode.NOERROR: + answer = Answer(self.qname, self.rdtype, self.rdclass, response, + self.nameserver, self.port) + if self.resolver.cache: + self.resolver.cache.put((self.qname, self.rdtype, + self.rdclass), answer) + if answer.rrset is None and self.raise_on_no_answer: + raise NoAnswer(response=answer.response) + return (answer, True) + elif rcode == dns.rcode.NXDOMAIN: + self.nxdomain_responses[self.qname] = response + # Make next_nameserver() return None, so caller breaks its + # inner loop and calls next_request(). + if self.resolver.cache: + answer = Answer(self.qname, dns.rdatatype.ANY, + dns.rdataclass.IN, response) + self.resolver.cache.put((self.qname, + dns.rdatatype.ANY, + self.rdclass), answer) + + return (None, True) + elif rcode == dns.rcode.YXDOMAIN: + yex = YXDOMAIN() + self.errors.append((self.nameserver, self.tcp_attempt, + self.port, yex, response)) + raise yex + else: + # + # We got a response, but we're not happy with the + # rcode in it. + # + if rcode != dns.rcode.SERVFAIL or not self.resolver.retry_servfail: + self.nameservers.remove(self.nameserver) + self.errors.append((self.nameserver, self.tcp_attempt, self.port, + dns.rcode.to_text(rcode), response)) + return (None, False) + +class Resolver: + """DNS stub resolver.""" + + # We initialize in reset() + # + # pylint: disable=attribute-defined-outside-init + + def __init__(self, filename='/etc/resolv.conf', configure=True): + """*filename*, a ``str`` or file object, specifying a file + in standard /etc/resolv.conf format. This parameter is meaningful + only when *configure* is true and the platform is POSIX. + + *configure*, a ``bool``. If True (the default), the resolver + instance is configured in the normal fashion for the operating + system the resolver is running on. (I.e. by reading a + /etc/resolv.conf file on POSIX systems and from the registry + on Windows systems.) + """ self.reset() if configure: if sys.platform == 'win32': - self.read_registry() + self.read_registry() # pragma: no cover elif filename: self.read_resolv_conf(filename) def reset(self): """Reset all resolver configuration to the defaults.""" + self.domain = \ dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) if len(self.domain) == 0: @@ -561,8 +705,9 @@ class Resolver(object): self.nameserver_ports = {} self.port = 53 self.search = [] + self.use_search_by_default = False self.timeout = 2.0 - self.lifetime = 30.0 + self.lifetime = 5.0 self.keyring = None self.keyname = None self.keyalgorithm = dns.tsig.default_algorithm @@ -573,23 +718,33 @@ class Resolver(object): self.flags = None self.retry_servfail = False self.rotate = False + self.ndots = None def read_resolv_conf(self, f): - """Process f as a file in the /etc/resolv.conf format. If f is - a string, it is used as the name of the file to open; otherwise it - is treated as the file itself.""" - if isinstance(f, string_types): - try: - f = open(f, 'r') - except IOError: - # /etc/resolv.conf doesn't exist, can't be read, etc. - # We'll just use the default resolver configuration. - self.nameservers = ['127.0.0.1'] - return - want_close = True - else: - want_close = False - try: + """Process *f* as a file in the /etc/resolv.conf format. If f is + a ``str``, it is used as the name of the file to open; otherwise it + is treated as the file itself. + + Interprets the following items: + + - nameserver - name server IP address + + - domain - local domain name + + - search - search list for host-name lookup + + - options - supported options are rotate, timeout, edns0, and ndots + + """ + + with contextlib.ExitStack() as stack: + if isinstance(f, str): + try: + f = stack.enter_context(open(f)) + except OSError: + # /etc/resolv.conf doesn't exist, can't be read, etc. + raise NoResolverConfiguration + for l in f: if len(l) == 0 or l[0] == '#' or l[0] == ';': continue @@ -607,13 +762,23 @@ class Resolver(object): for suffix in tokens[1:]: self.search.append(dns.name.from_text(suffix)) elif tokens[0] == 'options': - if 'rotate' in tokens[1:]: - self.rotate = True - finally: - if want_close: - f.close() + for opt in tokens[1:]: + if opt == 'rotate': + self.rotate = True + elif opt == 'edns0': + self.use_edns(0, 0, 0) + elif 'timeout' in opt: + try: + self.timeout = int(opt.split(':')[1]) + except (ValueError, IndexError): + pass + elif 'ndots' in opt: + try: + self.ndots = int(opt.split(':')[1]) + except (ValueError, IndexError): + pass if len(self.nameservers) == 0: - self.nameservers.append('127.0.0.1') + raise NoResolverConfiguration def _determine_split_char(self, entry): # @@ -621,9 +786,9 @@ class Resolver(object): # delimiter in between ' ' and ',' (and vice-versa) in various # versions of windows. # - if entry.find(' ') >= 0: + if entry.find(' ') >= 0: # pragma: no cover split_char = ' ' - elif entry.find(',') >= 0: + elif entry.find(',') >= 0: # pragma: no cover split_char = ',' else: # probably a singleton; treat as a space-separated list. @@ -631,7 +796,6 @@ class Resolver(object): return split_char def _config_win32_nameservers(self, nameservers): - """Configure a NameServer registry entry.""" # we call str() on nameservers to convert it from unicode to ascii nameservers = str(nameservers) split_char = self._determine_split_char(nameservers) @@ -640,13 +804,11 @@ class Resolver(object): if ns not in self.nameservers: self.nameservers.append(ns) - def _config_win32_domain(self, domain): - """Configure a Domain registry entry.""" + def _config_win32_domain(self, domain): # pragma: no cover # we call str() on domain to convert it from unicode to ascii self.domain = dns.name.from_text(str(domain)) - def _config_win32_search(self, search): - """Configure a Search registry entry.""" + def _config_win32_search(self, search): # pragma: no cover # we call str() on search to convert it from unicode to ascii search = str(search) split_char = self._determine_split_char(search) @@ -655,86 +817,78 @@ class Resolver(object): if s not in self.search: self.search.append(dns.name.from_text(s)) - def _config_win32_fromkey(self, key): - """Extract DNS info from a registry key.""" + def _config_win32_fromkey(self, key, always_try_domain): try: - servers, rtype = _winreg.QueryValueEx(key, 'NameServer') - except WindowsError: + servers, rtype = winreg.QueryValueEx(key, 'NameServer') + except WindowsError: # pylint: disable=undefined-variable servers = None if servers: self._config_win32_nameservers(servers) + if servers or always_try_domain: try: - dom, rtype = _winreg.QueryValueEx(key, 'Domain') + dom, rtype = winreg.QueryValueEx(key, 'Domain') if dom: self._config_win32_domain(dom) - except WindowsError: + except WindowsError: # pragma: no cover pass else: try: - servers, rtype = _winreg.QueryValueEx(key, 'DhcpNameServer') - except WindowsError: + servers, rtype = winreg.QueryValueEx(key, 'DhcpNameServer') + except WindowsError: # pragma: no cover servers = None - if servers: + if servers: # pragma: no cover self._config_win32_nameservers(servers) try: - dom, rtype = _winreg.QueryValueEx(key, 'DhcpDomain') - if dom: + dom, rtype = winreg.QueryValueEx(key, 'DhcpDomain') + if dom: # pragma: no cover self._config_win32_domain(dom) - except WindowsError: + except WindowsError: # pragma: no cover pass try: - search, rtype = _winreg.QueryValueEx(key, 'SearchList') - except WindowsError: + search, rtype = winreg.QueryValueEx(key, 'SearchList') + except WindowsError: # pylint: disable=undefined-variable search = None - if search: + if search: # pragma: no cover self._config_win32_search(search) def read_registry(self): """Extract resolver configuration from the Windows registry.""" - lm = _winreg.ConnectRegistry(None, _winreg.HKEY_LOCAL_MACHINE) - want_scan = False + + lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: + tcp_params = winreg.OpenKey(lm, + r'SYSTEM\CurrentControlSet' + r'\Services\Tcpip\Parameters') try: - # XP, 2000 - tcp_params = _winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters') - want_scan = True - except EnvironmentError: - # ME - tcp_params = _winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\VxD\MSTCP') - try: - self._config_win32_fromkey(tcp_params) + self._config_win32_fromkey(tcp_params, True) finally: tcp_params.Close() - if want_scan: - interfaces = _winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters' - r'\Interfaces') - try: - i = 0 - while True: + interfaces = winreg.OpenKey(lm, + r'SYSTEM\CurrentControlSet' + r'\Services\Tcpip\Parameters' + r'\Interfaces') + try: + i = 0 + while True: + try: + guid = winreg.EnumKey(interfaces, i) + i += 1 + key = winreg.OpenKey(interfaces, guid) + if not self._win32_is_nic_enabled(lm, guid, key): + continue try: - guid = _winreg.EnumKey(interfaces, i) - i += 1 - key = _winreg.OpenKey(interfaces, guid) - if not self._win32_is_nic_enabled(lm, guid, key): - continue - try: - self._config_win32_fromkey(key) - finally: - key.Close() - except EnvironmentError: - break - finally: - interfaces.Close() + self._config_win32_fromkey(key, False) + finally: + key.Close() + except EnvironmentError: # pragma: no cover + break + finally: + interfaces.Close() finally: lm.Close() - def _win32_is_nic_enabled(self, lm, guid, interface_key): + def _win32_is_nic_enabled(self, lm, guid, + interface_key): # Look in the Windows Registry to determine whether the network # interface corresponding to the given guid is enabled. # @@ -743,7 +897,7 @@ class Resolver(object): try: # This hard-coded location seems to be consistent, at least # from Windows 2000 through Vista. - connection_key = _winreg.OpenKey( + connection_key = winreg.OpenKey( lm, r'SYSTEM\CurrentControlSet\Control\Network' r'\{4D36E972-E325-11CE-BFC1-08002BE10318}' @@ -751,45 +905,36 @@ class Resolver(object): try: # The PnpInstanceID points to a key inside Enum - (pnp_id, ttype) = _winreg.QueryValueEx( + (pnp_id, ttype) = winreg.QueryValueEx( connection_key, 'PnpInstanceID') - if ttype != _winreg.REG_SZ: + if ttype != winreg.REG_SZ: # pragma: no cover raise ValueError - device_key = _winreg.OpenKey( + device_key = winreg.OpenKey( lm, r'SYSTEM\CurrentControlSet\Enum\%s' % pnp_id) try: # Get ConfigFlags for this device - (flags, ttype) = _winreg.QueryValueEx( + (flags, ttype) = winreg.QueryValueEx( device_key, 'ConfigFlags') - if ttype != _winreg.REG_DWORD: + if ttype != winreg.REG_DWORD: # pragma: no cover raise ValueError # Based on experimentation, bit 0x1 indicates that the # device is disabled. - return not (flags & 0x1) + return not flags & 0x1 finally: device_key.Close() finally: connection_key.Close() - except (EnvironmentError, ValueError): - # Pre-vista, enabled interfaces seem to have a non-empty - # NTEContextList; this was how dnspython detected enabled - # nics before the code above was contributed. We've retained - # the old method since we don't know if the code above works - # on Windows 95/98/ME. - try: - (nte, ttype) = _winreg.QueryValueEx(interface_key, - 'NTEContextList') - return nte is not None - except WindowsError: - return False + except Exception: # pragma: no cover + return False - def _compute_timeout(self, start): + def _compute_timeout(self, start, lifetime=None): + lifetime = self.lifetime if lifetime is None else lifetime now = time.time() duration = now - start if duration < 0: @@ -801,235 +946,195 @@ class Resolver(object): # happen, e.g. under vmware with older linux kernels. # Pretend it didn't happen. now = start - if duration >= self.lifetime: + if duration >= lifetime: raise Timeout(timeout=duration) - return min(self.lifetime - duration, self.timeout) + return min(lifetime - duration, self.timeout) - def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, source_port=0): - """Query nameservers to find the answer to the question. - - The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects - of the appropriate type, or strings that can be converted into objects - of the appropriate type. E.g. For I{rdtype} the integer 2 and the - the string 'NS' both mean to query for records with DNS rdata type NS. - - @param qname: the query name - @type qname: dns.name.Name object or string - @param rdtype: the query type - @type rdtype: int or string - @param rdclass: the query class - @type rdclass: int or string - @param tcp: use TCP to make the query (default is False). - @type tcp: bool - @param source: bind to this IP address (defaults to machine default - IP). - @type source: IP address in dotted quad notation - @param raise_on_no_answer: raise NoAnswer if there's no answer - (defaults is True). - @type raise_on_no_answer: bool - @param source_port: The port from which to send the message. - The default is 0. - @type source_port: int - @rtype: dns.resolver.Answer instance - @raises Timeout: no answers could be found in the specified lifetime - @raises NXDOMAIN: the query name does not exist - @raises YXDOMAIN: the query name is too long after DNAME substitution - @raises NoAnswer: the response did not contain an answer and - raise_on_no_answer is True. - @raises NoNameservers: no non-broken nameservers are available to - answer the question.""" - - if isinstance(qname, string_types): - qname = dns.name.from_text(qname, None) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if dns.rdatatype.is_metatype(rdtype): - raise NoMetaqueries - if isinstance(rdclass, string_types): - rdclass = dns.rdataclass.from_text(rdclass) - if dns.rdataclass.is_metaclass(rdclass): - raise NoMetaqueries + def _get_qnames_to_try(self, qname, search): + # This is a separate method so we can unit test the search + # rules without requiring the Internet. + if search is None: + search = self.use_search_by_default qnames_to_try = [] if qname.is_absolute(): qnames_to_try.append(qname) else: - if len(qname) > 1: + if len(qname) > 1 or not search: qnames_to_try.append(qname.concatenate(dns.name.root)) - if self.search: + if search and self.search: for suffix in self.search: - qnames_to_try.append(qname.concatenate(suffix)) - else: + if self.ndots is None or len(qname.labels) >= self.ndots: + qnames_to_try.append(qname.concatenate(suffix)) + elif search: qnames_to_try.append(qname.concatenate(self.domain)) - all_nxdomain = True + return qnames_to_try + + def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True, source_port=0, + lifetime=None, search=None): + """Query nameservers to find the answer to the question. + + The *qname*, *rdtype*, and *rdclass* parameters may be objects + of the appropriate type, or strings that can be converted into objects + of the appropriate type. + + *qname*, a ``dns.name.Name`` or ``str``, the query name. + + *rdtype*, an ``int`` or ``str``, the query type. + + *rdclass*, an ``int`` or ``str``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *source*, a ``str`` or ``None``. If not ``None``, bind to this IP + address when making queries. + + *raise_on_no_answer*, a ``bool``. If ``True``, raise + ``dns.resolver.NoAnswer`` if there's no answer to the question. + + *source_port*, an ``int``, the port from which to send the message. + + *lifetime*, a ``float``, how many seconds a query should run + before timing out. + + *search*, a ``bool`` or ``None``, determines whether the + search list configured in the system's resolver configuration + are used for relative names, and whether the resolver's domain + may be added to relative names. The default is ``None``, + which causes the value of the resolver's + ``use_search_by_default`` attribute to be used. + + Raises ``dns.exception.Timeout`` if no answers could be found + in the specified lifetime. + + Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. + + Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after + DNAME substitution. + + Raises ``dns.resolver.NoAnswer`` if *raise_on_no_answer* is + ``True`` and the query name exists but has no RRset of the + desired type and class. + + Raises ``dns.resolver.NoNameservers`` if no non-broken + nameservers are available to answer the question. + + Returns a ``dns.resolver.Answer`` instance. + + """ + + resolution = _Resolution(self, qname, rdtype, rdclass, tcp, + raise_on_no_answer, search) start = time.time() - for qname in qnames_to_try: - if self.cache: - answer = self.cache.get((qname, rdtype, rdclass)) - if answer is not None: - if answer.rrset is None and raise_on_no_answer: - raise NoAnswer - else: - return answer - request = dns.message.make_query(qname, rdtype, rdclass) - if self.keyname is not None: - request.use_tsig(self.keyring, self.keyname, - algorithm=self.keyalgorithm) - request.use_edns(self.edns, self.ednsflags, self.payload) - if self.flags is not None: - request.flags = self.flags - response = None - # - # make a copy of the servers list so we can alter it later. - # - nameservers = self.nameservers[:] - errors = [] - if self.rotate: - random.shuffle(nameservers) - backoff = 0.10 - while response is None: - if len(nameservers) == 0: - raise NoNameservers(request=request, errors=errors) - for nameserver in nameservers[:]: - timeout = self._compute_timeout(start) - port = self.nameserver_ports.get(nameserver, self.port) - try: - tcp_attempt = tcp + while True: + (request, answer) = resolution.next_request() + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + # cache hit! + return answer + done = False + while not done: + (nameserver, port, tcp, backoff) = resolution.next_nameserver() + if backoff: + time.sleep(backoff) + timeout = self._compute_timeout(start, lifetime) + try: + if dns.inet.is_address(nameserver): if tcp: response = dns.query.tcp(request, nameserver, - timeout, port, + timeout=timeout, + port=port, source=source, source_port=source_port) else: - response = dns.query.udp(request, nameserver, - timeout, port, + response = dns.query.udp(request, + nameserver, + timeout=timeout, + port=port, source=source, - source_port=source_port) - if response.flags & dns.flags.TC: - # Response truncated; retry with TCP. - tcp_attempt = True - timeout = self._compute_timeout(start) - response = \ - dns.query.tcp(request, nameserver, - timeout, port, - source=source, - source_port=source_port) - except (socket.error, dns.exception.Timeout) as ex: - # - # Communication failure or timeout. Go to the - # next server - # - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except dns.query.UnexpectedSource as ex: - # - # Who knows? Keep going. - # - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except dns.exception.FormError as ex: - # - # We don't understand what this server is - # saying. Take it out of the mix and - # continue. - # - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except EOFError as ex: - # - # We're using TCP and they hung up on us. - # Probably they don't support TCP (though - # they're supposed to!). Take it out of the - # mix and continue. - # - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - rcode = response.rcode() - if rcode == dns.rcode.YXDOMAIN: - ex = YXDOMAIN() - errors.append((nameserver, tcp_attempt, port, ex, - response)) - raise ex - if rcode == dns.rcode.NOERROR or \ - rcode == dns.rcode.NXDOMAIN: - break - # - # We got a response, but we're not happy with the - # rcode in it. Remove the server from the mix if - # the rcode isn't SERVFAIL. - # - if rcode != dns.rcode.SERVFAIL or not self.retry_servfail: - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, - dns.rcode.to_text(rcode), response)) - response = None - if response is not None: - break - # - # All nameservers failed! - # - if len(nameservers) > 0: - # - # But we still have servers to try. Sleep a bit - # so we don't pound them! - # - timeout = self._compute_timeout(start) - sleep_time = min(timeout, backoff) - backoff *= 2 - time.sleep(sleep_time) - if response.rcode() == dns.rcode.NXDOMAIN: - continue - all_nxdomain = False - break - if all_nxdomain: - raise NXDOMAIN(qname=qnames_to_try) - answer = Answer(qname, rdtype, rdclass, response, - raise_on_no_answer) - if self.cache: - self.cache.put((qname, rdtype, rdclass), answer) - return answer + source_port=source_port, + raise_on_truncation=True) + else: + protocol = urlparse(nameserver).scheme + if protocol != 'https': + raise NotImplementedError + response = dns.query.https(request, nameserver, + timeout=timeout) + except Exception as ex: + (_, done) = resolution.query_result(None, ex) + continue + (answer, done) = resolution.query_result(response, None) + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + return answer + + def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True, source_port=0, + lifetime=None): # pragma: no cover + """Query nameservers to find the answer to the question. + + This method calls resolve() with ``search=True``, and is + provided for backwards compatbility with prior versions of + dnspython. See the documentation for the resolve() method for + further details. + """ + warnings.warn('please use dns.resolver.Resolver.resolve() instead', + DeprecationWarning, stacklevel=2) + return self.resolve(qname, rdtype, rdclass, tcp, source, + raise_on_no_answer, source_port, lifetime, + True) + + def resolve_address(self, ipaddr, *args, **kwargs): + """Use a resolver to run a reverse query for PTR records. + + This utilizes the resolve() method to perform a PTR lookup on the + specified IP address. + + *ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get + the PTR record for. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + + return self.resolve(dns.reversename.from_address(ipaddr), + rdtype=dns.rdatatype.PTR, + rdclass=dns.rdataclass.IN, + *args, **kwargs) def use_tsig(self, keyring, keyname=None, algorithm=dns.tsig.default_algorithm): - """Add a TSIG signature to the query. + """Add a TSIG signature to each query. + + The parameters are passed to ``dns.message.Message.use_tsig()``; + see its documentation for details. + """ - @param keyring: The TSIG keyring to use; defaults to None. - @type keyring: dict - @param keyname: The name of the TSIG key to use; defaults to None. - The key must be defined in the keyring. If a keyring is specified - but a keyname is not, then the key used will be the first key in the - keyring. Note that the order of keys in a dictionary is not defined, - so applications should supply a keyname when a keyring is used, unless - they know the keyring contains only one key. - @param algorithm: The TSIG key algorithm to use. The default - is dns.tsig.default_algorithm. - @type algorithm: string""" self.keyring = keyring - if keyname is None: - self.keyname = list(self.keyring.keys())[0] - else: - self.keyname = keyname + self.keyname = keyname self.keyalgorithm = algorithm def use_edns(self, edns, ednsflags, payload): - """Configure Edns. + """Configure EDNS behavior. - @param edns: The EDNS level to use. The default is -1, no Edns. - @type edns: int - @param ednsflags: The EDNS flags - @type ednsflags: int - @param payload: The EDNS payload size. The default is 0. - @type payload: int""" + *edns*, an ``int``, is the EDNS level to use. Specifying + ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case + the other parameters are ignored. Specifying ``True`` is + equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + """ if edns is None: edns = -1 @@ -1038,50 +1143,116 @@ class Resolver(object): self.payload = payload def set_flags(self, flags): - """Overrides the default flags with your own + """Overrides the default flags with your own. + + *flags*, an ``int``, the message flags to use. + """ - @param flags: The flags to overwrite the default with - @type flags: int""" self.flags = flags + @property + def nameservers(self): + return self._nameservers + + @nameservers.setter + def nameservers(self, nameservers): + """ + *nameservers*, a ``list`` of nameservers. + + Raises ``ValueError`` if *nameservers* is anything other than a + ``list``. + """ + if isinstance(nameservers, list): + self._nameservers = nameservers + else: + raise ValueError('nameservers must be a list' + ' (not a {})'.format(type(nameservers))) + +#: The default resolver. default_resolver = None def get_default_resolver(): """Get the default resolver, initializing it if necessary.""" - global default_resolver if default_resolver is None: - default_resolver = Resolver() + reset_default_resolver() return default_resolver -def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0): +def reset_default_resolver(): + """Re-initialize default resolver. + + Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX + systems) will be re-read immediately. + """ + + global default_resolver + default_resolver = Resolver() + + +def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime=None, search=None): """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver object to make the query. - @see: L{dns.resolver.Resolver.query} for more information on the - parameters.""" - return get_default_resolver().query(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port) + + See ``dns.resolver.Resolver.resolve`` for more information on the + parameters. + """ + + return get_default_resolver().resolve(qname, rdtype, rdclass, tcp, source, + raise_on_no_answer, source_port, + lifetime, search) + +def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime=None): # pragma: no cover + """Query nameservers to find the answer to the question. + + This method calls resolve() with ``search=True``, and is + provided for backwards compatbility with prior versions of + dnspython. See the documentation for the resolve() method for + further details. + """ + warnings.warn('please use dns.resolver.resolve() instead', + DeprecationWarning, stacklevel=2) + return resolve(qname, rdtype, rdclass, tcp, source, + raise_on_no_answer, source_port, lifetime, + True) + + +def resolve_address(ipaddr, *args, **kwargs): + """Use a resolver to run a reverse query for PTR records. + + See ``dns.resolver.Resolver.resolve_address`` for more information on the + parameters. + """ + + return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): """Find the name of the zone which contains the specified name. - @param name: the query name - @type name: absolute dns.name.Name object or string - @param rdclass: The query class - @type rdclass: int - @param tcp: use TCP to make the query (default is False). - @type tcp: bool - @param resolver: the resolver to use - @type resolver: dns.resolver.Resolver object or None - @rtype: dns.name.Name""" + *name*, an absolute ``dns.name.Name`` or ``str``, the query name. - if isinstance(name, string_types): + *rdclass*, an ``int``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. + If ``None``, the default resolver is used. + + Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS + root. (This is only likely to happen if you're using non-default + root servers in your network and they are misconfigured.) + + Returns a ``dns.name.Name``. + """ + + if isinstance(name, str): name = dns.name.from_text(name, dns.name.root) if resolver is None: resolver = get_default_resolver() @@ -1089,7 +1260,7 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): raise NotAbsolute(name) while 1: try: - answer = resolver.query(name, dns.rdatatype.SOA, rdclass, tcp) + answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp) if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher @@ -1121,63 +1292,73 @@ _original_gethostbyaddr = socket.gethostbyaddr def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, proto=0, flags=0): + if flags & socket.AI_NUMERICHOST != 0: + # Short circuit directly into the system's getaddrinfo(). We're + # not adding any value in this case, and this avoids infinite loops + # because dns.query.* needs to call getaddrinfo() for IPv6 scoping + # reasons. We will also do this short circuit below if we + # discover that the host is an address literal. + return _original_getaddrinfo(host, service, family, socktype, proto, + flags) if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0: - raise NotImplementedError + # Not implemented. We raise a gaierror as opposed to a + # NotImplementedError as it helps callers handle errors more + # appropriately. [Issue #316] + # + # We raise EAI_FAIL as opposed to EAI_SYSTEM because there is + # no EAI_SYSTEM on Windows [Issue #416]. We didn't go for + # EAI_BADFLAGS as the flags aren't bad, we just don't + # implement them. + raise socket.gaierror(socket.EAI_FAIL, + 'Non-recoverable failure in name resolution') if host is None and service is None: - raise socket.gaierror(socket.EAI_NONAME) + raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') v6addrs = [] v4addrs = [] canonical_name = None + # Is host None or an address literal? If so, use the system's + # getaddrinfo(). + if host is None: + return _original_getaddrinfo(host, service, family, socktype, + proto, flags) try: - # Is host None or a V6 address literal? - if host is None: - canonical_name = 'localhost' - if flags & socket.AI_PASSIVE != 0: - v6addrs.append('::') - v4addrs.append('0.0.0.0') - else: - v6addrs.append('::1') - v4addrs.append('127.0.0.1') - else: - parts = host.split('%') - if len(parts) == 2: - ahost = parts[0] - else: - ahost = host - addr = dns.ipv6.inet_aton(ahost) - v6addrs.append(host) - canonical_name = host - except: - try: - # Is it a V4 address literal? - addr = dns.ipv4.inet_aton(host) - v4addrs.append(host) - canonical_name = host - except: - if flags & socket.AI_NUMERICHOST == 0: - try: - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - v6 = _resolver.query(host, dns.rdatatype.AAAA, - raise_on_no_answer=False) - # Note that setting host ensures we query the same name - # for A as we did for AAAA. - host = v6.qname - canonical_name = v6.canonical_name.to_text(True) - if v6.rrset is not None: - for rdata in v6.rrset: - v6addrs.append(rdata.address) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - v4 = _resolver.query(host, dns.rdatatype.A, - raise_on_no_answer=False) - host = v4.qname - canonical_name = v4.canonical_name.to_text(True) - if v4.rrset is not None: - for rdata in v4.rrset: - v4addrs.append(rdata.address) - except dns.resolver.NXDOMAIN: - raise socket.gaierror(socket.EAI_NONAME) - except: - raise socket.gaierror(socket.EAI_SYSTEM) + # We don't care about the result of af_for_address(), we're just + # calling it so it raises an exception if host is not an IPv4 or + # IPv6 address. + dns.inet.af_for_address(host) + return _original_getaddrinfo(host, service, family, socktype, + proto, flags) + except Exception: + pass + # Something needs resolution! + try: + if family == socket.AF_INET6 or family == socket.AF_UNSPEC: + v6 = _resolver.resolve(host, dns.rdatatype.AAAA, + raise_on_no_answer=False) + # Note that setting host ensures we query the same name + # for A as we did for AAAA. + host = v6.qname + canonical_name = v6.canonical_name.to_text(True) + if v6.rrset is not None: + for rdata in v6.rrset: + v6addrs.append(rdata.address) + if family == socket.AF_INET or family == socket.AF_UNSPEC: + v4 = _resolver.resolve(host, dns.rdatatype.A, + raise_on_no_answer=False) + host = v4.qname + canonical_name = v4.canonical_name.to_text(True) + if v4.rrset is not None: + for rdata in v4.rrset: + v4addrs.append(rdata.address) + except dns.resolver.NXDOMAIN: + raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + except Exception as e: + print(e) + # We raise EAI_AGAIN here as the failure may be temporary + # (e.g. a timeout) and EAI_SYSTEM isn't defined on Windows. + # [Issue #416] + raise socket.gaierror(socket.EAI_AGAIN, + 'Temporary failure in name resolution') port = None try: # Is it a port literal? @@ -1185,14 +1366,14 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, port = 0 else: port = int(service) - except: + except Exception: if flags & socket.AI_NUMERICSERV == 0: try: port = socket.getservbyname(service) - except: + except Exception: pass if port is None: - raise socket.gaierror(socket.EAI_NONAME) + raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') tuples = [] if socktype == 0: socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM] @@ -1215,7 +1396,7 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, tuples.append((socket.AF_INET, socktype, proto, cname, (addr, port))) if len(tuples) == 0: - raise socket.gaierror(socket.EAI_NONAME) + raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') return tuples @@ -1240,11 +1421,12 @@ def _getnameinfo(sockaddr, flags=0): qname = dns.reversename.from_address(addr) if flags & socket.NI_NUMERICHOST == 0: try: - answer = _resolver.query(qname, 'PTR') + answer = _resolver.resolve(qname, 'PTR') hostname = answer.rrset[0].target.to_text(True) except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): if flags & socket.NI_NAMEREQD: - raise socket.gaierror(socket.EAI_NONAME) + raise socket.gaierror(socket.EAI_NONAME, + 'Name or service not known') hostname = addr if scope is not None: hostname += '%' + str(scope) @@ -1263,9 +1445,12 @@ def _getfqdn(name=None): if name is None: name = socket.gethostname() try: - return _getnameinfo(_getaddrinfo(name, 80)[0][4])[0] - except: - return name + (name, _, _) = _gethostbyaddr(name) + # Python's version checks aliases too, but our gethostbyname + # ignores them, so we do so here as well. + except Exception: + pass + return name def _gethostbyname(name): @@ -1289,7 +1474,12 @@ def _gethostbyaddr(ip): dns.ipv6.inet_aton(ip) sockaddr = (ip, 80, 0, 0) family = socket.AF_INET6 - except: + except Exception: + try: + dns.ipv4.inet_aton(ip) + except Exception: + raise socket.gaierror(socket.EAI_NONAME, + 'Name or service not known') sockaddr = (ip, 80) family = socket.AF_INET (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) @@ -1298,8 +1488,15 @@ def _gethostbyaddr(ip): tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME) canonical = tuples[0][3] + # We only want to include an address from the tuples if it's the + # same as the one we asked about. We do this comparison in binary + # to avoid any differences in text representations. + bin_ip = dns.inet.inet_pton(family, ip) for item in tuples: - addresses.append(item[4][0]) + addr = item[4][0] + bin_addr = dns.inet.inet_pton(family, addr) + if bin_ip == bin_addr: + addresses.append(addr) # XXX we just ignore aliases return (canonical, aliases, addresses) @@ -1315,9 +1512,9 @@ def override_system_resolver(resolver=None): The resolver to use may be specified; if it's not, the default resolver will be used. - @param resolver: the resolver to use - @type resolver: dns.resolver.Resolver object or None + resolver, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. """ + if resolver is None: resolver = get_default_resolver() global _resolver @@ -1331,8 +1528,8 @@ def override_system_resolver(resolver=None): def restore_system_resolver(): - """Undo the effects of override_system_resolver(). - """ + """Undo the effects of prior override_system_resolver().""" + global _resolver _resolver = None socket.getaddrinfo = _original_getaddrinfo diff --git a/lib/dns/reversename.py b/lib/dns/reversename.py index a27e7050..e0beb03d 100644 --- a/lib/dns/reversename.py +++ b/lib/dns/reversename.py @@ -1,4 +1,6 @@ -# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,16 +15,9 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS Reverse Map Names. - -@var ipv4_reverse_domain: The DNS IPv4 reverse-map domain, in-addr.arpa. -@type ipv4_reverse_domain: dns.name.Name object -@var ipv6_reverse_domain: The DNS IPv6 reverse-map domain, ip6.arpa. -@type ipv6_reverse_domain: dns.name.Name object -""" +"""DNS Reverse Map Names.""" import binascii -import sys import dns.name import dns.ipv6 @@ -32,58 +27,74 @@ ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.') ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') -def from_address(text): +def from_address(text, v4_origin=ipv4_reverse_domain, + v6_origin=ipv6_reverse_domain): """Convert an IPv4 or IPv6 address in textual form into a Name object whose value is the reverse-map domain name of the address. - @param text: an IPv4 or IPv6 address in textual form (e.g. '127.0.0.1', - '::1') - @type text: str - @rtype: dns.name.Name object + + *text*, a ``str``, is an IPv4 or IPv6 address in textual form + (e.g. '127.0.0.1', '::1') + + *v4_origin*, a ``dns.name.Name`` to append to the labels corresponding to + the address if the address is an IPv4 address, instead of the default + (in-addr.arpa.) + + *v6_origin*, a ``dns.name.Name`` to append to the labels corresponding to + the address if the address is an IPv6 address, instead of the default + (ip6.arpa.) + + Raises ``dns.exception.SyntaxError`` if the address is badly formed. + + Returns a ``dns.name.Name``. """ + try: v6 = dns.ipv6.inet_aton(text) if dns.ipv6.is_mapped(v6): - if sys.version_info >= (3,): - parts = ['%d' % byte for byte in v6[12:]] - else: - parts = ['%d' % ord(byte) for byte in v6[12:]] - origin = ipv4_reverse_domain + parts = ['%d' % byte for byte in v6[12:]] + origin = v4_origin else: parts = [x for x in str(binascii.hexlify(v6).decode())] - origin = ipv6_reverse_domain - except: + origin = v6_origin + except Exception: parts = ['%d' % - byte for byte in bytearray(dns.ipv4.inet_aton(text))] - origin = ipv4_reverse_domain - parts.reverse() - return dns.name.from_text('.'.join(parts), origin=origin) + byte for byte in dns.ipv4.inet_aton(text)] + origin = v4_origin + return dns.name.from_text('.'.join(reversed(parts)), origin=origin) -def to_address(name): +def to_address(name, v4_origin=ipv4_reverse_domain, + v6_origin=ipv6_reverse_domain): """Convert a reverse map domain name into textual address form. - @param name: an IPv4 or IPv6 address in reverse-map form. - @type name: dns.name.Name object - @rtype: str + + *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name + form. + + *v4_origin*, a ``dns.name.Name`` representing the top-level domain for + IPv4 addresses, instead of the default (in-addr.arpa.) + + *v6_origin*, a ``dns.name.Name`` representing the top-level domain for + IPv4 addresses, instead of the default (ip6.arpa.) + + Raises ``dns.exception.SyntaxError`` if the name does not have a + reverse-map form. + + Returns a ``str``. """ - if name.is_subdomain(ipv4_reverse_domain): - name = name.relativize(ipv4_reverse_domain) - labels = list(name.labels) - labels.reverse() - text = b'.'.join(labels) - # run through inet_aton() to check syntax and make pretty. + + if name.is_subdomain(v4_origin): + name = name.relativize(v4_origin) + text = b'.'.join(reversed(name.labels)) + # run through inet_ntoa() to check syntax and make pretty. return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) - elif name.is_subdomain(ipv6_reverse_domain): - name = name.relativize(ipv6_reverse_domain) - labels = list(name.labels) - labels.reverse() + elif name.is_subdomain(v6_origin): + name = name.relativize(v6_origin) + labels = list(reversed(name.labels)) parts = [] - i = 0 - l = len(labels) - while i < l: + for i in range(0, len(labels), 4): parts.append(b''.join(labels[i:i + 4])) - i += 4 text = b':'.join(parts) - # run through inet_aton() to check syntax and make pretty. + # run through inet_ntoa() to check syntax and make pretty. return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) else: raise dns.exception.SyntaxError('unknown reverse-map address family') diff --git a/lib/dns/rrset.py b/lib/dns/rrset.py index 6ad71da8..68136f40 100644 --- a/lib/dns/rrset.py +++ b/lib/dns/rrset.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,7 +22,6 @@ import dns.name import dns.rdataset import dns.rdataclass import dns.renderer -from ._compat import string_types class RRset(dns.rdataset.Rdataset): @@ -40,12 +41,12 @@ class RRset(dns.rdataset.Rdataset): deleting=None): """Create a new RRset.""" - super(RRset, self).__init__(rdclass, rdtype, covers) + super().__init__(rdclass, rdtype, covers) self.name = name self.deleting = deleting def _clone(self): - obj = super(RRset, self)._clone() + obj = super()._clone() obj.name = self.name obj.deleting = self.deleting return obj @@ -61,27 +62,25 @@ class RRset(dns.rdataset.Rdataset): dtext = '' return '' + dns.rdatatype.to_text(self.rdtype) + ctext + dtext + \ + ' RRset: ' + self._rdata_repr() + '>' def __str__(self): return self.to_text() def __eq__(self, other): - """Two RRsets are equal if they have the same name and the same - rdataset - - @rtype: bool""" if not isinstance(other, RRset): return False if self.name != other.name: return False - return super(RRset, self).__eq__(other) + return super().__eq__(other) def match(self, name, rdclass, rdtype, covers, deleting=None): - """Returns True if this rrset matches the specified class, type, - covers, and deletion state.""" + """Returns ``True`` if this rrset matches the specified class, type, + covers, and deletion state. + """ - if not super(RRset, self).match(rdclass, rdtype, covers): + if not super().match(rdclass, rdtype, covers): return False if self.name != name or self.deleting != deleting: return False @@ -90,52 +89,63 @@ class RRset(dns.rdataset.Rdataset): def to_text(self, origin=None, relativize=True, **kw): """Convert the RRset into DNS master file format. - @see: L{dns.name.Name.choose_relativity} for more information - on how I{origin} and I{relativize} determine the way names + See ``dns.name.Name.choose_relativity`` for more information + on how *origin* and *relativize* determine the way names are emitted. Any additional keyword arguments are passed on to the rdata - to_text() method. + ``to_text()`` method. - @param origin: The origin for relative names, or None. - @type origin: dns.name.Name object - @param relativize: True if names should names be relativized - @type relativize: bool""" + *origin*, a ``dns.name.Name`` or ``None``, the origin for relative + names. - return super(RRset, self).to_text(self.name, origin, relativize, - self.deleting, **kw) + *relativize*, a ``bool``. If ``True``, names will be relativized + to *origin*. + """ + + return super().to_text(self.name, origin, relativize, + self.deleting, **kw) def to_wire(self, file, compress=None, origin=None, **kw): - """Convert the RRset to wire format.""" + """Convert the RRset to wire format. - return super(RRset, self).to_wire(self.name, file, compress, origin, - self.deleting, **kw) + All keyword arguments are passed to ``dns.rdataset.to_wire()``; see + that function for details. + + Returns an ``int``, the number of records emitted. + """ + + return super().to_wire(self.name, file, compress, origin, + self.deleting, **kw) def to_rdataset(self): """Convert an RRset into an Rdataset. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset``. """ return dns.rdataset.from_rdata_list(self.ttl, list(self)) -def from_text_list(name, ttl, rdclass, rdtype, text_rdatas): +def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, + idna_codec=None): """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. - @rtype: dns.rrset.RRset object + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use; if ``None``, the default IDNA 2003 + encoder/decoder is used. + + Returns a ``dns.rrset.RRset`` object. """ - if isinstance(name, string_types): - name = dns.name.from_text(name, None) - if isinstance(rdclass, string_types): - rdclass = dns.rdataclass.from_text(rdclass) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) + if isinstance(name, str): + name = dns.name.from_text(name, None, idna_codec=idna_codec) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) r = RRset(name, rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t) + rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, idna_codec=idna_codec) r.add(rd) return r @@ -144,21 +154,26 @@ def from_text(name, ttl, rdclass, rdtype, *text_rdatas): """Create an RRset with the specified name, TTL, class, and type and with the specified rdatas in text format. - @rtype: dns.rrset.RRset object + Returns a ``dns.rrset.RRset`` object. """ return from_text_list(name, ttl, rdclass, rdtype, text_rdatas) -def from_rdata_list(name, ttl, rdatas): +def from_rdata_list(name, ttl, rdatas, idna_codec=None): """Create an RRset with the specified name and TTL, and with the specified list of rdata objects. - @rtype: dns.rrset.RRset object + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use; if ``None``, the default IDNA 2003 + encoder/decoder is used. + + Returns a ``dns.rrset.RRset`` object. + """ - if isinstance(name, string_types): - name = dns.name.from_text(name, None) + if isinstance(name, str): + name = dns.name.from_text(name, None, idna_codec=idna_codec) if len(rdatas) == 0: raise ValueError("rdata list must not be empty") @@ -175,7 +190,7 @@ def from_rdata(name, ttl, *rdatas): """Create an RRset with the specified name and TTL, and with the specified rdata objects. - @rtype: dns.rrset.RRset object + Returns a ``dns.rrset.RRset`` object. """ return from_rdata_list(name, ttl, rdatas) diff --git a/lib/dns/serial.py b/lib/dns/serial.py new file mode 100644 index 00000000..b0474151 --- /dev/null +++ b/lib/dns/serial.py @@ -0,0 +1,117 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""Serial Number Arthimetic from RFC 1982""" + +class Serial: + def __init__(self, value, bits=32): + self.value = value % 2 ** bits + self.bits = bits + + def __repr__(self): + return f'dns.serial.Serial({self.value}, {self.bits})' + + def __eq__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + return self.value == other.value + + def __ne__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + return self.value != other.value + + def __lt__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + if self.value < other.value and \ + other.value - self.value < 2 ** (self.bits - 1): + return True + elif self.value > other.value and \ + self.value - other.value > 2 ** (self.bits - 1): + return True + else: + return False + + def __le__(self, other): + return self == other or self < other + + def __gt__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + if self.value < other.value and \ + other.value - self.value > 2 ** (self.bits - 1): + return True + elif self.value > other.value and \ + self.value - other.value < 2 ** (self.bits - 1): + return True + else: + return False + + def __ge__(self, other): + return self == other or self > other + + def __add__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v += delta + v = v % 2 ** self.bits + return Serial(v, self.bits) + + def __iadd__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v += delta + v = v % 2 ** self.bits + self.value = v + return self + + def __sub__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v -= delta + v = v % 2 ** self.bits + return Serial(v, self.bits) + + def __isub__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v -= delta + v = v % 2 ** self.bits + self.value = v + return self diff --git a/lib/dns/set.py b/lib/dns/set.py index f13af5f3..0982d787 100644 --- a/lib/dns/set.py +++ b/lib/dns/set.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,52 +15,61 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""A simple Set class.""" +import itertools +import sys +if sys.version_info >= (3, 7): + odict = dict +else: + from collections import OrderedDict as odict # pragma: no cover -class Set(object): +class Set: """A simple set class. - Sets are not in Python until 2.3, and rdata are not immutable so - we cannot use sets.Set anyway. This class implements subset of - the 2.3 Set interface using a list as the container. - - @ivar items: A list of the items which are in the set - @type items: list""" + This class was originally used to deal with sets being missing in + ancient versions of python, but dnspython will continue to use it + as these sets are based on lists and are thus indexable, and this + ability is widely used in dnspython applications. + """ __slots__ = ['items'] def __init__(self, items=None): """Initialize the set. - @param items: the initial set of items - @type items: any iterable or None + *items*, an iterable or ``None``, the initial set of items. """ - self.items = [] + self.items = odict() if items is not None: for item in items: self.add(item) def __repr__(self): - return "dns.simpleset.Set(%s)" % repr(self.items) + return "dns.set.Set(%s)" % repr(list(self.items.keys())) def add(self, item): - """Add an item to the set.""" + """Add an item to the set. + """ + if item not in self.items: - self.items.append(item) + self.items[item] = None def remove(self, item): - """Remove an item from the set.""" - self.items.remove(item) + """Remove an item from the set. + """ + + try: + del self.items[item] + except KeyError: + raise ValueError def discard(self, item): - """Remove an item from the set if present.""" - try: - self.items.remove(item) - except ValueError: - pass + """Remove an item from the set if present. + """ + + self.items.pop(item, None) def _clone(self): """Make a (shallow) copy of the set. @@ -75,23 +86,26 @@ class Set(object): cls = self.__class__ obj = cls.__new__(cls) - obj.items = list(self.items) + obj.items = self.items.copy() return obj def __copy__(self): - """Make a (shallow) copy of the set.""" + """Make a (shallow) copy of the set. + """ + return self._clone() def copy(self): - """Make a (shallow) copy of the set.""" + """Make a (shallow) copy of the set. + """ + return self._clone() def union_update(self, other): """Update the set, adding any elements from other which are not already in the set. - @param other: the collection of items with which to update the set - @type other: Set object """ + if not isinstance(other, Set): raise ValueError('other must be a Set instance') if self is other: @@ -102,9 +116,8 @@ class Set(object): def intersection_update(self, other): """Update the set, removing any elements from other which are not in both sets. - @param other: the collection of items with which to update the set - @type other: Set object """ + if not isinstance(other, Set): raise ValueError('other must be a Set instance') if self is other: @@ -113,28 +126,25 @@ class Set(object): # the list without breaking the iterator. for item in list(self.items): if item not in other.items: - self.items.remove(item) + del self.items[item] def difference_update(self, other): """Update the set, removing any elements from other which are in the set. - @param other: the collection of items with which to update the set - @type other: Set object """ + if not isinstance(other, Set): raise ValueError('other must be a Set instance') if self is other: - self.items = [] + self.items.clear() else: for item in other.items: self.discard(item) def union(self, other): - """Return a new set which is the union of I{self} and I{other}. + """Return a new set which is the union of ``self`` and ``other``. - @param other: the other set - @type other: Set object - @rtype: the same type as I{self} + Returns the same Set type as this set. """ obj = self._clone() @@ -142,11 +152,10 @@ class Set(object): return obj def intersection(self, other): - """Return a new set which is the intersection of I{self} and I{other}. + """Return a new set which is the intersection of ``self`` and + ``other``. - @param other: the other set - @type other: Set object - @rtype: the same type as I{self} + Returns the same Set type as this set. """ obj = self._clone() @@ -154,12 +163,10 @@ class Set(object): return obj def difference(self, other): - """Return a new set which I{self} - I{other}, i.e. the items - in I{self} which are not also in I{other}. + """Return a new set which ``self`` - ``other``, i.e. the items + in ``self`` which are not also in ``other``. - @param other: the other set - @type other: Set object - @rtype: the same type as I{self} + Returns the same Set type as this set. """ obj = self._clone() @@ -197,25 +204,26 @@ class Set(object): def update(self, other): """Update the set, adding any elements from other which are not already in the set. - @param other: the collection of items with which to update the set - @type other: any iterable type""" + + *other*, the collection of items with which to update the set, which + may be any iterable type. + """ + for item in other: self.add(item) def clear(self): """Make the set empty.""" - self.items = [] + self.items.clear() def __eq__(self, other): - # Yes, this is inefficient but the sets we're dealing with are - # usually quite small, so it shouldn't hurt too much. - for item in self.items: - if item not in other.items: + if odict == dict: + return self.items == other.items + else: + # We don't want an ordered comparison. + if len(self.items) != len(other.items): return False - for item in other.items: - if item not in self.items: - return False - return True + return all(elt in other.items for elt in self.items) def __ne__(self, other): return not self.__eq__(other) @@ -227,21 +235,22 @@ class Set(object): return iter(self.items) def __getitem__(self, i): - return self.items[i] + if isinstance(i, slice): + return list(itertools.islice(self.items, i.start, i.stop, i.step)) + else: + return next(itertools.islice(self.items, i, i + 1)) def __delitem__(self, i): - del self.items[i] - - def __getslice__(self, i, j): - return self.items[i:j] - - def __delslice__(self, i, j): - del self.items[i:j] + if isinstance(i, slice): + for elt in list(self[i]): + del self.items[elt] + else: + del self.items[self[i]] def issubset(self, other): - """Is I{self} a subset of I{other}? + """Is this set a subset of *other*? - @rtype: bool + Returns a ``bool``. """ if not isinstance(other, Set): @@ -252,9 +261,9 @@ class Set(object): return True def issuperset(self, other): - """Is I{self} a superset of I{other}? + """Is this set a superset of *other*? - @rtype: bool + Returns a ``bool``. """ if not isinstance(other, Set): diff --git a/lib/dns/tokenizer.py b/lib/dns/tokenizer.py index e5b09adf..3e5d2ba9 100644 --- a/lib/dns/tokenizer.py +++ b/lib/dns/tokenizer.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,24 +17,15 @@ """Tokenize DNS master file format""" -from io import StringIO +import io import sys import dns.exception import dns.name import dns.ttl -from ._compat import long, text_type, binary_type -_DELIMITERS = { - ' ': True, - '\t': True, - '\n': True, - ';': True, - '(': True, - ')': True, - '"': True} - -_QUOTING_DELIMITERS = {'"': True} +_DELIMITERS = {' ', '\t', '\n', ';', '(', ')', '"'} +_QUOTING_DELIMITERS = {'"'} EOF = 0 EOL = 1 @@ -44,32 +37,20 @@ DELIMITER = 6 class UngetBufferFull(dns.exception.DNSException): - """An attempt was made to unget a token when the unget buffer was full.""" -class Token(object): - +class Token: """A DNS master file format token. - @ivar ttype: The token type - @type ttype: int - @ivar value: The token value - @type value: string - @ivar has_escape: Does the token value contain escapes? - @type has_escape: bool + ttype: The token type + value: The token value + has_escape: Does the token value contain escapes? """ def __init__(self, ttype, value='', has_escape=False): - """Initialize a token instance. + """Initialize a token instance.""" - @param ttype: The token type - @type ttype: int - @param value: The token value - @type value: string - @param has_escape: Does the token value contain escapes? - @type has_escape: bool - """ self.ttype = ttype self.value = value self.has_escape = has_escape @@ -92,11 +73,11 @@ class Token(object): def is_comment(self): return self.ttype == COMMENT - def is_delimiter(self): + def is_delimiter(self): # pragma: no cover (we don't return delimiters yet) return self.ttype == DELIMITER def is_eol_or_eof(self): - return (self.ttype == EOL or self.ttype == EOF) + return self.ttype == EOL or self.ttype == EOF def __eq__(self, other): if not isinstance(other, Token): @@ -142,72 +123,120 @@ class Token(object): unescaped += c return Token(self.ttype, unescaped) - # compatibility for old-style tuple tokens - - def __len__(self): - return 2 - - def __iter__(self): - return iter((self.ttype, self.value)) - - def __getitem__(self, i): - if i == 0: - return self.ttype - elif i == 1: - return self.value - else: - raise IndexError + def unescape_to_bytes(self): + # We used to use unescape() for TXT-like records, but this + # caused problems as we'd process DNS escapes into Unicode code + # points instead of byte values, and then a to_text() of the + # processed data would not equal the original input. For + # example, \226 in the TXT record would have a to_text() of + # \195\162 because we applied UTF-8 encoding to Unicode code + # point 226. + # + # We now apply escapes while converting directly to bytes, + # avoiding this double encoding. + # + # This code also handles cases where the unicode input has + # non-ASCII code-points in it by converting it to UTF-8. TXT + # records aren't defined for Unicode, but this is the best we + # can do to preserve meaning. For example, + # + # foo\u200bbar + # + # (where \u200b is Unicode code point 0x200b) will be treated + # as if the input had been the UTF-8 encoding of that string, + # namely: + # + # foo\226\128\139bar + # + unescaped = b'' + l = len(self.value) + i = 0 + while i < l: + c = self.value[i] + i += 1 + if c == '\\': + if i >= l: + raise dns.exception.UnexpectedEnd + c = self.value[i] + i += 1 + if c.isdigit(): + if i >= l: + raise dns.exception.UnexpectedEnd + c2 = self.value[i] + i += 1 + if i >= l: + raise dns.exception.UnexpectedEnd + c3 = self.value[i] + i += 1 + if not (c2.isdigit() and c3.isdigit()): + raise dns.exception.SyntaxError + unescaped += b'%c' % (int(c) * 100 + int(c2) * 10 + int(c3)) + else: + # Note that as mentioned above, if c is a Unicode + # code point outside of the ASCII range, then this + # += is converting that code point to its UTF-8 + # encoding and appending multiple bytes to + # unescaped. + unescaped += c.encode() + else: + unescaped += c.encode() + return Token(self.ttype, bytes(unescaped)) -class Tokenizer(object): - +class Tokenizer: """A DNS master file format tokenizer. - A token is a (type, value) tuple, where I{type} is an int, and - I{value} is a string. The valid types are EOF, EOL, WHITESPACE, - IDENTIFIER, QUOTED_STRING, COMMENT, and DELIMITER. + A token object is basically a (type, value) tuple. The valid + types are EOF, EOL, WHITESPACE, IDENTIFIER, QUOTED_STRING, + COMMENT, and DELIMITER. - @ivar file: The file to tokenize - @type file: file - @ivar ungotten_char: The most recently ungotten character, or None. - @type ungotten_char: string - @ivar ungotten_token: The most recently ungotten token, or None. - @type ungotten_token: (int, string) token tuple - @ivar multiline: The current multiline level. This value is increased + file: The file to tokenize + + ungotten_char: The most recently ungotten character, or None. + + ungotten_token: The most recently ungotten token, or None. + + multiline: The current multiline level. This value is increased by one every time a '(' delimiter is read, and decreased by one every time a ')' delimiter is read. - @type multiline: int - @ivar quoting: This variable is true if the tokenizer is currently + + quoting: This variable is true if the tokenizer is currently reading a quoted string. - @type quoting: bool - @ivar eof: This variable is true if the tokenizer has encountered EOF. - @type eof: bool - @ivar delimiters: The current delimiter dictionary. - @type delimiters: dict - @ivar line_number: The current line number - @type line_number: int - @ivar filename: A filename that will be returned by the L{where} method. - @type filename: string + + eof: This variable is true if the tokenizer has encountered EOF. + + delimiters: The current delimiter dictionary. + + line_number: The current line number + + filename: A filename that will be returned by the where() method. + + idna_codec: A dns.name.IDNACodec, specifies the IDNA + encoder/decoder. If None, the default IDNA 2003 + encoder/decoder is used. """ - def __init__(self, f=sys.stdin, filename=None): + def __init__(self, f=sys.stdin, filename=None, idna_codec=None): """Initialize a tokenizer instance. - @param f: The file to tokenize. The default is sys.stdin. + f: The file to tokenize. The default is sys.stdin. This parameter may also be a string, in which case the tokenizer will take its input from the contents of the string. - @type f: file or string - @param filename: the name of the filename that the L{where} method + + filename: the name of the filename that the where() method will return. - @type filename: string + + idna_codec: A dns.name.IDNACodec, specifies the IDNA + encoder/decoder. If None, the default IDNA 2003 + encoder/decoder is used. """ - if isinstance(f, text_type): - f = StringIO(f) + if isinstance(f, str): + f = io.StringIO(f) if filename is None: filename = '' - elif isinstance(f, binary_type): - f = StringIO(f.decode()) + elif isinstance(f, bytes): + f = io.StringIO(f.decode()) if filename is None: filename = '' else: @@ -225,10 +254,12 @@ class Tokenizer(object): self.delimiters = _DELIMITERS self.line_number = 1 self.filename = filename + if idna_codec is None: + idna_codec = dns.name.IDNA_2003 + self.idna_codec = idna_codec def _get_char(self): """Read a character from input. - @rtype: string """ if self.ungotten_char is None: @@ -248,7 +279,7 @@ class Tokenizer(object): def where(self): """Return the current location in the input. - @rtype: (string, int) tuple. The first item is the filename of + Returns a (string, int) tuple. The first item is the filename of the input, the second is the current line number. """ @@ -261,13 +292,13 @@ class Tokenizer(object): an error to try to unget a character when the unget buffer is not empty. - @param c: the character to unget - @type c: string - @raises UngetBufferFull: there is already an ungotten char + c: the character to unget + raises UngetBufferFull: there is already an ungotten char """ if self.ungotten_char is not None: - raise UngetBufferFull + # this should never happen! + raise UngetBufferFull # pragma: no cover self.ungotten_char = c def skip_whitespace(self): @@ -278,7 +309,7 @@ class Tokenizer(object): If the tokenizer is in multiline mode, then newlines are whitespace. - @rtype: int + Returns the number of characters skipped. """ skipped = 0 @@ -293,15 +324,17 @@ class Tokenizer(object): def get(self, want_leading=False, want_comment=False): """Get the next token. - @param want_leading: If True, return a WHITESPACE token if the + want_leading: If True, return a WHITESPACE token if the first character read is whitespace. The default is False. - @type want_leading: bool - @param want_comment: If True, return a COMMENT token if the + + want_comment: If True, return a COMMENT token if the first token read is a comment. The default is False. - @type want_comment: bool - @rtype: Token object - @raises dns.exception.UnexpectedEnd: input ended prematurely - @raises dns.exception.SyntaxError: input was badly formed + + Raises dns.exception.UnexpectedEnd: input ended prematurely + + Raises dns.exception.SyntaxError: input was badly formed + + Returns a Token. """ if self.ungotten_token is not None: @@ -332,7 +365,7 @@ class Tokenizer(object): self.skip_whitespace() continue elif c == ')': - if not self.multiline > 0: + if self.multiline <= 0: raise dns.exception.SyntaxError self.multiline -= 1 self.skip_whitespace() @@ -379,23 +412,8 @@ class Tokenizer(object): else: self._unget_char(c) break - elif self.quoting: - if c == '\\': - c = self._get_char() - if c == '': - raise dns.exception.UnexpectedEnd - if c.isdigit(): - c2 = self._get_char() - if c2 == '': - raise dns.exception.UnexpectedEnd - c3 = self._get_char() - if c == '': - raise dns.exception.UnexpectedEnd - if not (c2.isdigit() and c3.isdigit()): - raise dns.exception.SyntaxError - c = chr(int(c) * 100 + int(c2) * 10 + int(c3)) - elif c == '\n': - raise dns.exception.SyntaxError('newline in quoted string') + elif self.quoting and c == '\n': + raise dns.exception.SyntaxError('newline in quoted string') elif c == '\\': # # It's an escape. Put it and the next character into @@ -420,9 +438,9 @@ class Tokenizer(object): an error to try to unget a token when the unget buffer is not empty. - @param token: the token to unget - @type token: Token object - @raises UngetBufferFull: there is already an ungotten token + token: the token to unget + + Raises UngetBufferFull: there is already an ungotten token """ if self.ungotten_token is not None: @@ -431,7 +449,8 @@ class Tokenizer(object): def next(self): """Return the next item in an iteration. - @rtype: (int, string) + + Returns a Token. """ token = self.get() @@ -446,11 +465,12 @@ class Tokenizer(object): # Helpers - def get_int(self): - """Read the next token and interpret it as an integer. + def get_int(self, base=10): + """Read the next token and interpret it as an unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not an unsigned integer. + + Returns an int. """ token = self.get().unescape() @@ -458,14 +478,15 @@ class Tokenizer(object): raise dns.exception.SyntaxError('expecting an identifier') if not token.value.isdigit(): raise dns.exception.SyntaxError('expecting an integer') - return int(token.value) + return int(token.value, base) def get_uint8(self): """Read the next token and interpret it as an 8-bit unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not an 8-bit unsigned integer. + + Returns an int. """ value = self.get_int() @@ -474,56 +495,63 @@ class Tokenizer(object): '%d is not an unsigned 8-bit integer' % value) return value - def get_uint16(self): + def get_uint16(self, base=10): """Read the next token and interpret it as a 16-bit unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not a 16-bit unsigned integer. + + Returns an int. """ - value = self.get_int() + value = self.get_int(base=base) if value < 0 or value > 65535: - raise dns.exception.SyntaxError( - '%d is not an unsigned 16-bit integer' % value) + if base == 8: + raise dns.exception.SyntaxError( + '%o is not an octal unsigned 16-bit integer' % value) + else: + raise dns.exception.SyntaxError( + '%d is not an unsigned 16-bit integer' % value) return value - def get_uint32(self): + def get_uint32(self, base=10): """Read the next token and interpret it as a 32-bit unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not a 32-bit unsigned integer. + + Returns an int. """ - token = self.get().unescape() - if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') - if not token.value.isdigit(): - raise dns.exception.SyntaxError('expecting an integer') - value = long(token.value) - if value < 0 or value > long(4294967296): + value = self.get_int(base=base) + if value < 0 or value > 4294967295: raise dns.exception.SyntaxError( '%d is not an unsigned 32-bit integer' % value) return value - def get_string(self, origin=None): + def get_string(self, max_length=None): """Read the next token and interpret it as a string. - @raises dns.exception.SyntaxError: - @rtype: string + Raises dns.exception.SyntaxError if not a string. + Raises dns.exception.SyntaxError if token value length + exceeds max_length (if specified). + + Returns a string. """ token = self.get().unescape() if not (token.is_identifier() or token.is_quoted_string()): raise dns.exception.SyntaxError('expecting a string') + if max_length and len(token.value) > max_length: + raise dns.exception.SyntaxError("string too long") return token.value - def get_identifier(self, origin=None): - """Read the next token and raise an exception if it is not an identifier. + def get_identifier(self): + """Read the next token, which should be an identifier. - @raises dns.exception.SyntaxError: - @rtype: string + Raises dns.exception.SyntaxError if not an identifier. + + Returns a string. """ token = self.get().unescape() @@ -531,23 +559,53 @@ class Tokenizer(object): raise dns.exception.SyntaxError('expecting an identifier') return token.value - def get_name(self, origin=None): - """Read the next token and interpret it as a DNS name. + def concatenate_remaining_identifiers(self): + """Read the remaining tokens on the line, which should be identifiers. - @raises dns.exception.SyntaxError: - @rtype: dns.name.Name object""" + Raises dns.exception.SyntaxError if a token is seen that is not an + identifier. - token = self.get() + Returns a string containing a concatenation of the remaining + identifiers. + """ + s = "" + while True: + token = self.get().unescape() + if token.is_eol_or_eof(): + break + if not token.is_identifier(): + raise dns.exception.SyntaxError + s += token.value + return s + + def as_name(self, token, origin=None, relativize=False, relativize_to=None): + """Try to interpret the token as a DNS name. + + Raises dns.exception.SyntaxError if not a name. + + Returns a dns.name.Name. + """ if not token.is_identifier(): raise dns.exception.SyntaxError('expecting an identifier') - return dns.name.from_text(token.value, origin) + name = dns.name.from_text(token.value, origin, self.idna_codec) + return name.choose_relativity(relativize_to or origin, relativize) + + def get_name(self, origin=None, relativize=False, relativize_to=None): + """Read the next token and interpret it as a DNS name. + + Raises dns.exception.SyntaxError if not a name. + + Returns a dns.name.Name. + """ + + token = self.get() + return self.as_name(token, origin, relativize, relativize_to) def get_eol(self): """Read the next token and raise an exception if it isn't EOL or EOF. - @raises dns.exception.SyntaxError: - @rtype: string + Returns a string. """ token = self.get() @@ -558,6 +616,14 @@ class Tokenizer(object): return token.value def get_ttl(self): + """Read the next token and interpret it as a DNS TTL. + + Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an + identifier or badly formed. + + Returns an int. + """ + token = self.get().unescape() if not token.is_identifier(): raise dns.exception.SyntaxError('expecting an identifier') diff --git a/lib/dns/tsig.py b/lib/dns/tsig.py index 92ce8603..8f34fe67 100644 --- a/lib/dns/tsig.py +++ b/lib/dns/tsig.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -15,15 +17,15 @@ """DNS TSIG support.""" +import base64 +import hashlib import hmac import struct -import sys import dns.exception -import dns.hash import dns.rdataclass import dns.name -from ._compat import long, string_types +import dns.rcode class BadTime(dns.exception.DNSException): @@ -35,6 +37,16 @@ class BadSignature(dns.exception.DNSException): """The TSIG signature fails to verify.""" +class BadKey(dns.exception.DNSException): + + """The TSIG record owner name does not match the key.""" + + +class BadAlgorithm(dns.exception.DNSException): + + """The TSIG algorithm does not match the key.""" + + class PeerError(dns.exception.DNSException): """Base class for all TSIG errors generated by the remote peer""" @@ -69,83 +81,88 @@ HMAC_SHA384 = dns.name.from_text("hmac-sha384") HMAC_SHA512 = dns.name.from_text("hmac-sha512") _hashes = { - HMAC_SHA224: 'SHA224', - HMAC_SHA256: 'SHA256', - HMAC_SHA384: 'SHA384', - HMAC_SHA512: 'SHA512', - HMAC_SHA1: 'SHA1', - HMAC_MD5: 'MD5', + HMAC_SHA224: hashlib.sha224, + HMAC_SHA256: hashlib.sha256, + HMAC_SHA384: hashlib.sha384, + HMAC_SHA512: hashlib.sha512, + HMAC_SHA1: hashlib.sha1, + HMAC_MD5: hashlib.md5, } -default_algorithm = HMAC_MD5 - -BADSIG = 16 -BADKEY = 17 -BADTIME = 18 -BADTRUNC = 22 +default_algorithm = HMAC_SHA256 -def sign(wire, keyname, secret, time, fudge, original_id, error, - other_data, request_mac, ctx=None, multi=False, first=True, - algorithm=default_algorithm): - """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata - for the input parameters, the HMAC MAC calculated by applying the - TSIG signature algorithm, and the TSIG digest context. - @rtype: (string, string, hmac.HMAC object) +def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, + multi=None): + """Return a context containing the TSIG rdata for the input parameters + @rtype: hmac.HMAC object @raises ValueError: I{other_data} is too long @raises NotImplementedError: I{algorithm} is not supported """ - (algorithm_name, digestmod) = get_algorithm(algorithm) + first = not (ctx and multi) if first: - ctx = hmac.new(secret, digestmod=digestmod) - ml = len(request_mac) - if ml > 0: - ctx.update(struct.pack('!H', ml)) + ctx = get_context(key) + if request_mac: + ctx.update(struct.pack('!H', len(request_mac))) ctx.update(request_mac) - id = struct.pack('!H', original_id) - ctx.update(id) + ctx.update(struct.pack('!H', rdata.original_id)) ctx.update(wire[2:]) if first: - ctx.update(keyname.to_digestable()) + ctx.update(key.name.to_digestable()) ctx.update(struct.pack('!H', dns.rdataclass.ANY)) ctx.update(struct.pack('!I', 0)) - long_time = time + long(0) - upper_time = (long_time >> 32) & long(0xffff) - lower_time = long_time & long(0xffffffff) - time_mac = struct.pack('!HIH', upper_time, lower_time, fudge) - pre_mac = algorithm_name + time_mac - ol = len(other_data) - if ol > 65535: + if time is None: + time = rdata.time_signed + upper_time = (time >> 32) & 0xffff + lower_time = time & 0xffffffff + time_encoded = struct.pack('!HIH', upper_time, lower_time, rdata.fudge) + other_len = len(rdata.other) + if other_len > 65535: raise ValueError('TSIG Other Data is > 65535 bytes') - post_mac = struct.pack('!HH', error, ol) + other_data if first: - ctx.update(pre_mac) - ctx.update(post_mac) + ctx.update(key.algorithm.to_digestable() + time_encoded) + ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other) else: - ctx.update(time_mac) - mac = ctx.digest() - mpack = struct.pack('!H', len(mac)) - tsig_rdata = pre_mac + mpack + mac + id + post_mac + ctx.update(time_encoded) + return ctx + + +def _maybe_start_digest(key, mac, multi): + """If this is the first message in a multi-message sequence, + start a new context. + @rtype: hmac.HMAC object + """ if multi: - ctx = hmac.new(secret, digestmod=digestmod) - ml = len(mac) - ctx.update(struct.pack('!H', ml)) + ctx = get_context(key) + ctx.update(struct.pack('!H', len(mac))) ctx.update(mac) + return ctx else: - ctx = None - return (tsig_rdata, mac, ctx) + return None -def hmac_md5(wire, keyname, secret, time, fudge, original_id, error, - other_data, request_mac, ctx=None, multi=False, first=True, - algorithm=default_algorithm): - return sign(wire, keyname, secret, time, fudge, original_id, error, - other_data, request_mac, ctx, multi, first, algorithm) +def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False): + """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata + for the input parameters, the HMAC MAC calculated by applying the + TSIG signature algorithm, and the TSIG digest context. + @rtype: (string, hmac.HMAC object) + @raises ValueError: I{other_data} is too long + @raises NotImplementedError: I{algorithm} is not supported + """ + + ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi) + mac = ctx.digest() + tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, + key.algorithm, time, rdata.fudge, mac, + rdata.original_id, rdata.error, + rdata.other) + + return (tsig, _maybe_start_digest(key, mac, multi)) -def validate(wire, keyname, secret, now, request_mac, tsig_start, tsig_rdata, - tsig_rdlen, ctx=None, multi=False, first=True): +def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, + multi=False): """Validate the specified TSIG rdata against the other input parameters. @raises FormError: The TSIG is badly formed. @@ -159,75 +176,59 @@ def validate(wire, keyname, secret, now, request_mac, tsig_start, tsig_rdata, raise dns.exception.FormError adcount -= 1 new_wire = wire[0:10] + struct.pack("!H", adcount) + wire[12:tsig_start] - current = tsig_rdata - (aname, used) = dns.name.from_wire(wire, current) - current = current + used - (upper_time, lower_time, fudge, mac_size) = \ - struct.unpack("!HIHH", wire[current:current + 10]) - time = ((upper_time + long(0)) << 32) + (lower_time + long(0)) - current += 10 - mac = wire[current:current + mac_size] - current += mac_size - (original_id, error, other_size) = \ - struct.unpack("!HHH", wire[current:current + 6]) - current += 6 - other_data = wire[current:current + other_size] - current += other_size - if current != tsig_rdata + tsig_rdlen: - raise dns.exception.FormError - if error != 0: - if error == BADSIG: + if rdata.error != 0: + if rdata.error == dns.rcode.BADSIG: raise PeerBadSignature - elif error == BADKEY: + elif rdata.error == dns.rcode.BADKEY: raise PeerBadKey - elif error == BADTIME: + elif rdata.error == dns.rcode.BADTIME: raise PeerBadTime - elif error == BADTRUNC: + elif rdata.error == dns.rcode.BADTRUNC: raise PeerBadTruncation else: - raise PeerError('unknown TSIG error code %d' % error) - time_low = time - fudge - time_high = time + fudge - if now < time_low or now > time_high: + raise PeerError('unknown TSIG error code %d' % rdata.error) + if abs(rdata.time_signed - now) > rdata.fudge: raise BadTime - (junk, our_mac, ctx) = sign(new_wire, keyname, secret, time, fudge, - original_id, error, other_data, - request_mac, ctx, multi, first, aname) - if (our_mac != mac): + if key.name != owner: + raise BadKey + if key.algorithm != rdata.algorithm: + raise BadAlgorithm + ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi) + mac = ctx.digest() + if not hmac.compare_digest(mac, rdata.mac): raise BadSignature - return ctx + return _maybe_start_digest(key, mac, multi) -def get_algorithm(algorithm): - """Returns the wire format string and the hash module to use for the - specified TSIG algorithm +def get_context(key): + """Returns an HMAC context foe the specified key. - @rtype: (string, hash constructor) + @rtype: HMAC context @raises NotImplementedError: I{algorithm} is not supported """ - if isinstance(algorithm, string_types): - algorithm = dns.name.from_text(algorithm) - try: - return (algorithm.to_digestable(), dns.hash.hashes[_hashes[algorithm]]) + digestmod = _hashes[key.algorithm] except KeyError: - raise NotImplementedError("TSIG algorithm " + str(algorithm) + - " is not supported") + raise NotImplementedError(f"TSIG algorithm {key.algorithm} " + + "is not supported") + return hmac.new(key.secret, digestmod=digestmod) -def get_algorithm_and_mac(wire, tsig_rdata, tsig_rdlen): - """Return the tsig algorithm for the specified tsig_rdata - @raises FormError: The TSIG is badly formed. - """ - current = tsig_rdata - (aname, used) = dns.name.from_wire(wire, current) - current = current + used - (upper_time, lower_time, fudge, mac_size) = \ - struct.unpack("!HIHH", wire[current:current + 10]) - current += 10 - mac = wire[current:current + mac_size] - current += mac_size - if current > tsig_rdata + tsig_rdlen: - raise dns.exception.FormError - return (aname, mac) +class Key: + def __init__(self, name, secret, algorithm=default_algorithm): + if isinstance(name, str): + name = dns.name.from_text(name) + self.name = name + if isinstance(secret, str): + secret = base64.decodebytes(secret.encode()) + self.secret = secret + if isinstance(algorithm, str): + algorithm = dns.name.from_text(algorithm) + self.algorithm = algorithm + + def __eq__(self, other): + return (isinstance(other, Key) and + self.name == other.name and + self.secret == other.secret and + self.algorithm == other.algorithm) diff --git a/lib/dns/tsigkeyring.py b/lib/dns/tsigkeyring.py index 295bac14..4afee57f 100644 --- a/lib/dns/tsigkeyring.py +++ b/lib/dns/tsigkeyring.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -21,26 +23,37 @@ import dns.name def from_text(textring): - """Convert a dictionary containing (textual DNS name, base64 secret) pairs - into a binary keyring which has (dns.name.Name, binary secret) pairs. + """Convert a dictionary containing (textual DNS name, base64 secret) + pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or + a dictionary containing (textual DNS name, (algorithm, base64 secret)) + pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs. @rtype: dict""" keyring = {} - for keytext in textring: - keyname = dns.name.from_text(keytext) - secret = base64.decodestring(textring[keytext]) - keyring[keyname] = secret + for (name, value) in textring.items(): + name = dns.name.from_text(name) + if isinstance(value, str): + keyring[name] = dns.tsig.Key(name, value).secret + else: + (algorithm, secret) = value + keyring[name] = dns.tsig.Key(name, secret, algorithm) return keyring def to_text(keyring): - """Convert a dictionary containing (dns.name.Name, binary secret) pairs - into a text keyring which has (textual DNS name, base64 secret) pairs. + """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs + into a text keyring which has (textual DNS name, (textual algorithm, + base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes) + pairs into a text keyring which has (textual DNS name, base64 secret) pairs. @rtype: dict""" textring = {} - for keyname in keyring: - keytext = keyname.to_text() - secret = base64.encodestring(keyring[keyname]) - textring[keytext] = secret + def b64encode(secret): + return base64.encodebytes(secret).decode().rstrip() + for (name, key) in keyring.items(): + name = name.to_text() + if isinstance(key, bytes): + textring[name] = b64encode(key) + else: + textring[name] = (key.algorithm.to_text(), b64encode(key.secret)) return textring diff --git a/lib/dns/ttl.py b/lib/dns/ttl.py index a27d8251..55ae5e16 100644 --- a/lib/dns/ttl.py +++ b/lib/dns/ttl.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -16,11 +18,9 @@ """DNS TTL conversion.""" import dns.exception -from ._compat import long class BadTTL(dns.exception.SyntaxError): - """DNS TTL value is not well-formed.""" @@ -29,33 +29,34 @@ def from_text(text): The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported. - @param text: the textual TTL - @type text: string - @raises dns.ttl.BadTTL: the TTL is not well-formed - @rtype: int + *text*, a ``str``, the textual TTL. + + Raises ``dns.ttl.BadTTL`` if the TTL is not well-formed. + + Returns an ``int``. """ if text.isdigit(): - total = long(text) + total = int(text) else: if not text[0].isdigit(): raise BadTTL - total = long(0) - current = long(0) + total = 0 + current = 0 for c in text: if c.isdigit(): current *= 10 - current += long(c) + current += int(c) else: c = c.lower() if c == 'w': - total += current * long(604800) + total += current * 604800 elif c == 'd': - total += current * long(86400) + total += current * 86400 elif c == 'h': - total += current * long(3600) + total += current * 3600 elif c == 'm': - total += current * long(60) + total += current * 60 elif c == 's': total += current else: @@ -63,6 +64,6 @@ def from_text(text): current = 0 if not current == 0: raise BadTTL("trailing integer") - if total < long(0) or total > long(2147483647): + if total < 0 or total > 2147483647: raise BadTTL("TTL should be between 0 and 2^31 - 1 (inclusive)") return total diff --git a/lib/dns/update.py b/lib/dns/update.py index 59728d98..8e796504 100644 --- a/lib/dns/update.py +++ b/lib/dns/update.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -23,62 +25,99 @@ import dns.rdata import dns.rdataclass import dns.rdataset import dns.tsig -from ._compat import string_types -class Update(dns.message.Message): +class UpdateSection(dns.enum.IntEnum): + """Update sections""" + ZONE = 0 + PREREQ = 1 + UPDATE = 2 + ADDITIONAL = 3 - def __init__(self, zone, rdclass=dns.rdataclass.IN, keyring=None, - keyname=None, keyalgorithm=dns.tsig.default_algorithm): + @classmethod + def _maximum(cls): + return 3 + +globals().update(UpdateSection.__members__) + + +class UpdateMessage(dns.message.Message): + + _section_enum = UpdateSection + + def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None, + keyname=None, keyalgorithm=dns.tsig.default_algorithm, + id=None): """Initialize a new DNS Update object. - @param zone: The zone which is being updated. - @type zone: A dns.name.Name or string - @param rdclass: The class of the zone; defaults to dns.rdataclass.IN. - @type rdclass: An int designating the class, or a string whose value - is the name of a class. - @param keyring: The TSIG keyring to use; defaults to None. - @type keyring: dict - @param keyname: The name of the TSIG key to use; defaults to None. - The key must be defined in the keyring. If a keyring is specified - but a keyname is not, then the key used will be the first key in the - keyring. Note that the order of keys in a dictionary is not defined, - so applications should supply a keyname when a keyring is used, unless - they know the keyring contains only one key. - @type keyname: dns.name.Name or string - @param keyalgorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm. Constants for TSIG algorithms are defined - in dns.tsig, and the currently implemented algorithms are - HMAC_MD5, HMAC_SHA1, HMAC_SHA224, HMAC_SHA256, HMAC_SHA384, and - HMAC_SHA512. - @type keyalgorithm: string + See the documentation of the Message class for a complete + description of the keyring dictionary. + + *zone*, a ``dns.name.Name``, ``str``, or ``None``, the zone + which is being updated. ``None`` should only be used by dnspython's + message constructors, as a zone is required for the convenience + methods like ``add()``, ``replace()``, etc. + + *rdclass*, an ``int`` or ``str``, the class of the zone. + + The *keyring*, *keyname*, and *keyalgorithm* parameters are passed to + ``use_tsig()``; see its documentation for details. """ - super(Update, self).__init__() + super().__init__(id=id) self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE) - if isinstance(zone, string_types): + if isinstance(zone, str): zone = dns.name.from_text(zone) self.origin = zone - if isinstance(rdclass, string_types): - rdclass = dns.rdataclass.from_text(rdclass) + rdclass = dns.rdataclass.RdataClass.make(rdclass) self.zone_rdclass = rdclass - self.find_rrset(self.question, self.origin, rdclass, dns.rdatatype.SOA, - create=True, force_unique=True) + if self.origin: + self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA, + create=True, force_unique=True) if keyring is not None: self.use_tsig(keyring, keyname, algorithm=keyalgorithm) + @property + def zone(self): + """The zone section.""" + return self.sections[0] + + @zone.setter + def zone(self, v): + self.sections[0] = v + + @property + def prerequisite(self): + """The prerequisite section.""" + return self.sections[1] + + @prerequisite.setter + def prerequisite(self, v): + self.sections[1] = v + + @property + def update(self): + """The update section.""" + return self.sections[2] + + @update.setter + def update(self, v): + self.sections[2] = v + def _add_rr(self, name, ttl, rd, deleting=None, section=None): """Add a single RR to the update section.""" if section is None: - section = self.authority + section = self.update covers = rd.covers() rrset = self.find_rrset(section, name, self.zone_rdclass, rd.rdtype, covers, deleting, True, True) rrset.add(rd, ttl) def _add(self, replace, section, name, *args): - """Add records. The first argument is the replace mode. If - false, RRs are added to an existing RRset; if true, the RRset + """Add records. + + *replace* is the replacement mode. If ``False``, + RRs are added to an existing RRset; if ``True``, the RRset is replaced with the specified contents. The second argument is the section to add to. The third argument is always a name. The other arguments can be: @@ -87,9 +126,10 @@ class Update(dns.message.Message): - ttl, rdata... - - ttl, rdtype, string...""" + - ttl, rdtype, string... + """ - if isinstance(name, string_types): + if isinstance(name, str): name = dns.name.from_text(name, None) if isinstance(args[0], dns.rdataset.Rdataset): for rds in args: @@ -106,9 +146,7 @@ class Update(dns.message.Message): for rd in args: self._add_rr(name, ttl, rd, section=section) else: - rdtype = args.pop(0) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) + rdtype = dns.rdatatype.RdataType.make(args.pop(0)) if replace: self.delete(name, rdtype) for s in args: @@ -117,32 +155,39 @@ class Update(dns.message.Message): self._add_rr(name, ttl, rd, section=section) def add(self, name, *args): - """Add records. The first argument is always a name. The other + """Add records. + + The first argument is always a name. The other arguments can be: - rdataset... - ttl, rdata... - - ttl, rdtype, string...""" - self._add(False, self.authority, name, *args) + - ttl, rdtype, string... + """ + + self._add(False, self.update, name, *args) def delete(self, name, *args): - """Delete records. The first argument is always a name. The other + """Delete records. + + The first argument is always a name. The other arguments can be: - - I{nothing} + - *empty* - rdataset... - rdata... - - rdtype, [string...]""" + - rdtype, [string...] + """ - if isinstance(name, string_types): + if isinstance(name, str): name = dns.name.from_text(name, None) if len(args) == 0: - self.find_rrset(self.authority, name, dns.rdataclass.ANY, + self.find_rrset(self.update, name, dns.rdataclass.ANY, dns.rdatatype.ANY, dns.rdatatype.NONE, dns.rdatatype.ANY, True, True) elif isinstance(args[0], dns.rdataset.Rdataset): @@ -155,11 +200,9 @@ class Update(dns.message.Message): for rd in args: self._add_rr(name, 0, rd, dns.rdataclass.NONE) else: - rdtype = args.pop(0) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) + rdtype = dns.rdatatype.RdataType.make(args.pop(0)) if len(args) == 0: - self.find_rrset(self.authority, name, + self.find_rrset(self.update, name, self.zone_rdclass, rdtype, dns.rdatatype.NONE, dns.rdataclass.ANY, @@ -171,7 +214,9 @@ class Update(dns.message.Message): self._add_rr(name, 0, rd, dns.rdataclass.NONE) def replace(self, name, *args): - """Replace records. The first argument is always a name. The other + """Replace records. + + The first argument is always a name. The other arguments can be: - rdataset... @@ -181,26 +226,30 @@ class Update(dns.message.Message): - ttl, rdtype, string... Note that if you want to replace the entire node, you should do - a delete of the name followed by one or more calls to add.""" + a delete of the name followed by one or more calls to add. + """ - self._add(True, self.authority, name, *args) + self._add(True, self.update, name, *args) def present(self, name, *args): """Require that an owner name (and optionally an rdata type, or specific rdataset) exists as a prerequisite to the - execution of the update. The first argument is always a name. + execution of the update. + + The first argument is always a name. The other arguments can be: - rdataset... - rdata... - - rdtype, string...""" + - rdtype, string... + """ - if isinstance(name, string_types): + if isinstance(name, str): name = dns.name.from_text(name, None) if len(args) == 0: - self.find_rrset(self.answer, name, + self.find_rrset(self.prerequisite, name, dns.rdataclass.ANY, dns.rdatatype.ANY, dns.rdatatype.NONE, None, True, True) @@ -211,12 +260,10 @@ class Update(dns.message.Message): # Add a 0 TTL args = list(args) args.insert(0, 0) - self._add(False, self.answer, name, *args) + self._add(False, self.prerequisite, name, *args) else: - rdtype = args[0] - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - self.find_rrset(self.answer, name, + rdtype = dns.rdatatype.RdataType.make(args[0]) + self.find_rrset(self.prerequisite, name, dns.rdataclass.ANY, rdtype, dns.rdatatype.NONE, None, True, True) @@ -225,25 +272,41 @@ class Update(dns.message.Message): """Require that an owner name (and optionally an rdata type) does not exist as a prerequisite to the execution of the update.""" - if isinstance(name, string_types): + if isinstance(name, str): name = dns.name.from_text(name, None) if rdtype is None: - self.find_rrset(self.answer, name, + self.find_rrset(self.prerequisite, name, dns.rdataclass.NONE, dns.rdatatype.ANY, dns.rdatatype.NONE, None, True, True) else: - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - self.find_rrset(self.answer, name, + rdtype = dns.rdatatype.RdataType.make(rdtype) + self.find_rrset(self.prerequisite, name, dns.rdataclass.NONE, rdtype, dns.rdatatype.NONE, None, True, True) - def to_wire(self, origin=None, max_size=65535): - """Return a string containing the update in DNS compressed wire - format. - @rtype: string""" - if origin is None: - origin = self.origin - return super(Update, self).to_wire(origin, max_size) + def _get_one_rr_per_rrset(self, value): + # Updates are always one_rr_per_rrset + return True + + def _parse_rr_header(self, section, name, rdclass, rdtype): + deleting = None + empty = False + if section == UpdateSection.ZONE: + if dns.rdataclass.is_metaclass(rdclass) or \ + rdtype != dns.rdatatype.SOA or \ + self.zone: + raise dns.exception.FormError + else: + if not self.zone: + raise dns.exception.FormError + if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE): + deleting = rdclass + rdclass = self.zone[0].rdclass + empty = (deleting == dns.rdataclass.ANY or + section == UpdateSection.PREREQ) + return (rdclass, rdtype, deleting, empty) + +# backwards compatibility +Update = UpdateMessage diff --git a/lib/dns/version.py b/lib/dns/version.py index 3d97f696..0b7c1d13 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,20 +17,30 @@ """dnspython release version information.""" -MAJOR = 1 -MINOR = 14 +#: MAJOR +MAJOR = 2 +#: MINOR +MINOR = 0 +#: MICRO MICRO = 0 +#: RELEASELEVEL RELEASELEVEL = 0x0f +#: SERIAL SERIAL = 0 -if RELEASELEVEL == 0x0f: +if RELEASELEVEL == 0x0f: # pragma: no cover + #: version version = '%d.%d.%d' % (MAJOR, MINOR, MICRO) -elif RELEASELEVEL == 0x00: - version = '%d.%d.%dx%d' % \ +elif RELEASELEVEL == 0x00: # pragma: no cover + version = '%d.%d.%ddev%d' % \ (MAJOR, MINOR, MICRO, SERIAL) -else: +elif RELEASELEVEL == 0x0c: # pragma: no cover + version = '%d.%d.%drc%d' % \ + (MAJOR, MINOR, MICRO, SERIAL) +else: # pragma: no cover version = '%d.%d.%d%x%d' % \ (MAJOR, MINOR, MICRO, RELEASELEVEL, SERIAL) +#: hexversion hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | \ SERIAL diff --git a/lib/dns/wire.py b/lib/dns/wire.py new file mode 100644 index 00000000..a3149605 --- /dev/null +++ b/lib/dns/wire.py @@ -0,0 +1,82 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import contextlib +import struct + +import dns.exception +import dns.name + +class Parser: + def __init__(self, wire, current=0): + self.wire = wire + self.current = 0 + self.end = len(self.wire) + if current: + self.seek(current) + self.furthest = current + + def remaining(self): + return self.end - self.current + + def get_bytes(self, size): + if size > self.remaining(): + raise dns.exception.FormError + output = self.wire[self.current:self.current + size] + self.current += size + self.furthest = max(self.furthest, self.current) + return output + + def get_counted_bytes(self, length_size=1): + length = int.from_bytes(self.get_bytes(length_size), 'big') + return self.get_bytes(length) + + def get_remaining(self): + return self.get_bytes(self.remaining()) + + def get_uint8(self): + return struct.unpack('!B', self.get_bytes(1))[0] + + def get_uint16(self): + return struct.unpack('!H', self.get_bytes(2))[0] + + def get_uint32(self): + return struct.unpack('!I', self.get_bytes(4))[0] + + def get_struct(self, format): + return struct.unpack(format, self.get_bytes(struct.calcsize(format))) + + def get_name(self, origin=None): + name = dns.name.from_wire_parser(self) + if origin: + name = name.relativize(origin) + return name + + def seek(self, where): + # Note that seeking to the end is OK! (If you try to read + # after such a seek, you'll get an exception as expected.) + if where < 0 or where > self.end: + raise dns.exception.FormError + self.current = where + + @contextlib.contextmanager + def restrict_to(self, size): + if size > self.remaining(): + raise dns.exception.FormError + saved_end = self.end + try: + self.end = self.current + size + yield + # We make this check here and not in the finally as we + # don't want to raise if we're already raising for some + # other reason. + if self.current != self.end: + raise dns.exception.FormError + finally: + self.end = saved_end + + @contextlib.contextmanager + def restore_furthest(self): + try: + yield None + finally: + self.current = self.furthest diff --git a/lib/dns/wiredata.py b/lib/dns/wiredata.py deleted file mode 100644 index b381f7b9..00000000 --- a/lib/dns/wiredata.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (C) 2011 Nominum, Inc. -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose with or without fee is hereby granted, -# provided that the above copyright notice and this permission notice -# appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -"""DNS Wire Data Helper""" - - -import dns.exception -from ._compat import binary_type, string_types - -# Figure out what constant python passes for an unspecified slice bound. -# It's supposed to be sys.maxint, yet on 64-bit windows sys.maxint is 2^31 - 1 -# but Python uses 2^63 - 1 as the constant. Rather than making pointless -# extra comparisons, duplicating code, or weakening WireData, we just figure -# out what constant Python will use. - - -class _SliceUnspecifiedBound(str): - - def __getslice__(self, i, j): - return j - -_unspecified_bound = _SliceUnspecifiedBound('')[1:] - - -class WireData(binary_type): - # WireData is a string with stricter slicing - - def __getitem__(self, key): - try: - if isinstance(key, slice): - return WireData(super(WireData, self).__getitem__(key)) - return bytearray(self.unwrap())[key] - except IndexError: - raise dns.exception.FormError - - def __getslice__(self, i, j): - try: - if j == _unspecified_bound: - # handle the case where the right bound is unspecified - j = len(self) - if i < 0 or j < 0: - raise dns.exception.FormError - # If it's not an empty slice, access left and right bounds - # to make sure they're valid - if i != j: - super(WireData, self).__getitem__(i) - super(WireData, self).__getitem__(j - 1) - return WireData(super(WireData, self).__getslice__(i, j)) - except IndexError: - raise dns.exception.FormError - - def __iter__(self): - i = 0 - while 1: - try: - yield self[i] - i += 1 - except dns.exception.FormError: - raise StopIteration - - def unwrap(self): - return binary_type(self) - - -def maybe_wrap(wire): - if isinstance(wire, WireData): - return wire - elif isinstance(wire, binary_type): - return WireData(wire) - elif isinstance(wire, string_types): - return WireData(wire.encode()) - raise ValueError("unhandled type %s" % type(wire)) diff --git a/lib/dns/zone.py b/lib/dns/zone.py index ae099bd8..e8413c08 100644 --- a/lib/dns/zone.py +++ b/lib/dns/zone.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -15,11 +17,11 @@ """DNS Zones.""" -from __future__ import generators - -import sys +import contextlib +import io +import os import re -from io import BytesIO +import sys import dns.exception import dns.name @@ -27,11 +29,11 @@ import dns.node import dns.rdataclass import dns.rdatatype import dns.rdata +import dns.rdtypes.ANY.SOA import dns.rrset import dns.tokenizer import dns.ttl import dns.grange -from ._compat import string_types, text_type class BadZone(dns.exception.DNSException): @@ -54,28 +56,16 @@ class UnknownOrigin(BadZone): """The DNS zone's origin is unknown.""" -class Zone(object): +class Zone: """A DNS zone. - A Zone is a mapping from names to nodes. The zone object may be - treated like a Python dictionary, e.g. zone[name] will retrieve - the node associated with that name. The I{name} may be a - dns.name.Name object, or it may be a string. In the either case, + A ``Zone`` is a mapping from names to nodes. The zone object may be + treated like a Python dictionary, e.g. ``zone[name]`` will retrieve + the node associated with that name. The *name* may be a + ``dns.name.Name object``, or it may be a string. In either case, if the name is relative it is treated as relative to the origin of the zone. - - @ivar rdclass: The zone's rdata class; the default is class IN. - @type rdclass: int - @ivar origin: The origin of the zone. - @type origin: dns.name.Name object - @ivar nodes: A dictionary mapping the names of nodes in the zone to the - nodes themselves. - @type nodes: dict - @ivar relativize: should names in the zone be relativized? - @type relativize: bool - @cvar node_factory: the factory used to create a new node - @type node_factory: class or callable """ node_factory = dns.node.Node @@ -85,13 +75,18 @@ class Zone(object): def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True): """Initialize a zone object. - @param origin: The origin of the zone. - @type origin: dns.name.Name object - @param rdclass: The zone's rdata class; the default is class IN. - @type rdclass: int""" + *origin* is the origin of the zone. It may be a ``dns.name.Name``, + a ``str``, or ``None``. If ``None``, then the zone's origin will + be set by the first ``$ORIGIN`` line in a masterfile. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + """ if origin is not None: - if isinstance(origin, string_types): + if isinstance(origin, str): origin = dns.name.from_text(origin) elif not isinstance(origin, dns.name.Name): raise ValueError("origin parameter must be convertible to a " @@ -106,7 +101,8 @@ class Zone(object): def __eq__(self, other): """Two zones are equal if they have the same origin, class, and nodes. - @rtype: bool + + Returns a ``bool``. """ if not isinstance(other, Zone): @@ -119,13 +115,14 @@ class Zone(object): def __ne__(self, other): """Are two zones not equal? - @rtype: bool + + Returns a ``bool``. """ return not self.__eq__(other) def _validate_name(self, name): - if isinstance(name, string_types): + if isinstance(name, str): name = dns.name.from_text(name, None) elif not isinstance(name, dns.name.Name): raise KeyError("name parameter must be convertible to a DNS name") @@ -150,24 +147,16 @@ class Zone(object): del self.nodes[key] def __iter__(self): - return self.nodes.iterkeys() - - def iterkeys(self): - return self.nodes.iterkeys() + return self.nodes.__iter__() def keys(self): - return self.nodes.keys() - - def itervalues(self): - return self.nodes.itervalues() + return self.nodes.keys() # pylint: disable=dict-keys-not-iterating def values(self): - return self.nodes.values() + return self.nodes.values() # pylint: disable=dict-values-not-iterating def items(self): - return self.nodes.items() - - iteritems = items + return self.nodes.items() # pylint: disable=dict-items-not-iterating def get(self, key): key = self._validate_name(key) @@ -179,12 +168,18 @@ class Zone(object): def find_node(self, name, create=False): """Find a node in the zone, possibly creating it. - @param name: the name of the node to find - @type name: dns.name.Name object or string - @param create: should the node be created if it doesn't exist? - @type create: bool - @raises KeyError: the name is not known and create was not specified. - @rtype: dns.node.Node object + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.node.Node``. """ name = self._validate_name(name) @@ -199,15 +194,22 @@ class Zone(object): def get_node(self, name, create=False): """Get a node in the zone, possibly creating it. - This method is like L{find_node}, except it returns None instead + This method is like ``find_node()``, except it returns None instead of raising an exception if the node does not exist and creation has not been requested. - @param name: the name of the node to find - @type name: dns.name.Name object or string - @param create: should the node be created if it doesn't exist? - @type create: bool - @rtype: dns.node.Node object or None + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.node.Node`` or ``None``. """ try: @@ -219,6 +221,11 @@ class Zone(object): def delete_node(self, name): """Delete the specified node if it exists. + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + It is not an error if the node does not exist. """ @@ -228,65 +235,82 @@ class Zone(object): def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, create=False): - """Look for rdata with the specified name and type in the zone, + """Look for an rdataset with the specified name and type in the zone, and return an rdataset encapsulating it. - The I{name}, I{rdtype}, and I{covers} parameters may be - strings, in which case they will be converted to their proper - type. - The rdataset returned is not a copy; changes to it will change the zone. KeyError is raised if the name or type are not found. - Use L{get_rdataset} if you want to have None returned instead. - @param name: the owner name to look for - @type name: DNS.name.Name object or string - @param rdtype: the rdata type desired - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string - @param create: should the node and rdataset be created if they do not - exist? - @type create: bool - @raises KeyError: the node or rdata could not be found - @rtype: dns.rrset.RRset object + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rdataset.Rdataset``. """ name = self._validate_name(name) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if isinstance(covers, string_types): - covers = dns.rdatatype.from_text(covers) + rdtype = dns.rdatatype.RdataType.make(rdtype) + if covers is not None: + covers = dns.rdatatype.RdataType.make(covers) node = self.find_node(name, create) return node.find_rdataset(self.rdclass, rdtype, covers, create) def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, create=False): - """Look for rdata with the specified name and type in the zone, - and return an rdataset encapsulating it. + """Look for an rdataset with the specified name and type in the zone. - The I{name}, I{rdtype}, and I{covers} parameters may be - strings, in which case they will be converted to their proper - type. + This method is like ``find_rdataset()``, except it returns None instead + of raising an exception if the rdataset does not exist and creation + has not been requested. The rdataset returned is not a copy; changes to it will change the zone. - None is returned if the name or type are not found. - Use L{find_rdataset} if you want to have KeyError raised instead. + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. - @param name: the owner name to look for - @type name: DNS.name.Name object or string - @param rdtype: the rdata type desired - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string - @param create: should the node and rdataset be created if they do not - exist? - @type create: bool - @rtype: dns.rrset.RRset object + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rdataset.Rdataset`` or ``None``. """ try: @@ -296,12 +320,8 @@ class Zone(object): return rdataset def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): - """Delete the rdataset matching I{rdtype} and I{covers}, if it - exists at the node specified by I{name}. - - The I{name}, I{rdtype}, and I{covers} parameters may be - strings, in which case they will be converted to their proper - type. + """Delete the rdataset matching *rdtype* and *covers*, if it + exists at the node specified by *name*. It is not an error if the node does not exist, or if there is no matching rdataset at the node. @@ -309,19 +329,28 @@ class Zone(object): If the node has no rdatasets after the deletion, it will itself be deleted. - @param name: the owner name to look for - @type name: DNS.name.Name object or string - @param rdtype: the rdata type desired - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. """ name = self._validate_name(name) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if isinstance(covers, string_types): - covers = dns.rdatatype.from_text(covers) + rdtype = dns.rdatatype.RdataType.make(rdtype) + if covers is not None: + covers = dns.rdatatype.RdataType.make(covers) node = self.get_node(name) if node is not None: node.delete_rdataset(self.rdclass, rdtype, covers) @@ -333,16 +362,18 @@ class Zone(object): It is not an error if there is no rdataset matching I{replacement}. - Ownership of the I{replacement} object is transferred to the zone; - in other words, this method does not store a copy of I{replacement} - at the node, it stores I{replacement} itself. + Ownership of the *replacement* object is transferred to the zone; + in other words, this method does not store a copy of *replacement* + at the node, it stores *replacement* itself. - If the I{name} node does not exist, it is created. + If the node does not exist, it is created. - @param name: the owner name - @type name: DNS.name.Name object or string - @param replacement: the replacement rdataset - @type replacement: dns.rdataset.Rdataset + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *replacement*, a ``dns.rdataset.Rdataset``, the replacement rdataset. """ if replacement.rdclass != self.rdclass: @@ -351,71 +382,89 @@ class Zone(object): node.replace_rdataset(replacement) def find_rrset(self, name, rdtype, covers=dns.rdatatype.NONE): - """Look for rdata with the specified name and type in the zone, + """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. - The I{name}, I{rdtype}, and I{covers} parameters may be - strings, in which case they will be converted to their proper - type. - This method is less efficient than the similar - L{find_rdataset} because it creates an RRset instead of + ``find_rdataset()`` because it creates an RRset instead of returning the matching rdataset. It may be more convenient for some uses since it returns an object which binds the owner - name to the rdata. + name to the rdataset. This method may not be used to create new nodes or rdatasets; - use L{find_rdataset} instead. + use ``find_rdataset`` instead. - KeyError is raised if the name or type are not found. - Use L{get_rrset} if you want to have None returned instead. + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. - @param name: the owner name to look for - @type name: DNS.name.Name object or string - @param rdtype: the rdata type desired - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string - @raises KeyError: the node or rdata could not be found - @rtype: dns.rrset.RRset object + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rrset.RRset`` or ``None``. """ name = self._validate_name(name) - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if isinstance(covers, string_types): - covers = dns.rdatatype.from_text(covers) + rdtype = dns.rdatatype.RdataType.make(rdtype) + if covers is not None: + covers = dns.rdatatype.RdataType.make(covers) rdataset = self.nodes[name].find_rdataset(self.rdclass, rdtype, covers) rrset = dns.rrset.RRset(name, self.rdclass, rdtype, covers) rrset.update(rdataset) return rrset def get_rrset(self, name, rdtype, covers=dns.rdatatype.NONE): - """Look for rdata with the specified name and type in the zone, + """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. - The I{name}, I{rdtype}, and I{covers} parameters may be - strings, in which case they will be converted to their proper - type. - - This method is less efficient than the similar L{get_rdataset} + This method is less efficient than the similar ``get_rdataset()`` because it creates an RRset instead of returning the matching rdataset. It may be more convenient for some uses since it - returns an object which binds the owner name to the rdata. + returns an object which binds the owner name to the rdataset. This method may not be used to create new nodes or rdatasets; - use L{find_rdataset} instead. + use ``get_rdataset()`` instead. - None is returned if the name or type are not found. - Use L{find_rrset} if you want to have KeyError raised instead. + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. - @param name: the owner name to look for - @type name: DNS.name.Name object or string - @param rdtype: the rdata type desired - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string - @rtype: dns.rrset.RRset object + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rrset.RRset`` or ``None``. """ try: @@ -427,21 +476,27 @@ class Zone(object): def iterate_rdatasets(self, rdtype=dns.rdatatype.ANY, covers=dns.rdatatype.NONE): """Return a generator which yields (name, rdataset) tuples for - all rdatasets in the zone which have the specified I{rdtype} - and I{covers}. If I{rdtype} is dns.rdatatype.ANY, the default, + all rdatasets in the zone which have the specified *rdtype* + and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, then all rdatasets will be matched. - @param rdtype: int or string - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. """ - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if isinstance(covers, string_types): - covers = dns.rdatatype.from_text(covers) - for (name, node) in self.iteritems(): + rdtype = dns.rdatatype.RdataType.make(rdtype) + if covers is not None: + covers = dns.rdatatype.RdataType.make(covers) + for (name, node) in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or \ (rds.rdtype == rdtype and rds.covers == covers): @@ -450,21 +505,27 @@ class Zone(object): def iterate_rdatas(self, rdtype=dns.rdatatype.ANY, covers=dns.rdatatype.NONE): """Return a generator which yields (name, ttl, rdata) tuples for - all rdatas in the zone which have the specified I{rdtype} - and I{covers}. If I{rdtype} is dns.rdatatype.ANY, the default, + all rdatas in the zone which have the specified *rdtype* + and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, then all rdatas will be matched. - @param rdtype: int or string - @type rdtype: int or string - @param covers: the covered type (defaults to None) - @type covers: int or string + *rdtype*, an ``int`` or ``str``, the rdata type desired. + + *covers*, an ``int`` or ``str`` or ``None``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. """ - if isinstance(rdtype, string_types): - rdtype = dns.rdatatype.from_text(rdtype) - if isinstance(covers, string_types): - covers = dns.rdatatype.from_text(covers) - for (name, node) in self.iteritems(): + rdtype = dns.rdatatype.RdataType.make(rdtype) + if covers is not None: + covers = dns.rdatatype.RdataType.make(covers) + for (name, node) in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or \ (rds.rdtype == rdtype and rds.covers == covers): @@ -474,70 +535,83 @@ class Zone(object): def to_file(self, f, sorted=True, relativize=True, nl=None): """Write a zone to a file. - @param f: file or string. If I{f} is a string, it is treated + *f*, a file or `str`. If *f* is a string, it is treated as the name of a file to open. - @param sorted: if True, the file will be written with the - names sorted in DNSSEC order from least to greatest. Otherwise - the names will be written in whatever order they happen to have - in the zone's dictionary. - @param relativize: if True, domain names in the output will be - relativized to the zone's origin (if possible). - @type relativize: bool - @param nl: The end of line string. If not specified, the - output will use the platform's native end-of-line marker (i.e. - LF on POSIX, CRLF on Windows, CR on Macintosh). - @type nl: string or None + + *sorted*, a ``bool``. If True, the default, then the file + will be written with the names sorted in DNSSEC order from + least to greatest. Otherwise the names will be written in + whatever order they happen to have in the zone's dictionary. + + *relativize*, a ``bool``. If True, the default, then domain + names in the output will be relativized to the zone's origin + if possible. + + *nl*, a ``str`` or None. The end of line string. If not + ``None``, the output will use the platform's native + end-of-line marker (i.e. LF on POSIX, CRLF on Windows). """ - str_type = string_types + with contextlib.ExitStack() as stack: + if isinstance(f, str): + f = stack.enter_context(open(f, 'wb')) - if nl is None: - opts = 'wb' - else: - opts = 'wb' + # must be in this way, f.encoding may contain None, or even + # attribute may not be there + file_enc = getattr(f, 'encoding', None) + if file_enc is None: + file_enc = 'utf-8' + + if nl is None: + # binary mode, '\n' is not enough + nl_b = os.linesep.encode(file_enc) + nl = '\n' + elif isinstance(nl, str): + nl_b = nl.encode(file_enc) + else: + nl_b = nl + nl = nl.decode() - if isinstance(f, str_type): - f = open(f, opts) - want_close = True - else: - want_close = False - try: if sorted: names = list(self.keys()) names.sort() else: - names = self.iterkeys() + names = self.keys() for n in names: l = self[n].to_text(n, origin=self.origin, relativize=relativize) - if isinstance(l, text_type): - l = l.encode() - if nl is None: - f.write(l) - f.write('\n') + if isinstance(l, str): + l_b = l.encode(file_enc) else: + l_b = l + l = l.decode() + + try: + f.write(l_b) + f.write(nl_b) + except TypeError: # textual mode f.write(l) f.write(nl) - finally: - if want_close: - f.close() def to_text(self, sorted=True, relativize=True, nl=None): """Return a zone's text as though it were written to a file. - @param sorted: if True, the file will be written with the - names sorted in DNSSEC order from least to greatest. Otherwise - the names will be written in whatever order they happen to have - in the zone's dictionary. - @param relativize: if True, domain names in the output will be - relativized to the zone's origin (if possible). - @type relativize: bool - @param nl: The end of line string. If not specified, the - output will use the platform's native end-of-line marker (i.e. - LF on POSIX, CRLF on Windows, CR on Macintosh). - @type nl: string or None + *sorted*, a ``bool``. If True, the default, then the file + will be written with the names sorted in DNSSEC order from + least to greatest. Otherwise the names will be written in + whatever order they happen to have in the zone's dictionary. + + *relativize*, a ``bool``. If True, the default, then domain + names in the output will be relativized to the zone's origin + if possible. + + *nl*, a ``str`` or None. The end of line string. If not + ``None``, the output will use the platform's native + end-of-line marker (i.e. LF on POSIX, CRLF on Windows). + + Returns a ``str``. """ - temp_buffer = BytesIO() + temp_buffer = io.StringIO() self.to_file(temp_buffer, sorted, relativize, nl) return_value = temp_buffer.getvalue() temp_buffer.close() @@ -546,9 +620,11 @@ class Zone(object): def check_origin(self): """Do some simple checking of the zone's origin. - @raises dns.zone.NoSOA: there is no SOA RR - @raises dns.zone.NoNS: there is no NS RRset - @raises KeyError: there is no origin node + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. """ if self.relativize: name = dns.name.empty @@ -560,14 +636,20 @@ class Zone(object): raise NoNS -class _MasterReader(object): +class _MasterReader: """Read a DNS master file @ivar tok: The tokenizer @type tok: dns.tokenizer.Tokenizer object - @ivar ttl: The default TTL - @type ttl: int + @ivar last_ttl: The last seen explicit TTL for an RR + @type last_ttl: int + @ivar last_ttl_known: Has last TTL been detected + @type last_ttl_known: bool + @ivar default_ttl: The default TTL from a $TTL directive or SOA RR + @type default_ttl: int + @ivar default_ttl_known: Has default TTL been detected + @type default_ttl_known: bool @ivar last_name: The last name read @type last_name: dns.name.Name object @ivar current_origin: The current origin @@ -577,8 +659,8 @@ class _MasterReader(object): @ivar zone: the zone @type zone: dns.zone.Zone object @ivar saved_state: saved reader state (used when processing $INCLUDE) - @type saved_state: list of (tokenizer, current_origin, last_name, file) - tuples. + @type saved_state: list of (tokenizer, current_origin, last_name, file, + last_ttl, last_ttl_known, default_ttl, default_ttl_known) tuples. @ivar current_file: the file object of the $INCLUDed file being parsed (None if no $INCLUDE is active). @ivar allow_include: is $INCLUDE allowed? @@ -590,12 +672,15 @@ class _MasterReader(object): def __init__(self, tok, origin, rdclass, relativize, zone_factory=Zone, allow_include=False, check_origin=True): - if isinstance(origin, string_types): + if isinstance(origin, str): origin = dns.name.from_text(origin) self.tok = tok self.current_origin = origin self.relativize = relativize - self.ttl = 0 + self.last_ttl = 0 + self.last_ttl_known = False + self.default_ttl = 0 + self.default_ttl_known = False self.last_name = self.current_origin self.zone = zone_factory(origin, rdclass, relativize=relativize) self.saved_state = [] @@ -616,8 +701,7 @@ class _MasterReader(object): raise UnknownOrigin token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = dns.name.from_text( - token.value, self.current_origin) + self.last_name = self.tok.as_name(token, self.current_origin) else: token = self.tok.get() if token.is_eol_or_eof(): @@ -633,14 +717,22 @@ class _MasterReader(object): token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError + # TTL + ttl = None try: ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError except dns.ttl.BadTTL: - ttl = self.ttl + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + # Class try: rdclass = dns.rdataclass.from_text(token.value) @@ -648,15 +740,15 @@ class _MasterReader(object): if not token.is_identifier(): raise dns.exception.SyntaxError except dns.exception.SyntaxError: - raise dns.exception.SyntaxError - except: + raise + except Exception: rdclass = self.zone.rdclass if rdclass != self.zone.rdclass: raise dns.exception.SyntaxError("RR class is not zone's class") # Type try: rdtype = dns.rdatatype.from_text(token.value) - except: + except Exception: raise dns.exception.SyntaxError( "unknown rdatatype '%s'" % token.value) n = self.zone.nodes.get(name) @@ -665,12 +757,12 @@ class _MasterReader(object): self.zone.nodes[name] = n try: rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.current_origin, False) + self.current_origin, self.relativize, + self.zone.origin) except dns.exception.SyntaxError: # Catch and reraise. - (ty, va) = sys.exc_info()[:2] - raise va - except: + raise + except Exception: # All exceptions that occur in the processing of rdata # are treated as syntax errors. This is not strictly # correct, but it is correct almost all of the time. @@ -678,9 +770,23 @@ class _MasterReader(object): # helpful filename:line info. (ty, va) = sys.exc_info()[:2] raise dns.exception.SyntaxError( - "caught exception %s: %s" % (str(ty), str(va))) + "caught exception {}: {}".format(str(ty), str(va))) + + if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: + # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default + # TTL from the SOA minttl if no $TTL statement is present before the + # SOA is parsed. + self.default_ttl = rd.minimum + self.default_ttl_known = True + if ttl is None: + # if we didn't have a TTL on the SOA, set it! + ttl = rd.minimum + + # TTL check. We had to wait until now to do this as the SOA RR's + # own TTL can be inferred from its minimum. + if ttl is None: + raise dns.exception.SyntaxError("Missing default TTL value") - rd.choose_relativity(self.zone.origin, self.relativize) covers = rd.covers() rds = n.find_rdataset(rdclass, rdtype, covers, True) rds.add(rd, ttl) @@ -688,9 +794,9 @@ class _MasterReader(object): def _parse_modify(self, side): # Here we catch everything in '{' '}' in a group so we can replace it # with ''. - is_generate1 = re.compile("^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") - is_generate2 = re.compile("^.*\$({(\+|-?)(\d+)}).*$") - is_generate3 = re.compile("^.*\$({(\+|-?)(\d+),(\d+)}).*$") + is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") + is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$") + is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$") # Sometimes there are modifiers in the hostname. These come after # the dollar sign. They are in the form: ${offset[,width[,base]]}. # Make names @@ -708,10 +814,9 @@ class _MasterReader(object): base = 'd' g3 = is_generate3.match(side) if g3: - mod, sign, offset, width = g1.groups() + mod, sign, offset, width = g3.groups() if sign == '': sign = '+' - width = g1.groups()[2] base = 'd' if not (g1 or g2 or g3): @@ -722,7 +827,7 @@ class _MasterReader(object): base = 'd' if base != 'd': - raise NotImplemented + raise NotImplementedError() return mod, sign, offset, width, base @@ -740,7 +845,7 @@ class _MasterReader(object): token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError - except: + except Exception: raise dns.exception.SyntaxError # lhs (required) @@ -749,17 +854,24 @@ class _MasterReader(object): token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError - except: + except Exception: raise dns.exception.SyntaxError # TTL try: ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError except dns.ttl.BadTTL: - ttl = self.ttl + if not (self.last_ttl_known or self.default_ttl_known): + raise dns.exception.SyntaxError("Missing default TTL value") + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl # Class try: rdclass = dns.rdataclass.from_text(token.value) @@ -768,7 +880,7 @@ class _MasterReader(object): raise dns.exception.SyntaxError except dns.exception.SyntaxError: raise dns.exception.SyntaxError - except: + except Exception: rdclass = self.zone.rdclass if rdclass != self.zone.rdclass: raise dns.exception.SyntaxError("RR class is not zone's class") @@ -778,38 +890,36 @@ class _MasterReader(object): token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError - except: + except Exception: raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) - # lhs (required) - try: - rhs = token.value - except: - raise dns.exception.SyntaxError + # rhs (required) + rhs = token.value lmod, lsign, loffset, lwidth, lbase = self._parse_modify(lhs) rmod, rsign, roffset, rwidth, rbase = self._parse_modify(rhs) for i in range(start, stop + 1, step): # +1 because bind is inclusive and python is exclusive - if lsign == u'+': + if lsign == '+': lindex = i + int(loffset) - elif lsign == u'-': + elif lsign == '-': lindex = i - int(loffset) - if rsign == u'-': + if rsign == '-': rindex = i - int(roffset) - elif rsign == u'+': + elif rsign == '+': rindex = i + int(roffset) lzfindex = str(lindex).zfill(int(lwidth)) rzfindex = str(rindex).zfill(int(rwidth)) - name = lhs.replace(u'$%s' % (lmod), lzfindex) - rdata = rhs.replace(u'$%s' % (rmod), rzfindex) + name = lhs.replace('$%s' % (lmod), lzfindex) + rdata = rhs.replace('$%s' % (rmod), rzfindex) - self.last_name = dns.name.from_text(name, self.current_origin) + self.last_name = dns.name.from_text(name, self.current_origin, + self.tok.idna_codec) name = self.last_name if not name.is_subdomain(self.zone.origin): self._eat_line() @@ -823,12 +933,12 @@ class _MasterReader(object): self.zone.nodes[name] = n try: rd = dns.rdata.from_text(rdclass, rdtype, rdata, - self.current_origin, False) + self.current_origin, self.relativize, + self.zone.origin) except dns.exception.SyntaxError: # Catch and reraise. - (ty, va) = sys.exc_info()[:2] - raise va - except: + raise + except Exception: # All exceptions that occur in the processing of rdata # are treated as syntax errors. This is not strictly # correct, but it is correct almost all of the time. @@ -838,7 +948,6 @@ class _MasterReader(object): raise dns.exception.SyntaxError("caught exception %s: %s" % (str(ty), str(va))) - rd.choose_relativity(self.zone.origin, self.relativize) covers = rd.covers() rds = n.find_rdataset(rdclass, rdtype, covers, True) rds.add(rd, ttl) @@ -861,7 +970,10 @@ class _MasterReader(object): self.current_origin, self.last_name, self.current_file, - self.ttl) = self.saved_state.pop(-1) + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known) = self.saved_state.pop(-1) continue break elif token.is_eol(): @@ -869,27 +981,29 @@ class _MasterReader(object): elif token.is_comment(): self.tok.get_eol() continue - elif token.value[0] == u'$': + elif token.value[0] == '$': c = token.value.upper() - if c == u'$TTL': + if c == '$TTL': token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError("bad $TTL") - self.ttl = dns.ttl.from_text(token.value) + self.default_ttl = dns.ttl.from_text(token.value) + self.default_ttl_known = True self.tok.get_eol() - elif c == u'$ORIGIN': + elif c == '$ORIGIN': self.current_origin = self.tok.get_name() self.tok.get_eol() if self.zone.origin is None: self.zone.origin = self.current_origin - elif c == u'$INCLUDE' and self.allow_include: + elif c == '$INCLUDE' and self.allow_include: token = self.tok.get() filename = token.value token = self.tok.get() if token.is_identifier(): new_origin =\ dns.name.from_text(token.value, - self.current_origin) + self.current_origin, + self.tok.idna_codec) self.tok.get_eol() elif not token.is_eol_or_eof(): raise dns.exception.SyntaxError( @@ -900,12 +1014,15 @@ class _MasterReader(object): self.current_origin, self.last_name, self.current_file, - self.ttl)) + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known)) self.current_file = open(filename, 'r') self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) self.current_origin = new_origin - elif c == u'$GENERATE': + elif c == '$GENERATE': self._generate_line() else: raise dns.exception.SyntaxError( @@ -917,8 +1034,10 @@ class _MasterReader(object): (filename, line_number) = self.tok.where() if detail is None: detail = "syntax error" - raise dns.exception.SyntaxError( + ex = dns.exception.SyntaxError( "%s:%d: %s" % (filename, line_number, detail)) + tb = sys.exc_info()[2] + raise ex.with_traceback(tb) from None # Now that we're done reading, do some basic checking of the zone. if self.check_origin: @@ -927,32 +1046,46 @@ class _MasterReader(object): def from_text(text, origin=None, rdclass=dns.rdataclass.IN, relativize=True, zone_factory=Zone, filename=None, - allow_include=False, check_origin=True): + allow_include=False, check_origin=True, idna_codec=None): """Build a zone object from a master file format string. - @param text: the master file format input - @type text: string. - @param origin: The origin of the zone; if not specified, the first - $ORIGIN statement in the master file will determine the origin of the - zone. - @type origin: dns.name.Name object or string - @param rdclass: The zone's rdata class; the default is class IN. - @type rdclass: int - @param relativize: should names be relativized? The default is True - @type relativize: bool - @param zone_factory: The zone factory to use - @type zone_factory: function returning a Zone - @param filename: The filename to emit when describing where an error - occurred; the default is ''. - @type filename: string - @param allow_include: is $INCLUDE allowed? - @type allow_include: bool - @param check_origin: should sanity checks of the origin node be done? - The default is True. - @type check_origin: bool - @raises dns.zone.NoSOA: No SOA RR was found at the zone origin - @raises dns.zone.NoNS: No NS RRset was found at the zone origin - @rtype: dns.zone.Zone object + *text*, a ``str``, the master file format input. + + *origin*, a ``dns.name.Name``, a ``str``, or ``None``. The origin + of the zone; if not specified, the first ``$ORIGIN`` statement in the + masterfile will determine the origin of the zone. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *zone_factory*, the zone factory to use or ``None``. If ``None``, then + ``dns.zone.Zone`` will be used. The value may be any class or callable + that returns a subclass of ``dns.zone.Zone``. + + *filename*, a ``str`` or ``None``, the filename to emit when + describing where an error occurred; the default is ``''``. + + *allow_include*, a ``bool``. If ``True``, the default, then ``$INCLUDE`` + directives are permitted. If ``False``, then encoutering a ``$INCLUDE`` + will raise a ``SyntaxError`` exception. + + *check_origin*, a ``bool``. If ``True``, the default, then sanity + checks of the origin node will be made by calling the zone's + ``check_origin()`` method. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + + Returns a subclass of ``dns.zone.Zone``. """ # 'text' can also be a file, but we don't publish that fact @@ -961,7 +1094,7 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, if filename is None: filename = '' - tok = dns.tokenizer.Tokenizer(text, filename) + tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) reader = _MasterReader(tok, origin, rdclass, relativize, zone_factory, allow_include=allow_include, check_origin=check_origin) @@ -974,69 +1107,77 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN, allow_include=True, check_origin=True): """Read a master file and build a zone object. - @param f: file or string. If I{f} is a string, it is treated + *f*, a file or ``str``. If *f* is a string, it is treated as the name of a file to open. - @param origin: The origin of the zone; if not specified, the first - $ORIGIN statement in the master file will determine the origin of the - zone. - @type origin: dns.name.Name object or string - @param rdclass: The zone's rdata class; the default is class IN. - @type rdclass: int - @param relativize: should names be relativized? The default is True - @type relativize: bool - @param zone_factory: The zone factory to use - @type zone_factory: function returning a Zone - @param filename: The filename to emit when describing where an error - occurred; the default is '', or the value of I{f} if I{f} is a - string. - @type filename: string - @param allow_include: is $INCLUDE allowed? - @type allow_include: bool - @param check_origin: should sanity checks of the origin node be done? - The default is True. - @type check_origin: bool - @raises dns.zone.NoSOA: No SOA RR was found at the zone origin - @raises dns.zone.NoNS: No NS RRset was found at the zone origin - @rtype: dns.zone.Zone object + + *origin*, a ``dns.name.Name``, a ``str``, or ``None``. The origin + of the zone; if not specified, the first ``$ORIGIN`` statement in the + masterfile will determine the origin of the zone. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *zone_factory*, the zone factory to use or ``None``. If ``None``, then + ``dns.zone.Zone`` will be used. The value may be any class or callable + that returns a subclass of ``dns.zone.Zone``. + + *filename*, a ``str`` or ``None``, the filename to emit when + describing where an error occurred; the default is ``''``. + + *allow_include*, a ``bool``. If ``True``, the default, then ``$INCLUDE`` + directives are permitted. If ``False``, then encoutering a ``$INCLUDE`` + will raise a ``SyntaxError`` exception. + + *check_origin*, a ``bool``. If ``True``, the default, then sanity + checks of the origin node will be made by calling the zone's + ``check_origin()`` method. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + + Returns a subclass of ``dns.zone.Zone``. """ - str_type = string_types - opts = 'rU' - - if isinstance(f, str_type): - if filename is None: - filename = f - f = open(f, opts) - want_close = True - else: - if filename is None: - filename = '' - want_close = False - - try: - z = from_text(f, origin, rdclass, relativize, zone_factory, - filename, allow_include, check_origin) - finally: - if want_close: - f.close() - return z + with contextlib.ExitStack() as stack: + if isinstance(f, str): + if filename is None: + filename = f + f = stack.enter_context(open(f)) + return from_text(f, origin, rdclass, relativize, zone_factory, + filename, allow_include, check_origin) def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): """Convert the output of a zone transfer generator into a zone object. - @param xfr: The xfr generator - @type xfr: generator of dns.message.Message objects - @param relativize: should names be relativized? The default is True. + *xfr*, a generator of ``dns.message.Message`` objects, typically + ``dns.query.xfr()``. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. It is essential that the relativize setting matches the one specified - to dns.query.xfr(). - @type relativize: bool - @param check_origin: should sanity checks of the origin node be done? - The default is True. - @type check_origin: bool - @raises dns.zone.NoSOA: No SOA RR was found at the zone origin - @raises dns.zone.NoNS: No NS RRset was found at the zone origin - @rtype: dns.zone.Zone object + to the generator. + + *check_origin*, a ``bool``. If ``True``, the default, then sanity + checks of the origin node will be made by calling the zone's + ``check_origin()`` method. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + + Returns a subclass of ``dns.zone.Zone``. """ z = None @@ -1057,7 +1198,6 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): rrset.covers, True) zrds.update_ttl(rrset.ttl) for rd in rrset: - rd.choose_relativity(z.origin, relativize) zrds.add(rd) if check_origin: z.check_origin()