diff --git a/lib/dns/__init__.py b/lib/dns/__init__.py index 0473ca17..9abdf018 100644 --- a/lib/dns/__init__.py +++ b/lib/dns/__init__.py @@ -18,49 +18,52 @@ """dnspython DNS toolkit""" __all__ = [ - 'asyncbackend', - 'asyncquery', - 'asyncresolver', - 'dnssec', - 'e164', - 'edns', - 'entropy', - 'exception', - 'flags', - 'immutable', - 'inet', - 'ipv4', - 'ipv6', - 'message', - 'name', - 'namedict', - 'node', - 'opcode', - 'query', - 'rcode', - 'rdata', - 'rdataclass', - 'rdataset', - 'rdatatype', - 'renderer', - 'resolver', - 'reversename', - 'rrset', - 'serial', - 'set', - 'tokenizer', - 'transaction', - 'tsig', - 'tsigkeyring', - 'ttl', - 'rdtypes', - 'update', - 'version', - 'versioned', - 'wire', - 'xfr', - 'zone', - 'zonefile', + "asyncbackend", + "asyncquery", + "asyncresolver", + "dnssec", + "dnssectypes", + "e164", + "edns", + "entropy", + "exception", + "flags", + "immutable", + "inet", + "ipv4", + "ipv6", + "message", + "name", + "namedict", + "node", + "opcode", + "query", + "quic", + "rcode", + "rdata", + "rdataclass", + "rdataset", + "rdatatype", + "renderer", + "resolver", + "reversename", + "rrset", + "serial", + "set", + "tokenizer", + "transaction", + "tsig", + "tsigkeyring", + "ttl", + "rdtypes", + "update", + "version", + "versioned", + "wire", + "xfr", + "zone", + "zonetypes", + "zonefile", ] from dns.version import version as __version__ # noqa diff --git a/lib/dns/_asyncbackend.py b/lib/dns/_asyncbackend.py index 1f3a8287..ff24604f 100644 --- a/lib/dns/_asyncbackend.py +++ b/lib/dns/_asyncbackend.py @@ -3,6 +3,7 @@ # This is a nullcontext for both sync and async. 3.7 has a nullcontext, # but it is only for sync use. + class NullContext: def __init__(self, enter_result=None): self.enter_result = enter_result @@ -23,6 +24,7 @@ class NullContext: # These are declared here so backends can import them without creating # circular dependencies with dns.asyncbackend. + class Socket: # pragma: no cover async def close(self): pass @@ -41,6 +43,9 @@ class Socket: # pragma: no cover class DatagramSocket(Socket): # pragma: no cover + def __init__(self, family: int): + self.family = family + async def sendto(self, what, destination, timeout): raise NotImplementedError @@ -56,14 +61,25 @@ class StreamSocket(Socket): # pragma: no cover raise NotImplementedError -class Backend: # pragma: no cover +class Backend: # pragma: no cover def name(self): - return 'unknown' + return "unknown" - async def make_socket(self, af, socktype, proto=0, - source=None, destination=None, timeout=None, - ssl_context=None, server_hostname=None): + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): raise NotImplementedError def datagram_connection_required(self): return False + + async def sleep(self, interval): + raise NotImplementedError diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py index d737d13c..82a06249 100644 --- a/lib/dns/_asyncio_backend.py +++ b/lib/dns/_asyncio_backend.py @@ -10,7 +10,8 @@ import dns._asyncbackend import dns.exception -_is_win32 = sys.platform == 'win32' +_is_win32 = sys.platform == "win32" + def _get_running_loop(): try: @@ -30,7 +31,6 @@ class _DatagramProtocol: def datagram_received(self, data, addr): if self.recvfrom and not self.recvfrom.done(): self.recvfrom.set_result((data, addr)) - self.recvfrom = None def error_received(self, exc): # pragma: no cover if self.recvfrom and not self.recvfrom.done(): @@ -56,30 +56,34 @@ async def _maybe_wait_for(awaitable, timeout): class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, transport, protocol): - self.family = family + super().__init__(family) self.transport = transport self.protocol = protocol async def sendto(self, what, destination, timeout): # pragma: no cover # no timeout for asyncio sendto self.transport.sendto(what, destination) + return len(what) async def recvfrom(self, size, timeout): # ignore size as there's no way I know to tell protocol about it done = _get_running_loop().create_future() - assert self.protocol.recvfrom is None - self.protocol.recvfrom = done - await _maybe_wait_for(done, timeout) - return done.result() + try: + assert self.protocol.recvfrom is None + self.protocol.recvfrom = done + await _maybe_wait_for(done, timeout) + return done.result() + finally: + self.protocol.recvfrom = None async def close(self): self.protocol.close() async def getpeername(self): - return self.transport.get_extra_info('peername') + return self.transport.get_extra_info("peername") async def getsockname(self): - return self.transport.get_extra_info('sockname') + return self.transport.get_extra_info("sockname") class StreamSocket(dns._asyncbackend.StreamSocket): @@ -93,8 +97,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket): return await _maybe_wait_for(self.writer.drain(), timeout) async def recv(self, size, timeout): - return await _maybe_wait_for(self.reader.read(size), - timeout) + return await _maybe_wait_for(self.reader.read(size), timeout) async def close(self): self.writer.close() @@ -104,43 +107,64 @@ class StreamSocket(dns._asyncbackend.StreamSocket): pass async def getpeername(self): - return self.writer.get_extra_info('peername') + return self.writer.get_extra_info("peername") async def getsockname(self): - return self.writer.get_extra_info('sockname') + return self.writer.get_extra_info("sockname") class Backend(dns._asyncbackend.Backend): def name(self): - return 'asyncio' + return "asyncio" - async def make_socket(self, af, socktype, proto=0, - source=None, destination=None, timeout=None, - ssl_context=None, server_hostname=None): - if destination is None and socktype == socket.SOCK_DGRAM and \ - _is_win32: - raise NotImplementedError('destinationless datagram sockets ' - 'are not supported by asyncio ' - 'on Windows') + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): + if destination is None and socktype == socket.SOCK_DGRAM and _is_win32: + raise NotImplementedError( + "destinationless datagram sockets " + "are not supported by asyncio " + "on Windows" + ) loop = _get_running_loop() if socktype == socket.SOCK_DGRAM: transport, protocol = await loop.create_datagram_endpoint( - _DatagramProtocol, source, family=af, - proto=proto, remote_addr=destination) + _DatagramProtocol, + source, + family=af, + proto=proto, + remote_addr=destination, + ) return DatagramSocket(af, transport, protocol) elif socktype == socket.SOCK_STREAM: + if destination is None: + # This shouldn't happen, but we check to make code analysis software + # happier. + raise ValueError("destination required for stream sockets") (r, w) = await _maybe_wait_for( - asyncio.open_connection(destination[0], - destination[1], - ssl=ssl_context, - family=af, - proto=proto, - local_addr=source, - server_hostname=server_hostname), - timeout) + asyncio.open_connection( + destination[0], + destination[1], + ssl=ssl_context, + family=af, + proto=proto, + local_addr=source, + server_hostname=server_hostname, + ), + timeout, + ) return StreamSocket(af, r, w) - raise NotImplementedError('unsupported socket ' + - f'type {socktype}') # pragma: no cover + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover async def sleep(self, interval): await asyncio.sleep(interval) diff --git a/lib/dns/_curio_backend.py b/lib/dns/_curio_backend.py index 6fa7b3a1..765d6471 100644 --- a/lib/dns/_curio_backend.py +++ b/lib/dns/_curio_backend.py @@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): + super().__init__(socket.family) self.socket = socket - self.family = socket.family async def sendto(self, what, destination, timeout): async with _maybe_timeout(timeout): return await self.socket.sendto(what, destination) - raise dns.exception.Timeout(timeout=timeout) # pragma: no cover + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] async def recvfrom(self, size, timeout): async with _maybe_timeout(timeout): return await self.socket.recvfrom(size) - raise dns.exception.Timeout(timeout=timeout) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] async def close(self): await self.socket.close() @@ -57,12 +59,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket): async def sendall(self, what, timeout): async with _maybe_timeout(timeout): return await self.socket.sendall(what) - raise dns.exception.Timeout(timeout=timeout) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] async def recv(self, size, timeout): async with _maybe_timeout(timeout): return await self.socket.recv(size) - raise dns.exception.Timeout(timeout=timeout) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] async def close(self): await self.socket.close() @@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket): class Backend(dns._asyncbackend.Backend): def name(self): - return 'curio' + return "curio" - async def make_socket(self, af, socktype, proto=0, - source=None, destination=None, timeout=None, - ssl_context=None, server_hostname=None): + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): if socktype == socket.SOCK_DGRAM: s = curio.socket.socket(af, socktype, proto) try: @@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend): else: source_addr = None async with _maybe_timeout(timeout): - s = await curio.open_connection(destination[0], destination[1], - ssl=ssl_context, - source_addr=source_addr, - server_hostname=server_hostname) + s = await curio.open_connection( + destination[0], + destination[1], + ssl=ssl_context, + source_addr=source_addr, + server_hostname=server_hostname, + ) return StreamSocket(s) - raise NotImplementedError('unsupported socket ' + - f'type {socktype}') # pragma: no cover + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover async def sleep(self, interval): await curio.sleep(interval) diff --git a/lib/dns/_immutable_attr.py b/lib/dns/_immutable_attr.py deleted file mode 100644 index 4d89be90..00000000 --- a/lib/dns/_immutable_attr.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -# This implementation of the immutable decorator is for python 3.6, -# which doesn't have Context Variables. This implementation is somewhat -# costly for classes with slots, as it adds a __dict__ to them. - - -import inspect - - -class _Immutable: - """Immutable mixin class""" - - # Note we MUST NOT have __slots__ as that causes - # - # TypeError: multiple bases have instance lay-out conflict - # - # when we get mixed in with another class with slots. When we - # get mixed into something with slots, it effectively adds __dict__ to - # the slots of the other class, which allows attribute setting to work, - # albeit at the cost of the dictionary. - - def __setattr__(self, name, value): - if not hasattr(self, '_immutable_init') or \ - self._immutable_init is not self: - raise TypeError("object doesn't support attribute assignment") - else: - super().__setattr__(name, value) - - def __delattr__(self, name): - if not hasattr(self, '_immutable_init') or \ - self._immutable_init is not self: - raise TypeError("object doesn't support attribute assignment") - else: - super().__delattr__(name) - - -def _immutable_init(f): - def nf(*args, **kwargs): - try: - # Are we already initializing an immutable class? - previous = args[0]._immutable_init - except AttributeError: - # We are the first! - previous = None - object.__setattr__(args[0], '_immutable_init', args[0]) - try: - # call the actual __init__ - f(*args, **kwargs) - finally: - if not previous: - # If we started the initialization, establish immutability - # by removing the attribute that allows mutation - object.__delattr__(args[0], '_immutable_init') - nf.__signature__ = inspect.signature(f) - return nf - - -def immutable(cls): - if _Immutable in cls.__mro__: - # Some ancestor already has the mixin, so just make sure we keep - # following the __init__ protocol. - cls.__init__ = _immutable_init(cls.__init__) - if hasattr(cls, '__setstate__'): - cls.__setstate__ = _immutable_init(cls.__setstate__) - ncls = cls - else: - # Mixin the Immutable class and follow the __init__ protocol. - class ncls(_Immutable, cls): - - @_immutable_init - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if hasattr(cls, '__setstate__'): - @_immutable_init - def __setstate__(self, *args, **kwargs): - super().__setstate__(*args, **kwargs) - - # make ncls have the same name and module as cls - ncls.__name__ = cls.__name__ - ncls.__qualname__ = cls.__qualname__ - ncls.__module__ = cls.__module__ - return ncls diff --git a/lib/dns/_immutable_ctx.py b/lib/dns/_immutable_ctx.py index ececdbeb..63c0a2d3 100644 --- a/lib/dns/_immutable_ctx.py +++ b/lib/dns/_immutable_ctx.py @@ -8,7 +8,7 @@ import contextvars import inspect -_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False) +_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) class _Immutable: @@ -41,6 +41,7 @@ def _immutable_init(f): f(*args, **kwargs) finally: _in__init__.reset(previous) + nf.__signature__ = inspect.signature(f) return nf @@ -50,7 +51,7 @@ def immutable(cls): # Some ancestor already has the mixin, so just make sure we keep # following the __init__ protocol. cls.__init__ = _immutable_init(cls.__init__) - if hasattr(cls, '__setstate__'): + if hasattr(cls, "__setstate__"): cls.__setstate__ = _immutable_init(cls.__setstate__) ncls = cls else: @@ -63,7 +64,8 @@ def immutable(cls): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if hasattr(cls, '__setstate__'): + if hasattr(cls, "__setstate__"): + @_immutable_init def __setstate__(self, *args, **kwargs): super().__setstate__(*args, **kwargs) diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py index a00d4a4e..b0c02103 100644 --- a/lib/dns/_trio_backend.py +++ b/lib/dns/_trio_backend.py @@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): + super().__init__(socket.family) self.socket = socket - self.family = socket.family async def sendto(self, what, destination, timeout): with _maybe_timeout(timeout): return await self.socket.sendto(what, destination) - raise dns.exception.Timeout(timeout=timeout) # pragma: no cover + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] async def recvfrom(self, size, timeout): with _maybe_timeout(timeout): return await self.socket.recvfrom(size) - raise dns.exception.Timeout(timeout=timeout) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] async def close(self): self.socket.close() @@ -58,12 +60,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket): async def sendall(self, what, timeout): with _maybe_timeout(timeout): return await self.stream.send_all(what) - raise dns.exception.Timeout(timeout=timeout) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] async def recv(self, size, timeout): with _maybe_timeout(timeout): return await self.stream.receive_some(size) - raise dns.exception.Timeout(timeout=timeout) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] async def close(self): await self.stream.aclose() @@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket): class Backend(dns._asyncbackend.Backend): def name(self): - return 'trio' + return "trio" - async def make_socket(self, af, socktype, proto=0, source=None, - destination=None, timeout=None, - ssl_context=None, server_hostname=None): + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): s = trio.socket.socket(af, socktype, proto) stream = None try: @@ -103,19 +113,20 @@ class Backend(dns._asyncbackend.Backend): return DatagramSocket(s) elif socktype == socket.SOCK_STREAM: stream = trio.SocketStream(s) - s = None tls = False if ssl_context: tls = True try: - stream = trio.SSLStream(stream, ssl_context, - server_hostname=server_hostname) + stream = trio.SSLStream( + stream, ssl_context, server_hostname=server_hostname + ) except Exception: # pragma: no cover await stream.aclose() raise return StreamSocket(af, stream, tls) - raise NotImplementedError('unsupported socket ' + - f'type {socktype}') # pragma: no cover + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover async def sleep(self, interval): await trio.sleep(interval) diff --git a/lib/dns/asyncbackend.py b/lib/dns/asyncbackend.py index 089d3d35..c7565a99 100644 --- a/lib/dns/asyncbackend.py +++ b/lib/dns/asyncbackend.py @@ -1,26 +1,33 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Dict + import dns.exception # pylint: disable=unused-import -from dns._asyncbackend import Socket, DatagramSocket, \ - StreamSocket, Backend # noqa: +from dns._asyncbackend import ( + Socket, + DatagramSocket, + StreamSocket, + Backend, +) # noqa: F401 lgtm[py/unused-import] # pylint: enable=unused-import _default_backend = None -_backends = {} +_backends: Dict[str, Backend] = {} # Allow sniffio import to be disabled for testing purposes _no_sniffio = False + class AsyncLibraryNotFoundError(dns.exception.DNSException): pass -def get_backend(name): +def get_backend(name: str) -> Backend: """Get the specified asynchronous backend. *name*, a ``str``, the name of the backend. Currently the "trio", @@ -32,22 +39,25 @@ def get_backend(name): backend = _backends.get(name) if backend: return backend - if name == 'trio': + if name == "trio": import dns._trio_backend + backend = dns._trio_backend.Backend() - elif name == 'curio': + elif name == "curio": import dns._curio_backend + backend = dns._curio_backend.Backend() - elif name == 'asyncio': + elif name == "asyncio": import dns._asyncio_backend + backend = dns._asyncio_backend.Backend() else: - raise NotImplementedError(f'unimplemented async backend {name}') + raise NotImplementedError(f"unimplemented async backend {name}") _backends[name] = backend return backend -def sniff(): +def sniff() -> str: """Attempt to determine the in-use asynchronous I/O library by using the ``sniffio`` module if it is available. @@ -59,35 +69,32 @@ def sniff(): if _no_sniffio: raise ImportError import sniffio + try: return sniffio.current_async_library() except sniffio.AsyncLibraryNotFoundError: - raise AsyncLibraryNotFoundError('sniffio cannot determine ' + - 'async library') + raise AsyncLibraryNotFoundError( + "sniffio cannot determine " + "async library" + ) except ImportError: import asyncio + try: asyncio.get_running_loop() - return 'asyncio' + return "asyncio" except RuntimeError: - raise AsyncLibraryNotFoundError('no async library detected') - except AttributeError: # pragma: no cover - # we have to check current_task on 3.6 - if not asyncio.Task.current_task(): - raise AsyncLibraryNotFoundError('no async library detected') - return 'asyncio' + raise AsyncLibraryNotFoundError("no async library detected") -def get_default_backend(): - """Get the default backend, initializing it if necessary. - """ +def get_default_backend() -> Backend: + """Get the default backend, initializing it if necessary.""" if _default_backend: return _default_backend return set_default_backend(sniff()) -def set_default_backend(name): +def set_default_backend(name: str) -> Backend: """Set the default backend. It's not normally necessary to call this method, as diff --git a/lib/dns/asyncbackend.pyi b/lib/dns/asyncbackend.pyi deleted file mode 100644 index 1ec9d32b..00000000 --- a/lib/dns/asyncbackend.pyi +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -class Backend: - ... - -def get_backend(name: str) -> Backend: - ... -def sniff() -> str: - ... -def get_default_backend() -> Backend: - ... -def set_default_backend(name: str) -> Backend: - ... diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py index 4ec97fb7..459c611d 100644 --- a/lib/dns/asyncquery.py +++ b/lib/dns/asyncquery.py @@ -17,7 +17,10 @@ """Talk to a DNS server.""" +from typing import Any, Dict, Optional, Tuple, Union + import base64 +import contextlib import socket import struct import time @@ -27,12 +30,24 @@ import dns.exception import dns.inet import dns.name import dns.message +import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype +import dns.transaction -from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ - UDPMode, _have_httpx, _have_http2, NoDOH +from dns._asyncbackend import NullContext +from dns.query import ( + _compute_times, + _matches_destination, + BadResponse, + ssl, + UDPMode, + _have_httpx, + _have_http2, + NoDOH, + NoDOQ, +) if _have_httpx: import httpx @@ -47,11 +62,11 @@ def _source_tuple(af, address, port): if address or port: if address is None: if af == socket.AF_INET: - address = '0.0.0.0' + address = "0.0.0.0" elif af == socket.AF_INET6: - address = '::' + address = "::" else: - raise NotImplementedError(f'unknown address family {af}') + raise NotImplementedError(f"unknown address family {af}") return (address, port) else: return None @@ -66,7 +81,12 @@ def _timeout(expiration, now=None): return None -async def send_udp(sock, what, destination, expiration=None): +async def send_udp( + sock: dns.asyncbackend.DatagramSocket, + what: Union[dns.message.Message, bytes], + destination: Any, + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. @@ -78,7 +98,8 @@ async def send_udp(sock, what, destination, expiration=None): *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will - occur. + occur. The expiration value is meaningless for the asyncio backend, as + asyncio's transport sendto() never blocks. Returns an ``(int, float)`` tuple of bytes sent and the sent time. """ @@ -90,35 +111,61 @@ async def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -async def receive_udp(sock, destination=None, expiration=None, - ignore_unexpected=False, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False, - raise_on_truncation=False): +async def receive_udp( + sock: dns.asyncbackend.DatagramSocket, + destination: Optional[Any] = None, + expiration: Optional[float] = None, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, + raise_on_truncation: bool = False, +) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. See :py:func:`dns.query.receive_udp()` for the documentation of the other - parameters, exceptions, and return type of this method. + parameters, and exceptions. + + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the + received time, and the address where the message arrived from. """ - wire = b'' + wire = b"" while 1: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _matches_destination(sock.family, from_address, destination, - ignore_unexpected): + if _matches_destination( + sock.family, from_address, destination, ignore_unexpected + ): break received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) return (r, received_time, from_address) -async def udp(q, where, timeout=None, port=53, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False, sock=None, - backend=None): + +async def udp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + sock: Optional[dns.asyncbackend.DatagramSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, @@ -134,42 +181,52 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, """ wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - s = None - # After 3.6 is no longer supported, this can use an AsyncExitStack. - try: - af = dns.inet.af_for_address(where) - destination = _lltuple((where, port), af) - if sock: - s = sock + af = dns.inet.af_for_address(where) + destination = _lltuple((where, port), af) + if sock: + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + if not backend: + backend = dns.asyncbackend.get_default_backend() + stuple = _source_tuple(af, source, source_port) + if backend.datagram_connection_required(): + dtuple = (where, port) else: - if not backend: - backend = dns.asyncbackend.get_default_backend() - stuple = _source_tuple(af, source, source_port) - if backend.datagram_connection_required(): - dtuple = (where, port) - else: - dtuple = None - s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, - dtuple) + dtuple = None + cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) + async with cm as s: await send_udp(s, wire, destination, expiration) - (r, received_time, _) = await receive_udp(s, destination, expiration, - ignore_unexpected, - one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing, - raise_on_truncation) + (r, received_time, _) = await receive_udp( + s, + destination, + expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, + raise_on_truncation, + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r - finally: - if not sock and s: - await s.close() -async def udp_with_fallback(q, where, timeout=None, port=53, source=None, - source_port=0, ignore_unexpected=False, - one_rr_per_rrset=False, ignore_trailing=False, - udp_sock=None, tcp_sock=None, backend=None): + +async def udp_with_fallback( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None, + tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -191,18 +248,42 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None, method. """ try: - response = await udp(q, where, timeout, port, source, source_port, - ignore_unexpected, one_rr_per_rrset, - ignore_trailing, True, udp_sock, backend) + response = await udp( + q, + where, + timeout, + port, + source, + source_port, + ignore_unexpected, + one_rr_per_rrset, + ignore_trailing, + True, + udp_sock, + backend, + ) return (response, False) except dns.message.Truncated: - response = await tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, tcp_sock, - backend) + response = await tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + tcp_sock, + backend, + ) return (response, True) -async def send_tcp(sock, what, expiration=None): +async def send_tcp( + sock: dns.asyncbackend.StreamSocket, + what: Union[dns.message.Message, bytes], + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -212,12 +293,14 @@ async def send_tcp(sock, what, expiration=None): """ if isinstance(what, dns.message.Message): - what = what.to_wire() - l = len(what) + wire = what.to_wire() + else: + wire = what + l = len(wire) # copying the wire into tcpmsg is inefficient, but lets us # avoid writev() or doing a short write that would get pushed # onto the net - tcpmsg = struct.pack("!H", l) + what + tcpmsg = struct.pack("!H", l) + wire sent_time = time.time() await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) return (len(tcpmsg), sent_time) @@ -227,18 +310,24 @@ async def _read_exactly(sock, count, expiration): """Read the specified number of bytes from stream. Keep trying until we either get the desired amount, or we hit EOF. """ - s = b'' + s = b"" while count > 0: n = await sock.recv(count, _timeout(expiration)) - if n == b'': + if n == b"": raise EOFError count = count - len(n) s = s + n return s -async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False): +async def receive_tcp( + sock: dns.asyncbackend.StreamSocket, + expiration: Optional[float] = None, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, +) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -251,15 +340,28 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, (l,) = struct.unpack("!H", ldata) wire = await _read_exactly(sock, l, expiration) received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) return (r, received_time) -async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None, - backend=None): +async def tcp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the @@ -276,41 +378,48 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - s = None - # After 3.6 is no longer supported, this can use an AsyncExitStack. - try: - if sock: - # Verify that the socket is connected, as if it's not connected, - # it's not writable, and the polling in send_tcp() will time out or - # hang forever. - await sock.getpeername() - s = sock - else: - # These are simple (address, port) pairs, not - # family-dependent tuples you pass to lowlevel socket - # code. - af = dns.inet.af_for_address(where) - stuple = _source_tuple(af, source, source_port) - dtuple = (where, port) - if not backend: - backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, - dtuple, timeout) + if sock: + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + await sock.getpeername() + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + # These are simple (address, port) pairs, not family-dependent tuples + # you pass to low-level socket code. + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + cm = await backend.make_socket( + af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout + ) + async with cm as s: await send_tcp(s, wire, expiration) - (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing) + (r, received_time) = await receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r - finally: - if not sock and s: - await s.close() -async def tls(q, where, timeout=None, port=853, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None, - backend=None, ssl_context=None, server_hostname=None): + +async def tls( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TLS. *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket @@ -326,11 +435,14 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, See :py:func:`dns.query.tls()` for the documentation of the other parameters, exceptions, and return type of this method. """ - # After 3.6 is no longer supported, this can use an AsyncExitStack. (begin_time, expiration) = _compute_times(timeout) - if not sock: + if sock: + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: if ssl_context is None: - ssl_context = ssl.create_default_context() + # See the comment about ssl.create_default_context() in query.py + ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol] + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 if server_hostname is None: ssl_context.check_hostname = False else: @@ -341,25 +453,49 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, dtuple = (where, port) if not backend: backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, - dtuple, timeout, ssl_context, - server_hostname) - else: - s = sock - try: + cm = await backend.make_socket( + af, + socket.SOCK_STREAM, + 0, + stuple, + dtuple, + timeout, + ssl_context, + server_hostname, + ) + async with cm as s: timeout = _timeout(expiration) - response = await tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, s, backend) + response = await tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + s, + backend, + ) end_time = time.time() response.time = end_time - begin_time return response - finally: - if not sock and s: - await s.close() -async def https(q, where, timeout=None, port=443, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, client=None, - path='/dns-query', post=True, verify=True): + +async def https( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 443, + source: Optional[str] = None, + source_port: int = 0, # pylint: disable=W0613 + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + client: Optional["httpx.AsyncClient"] = None, + path: str = "/dns-query", + post: bool = True, + verify: Union[bool, str] = True, +) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *client*, a ``httpx.AsyncClient``. If provided, the client to use for @@ -373,7 +509,7 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0, """ if not _have_httpx: - raise NoDOH('httpx is not available.') # pragma: no cover + raise NoDOH("httpx is not available.") # pragma: no cover wire = q.to_wire() try: @@ -381,65 +517,78 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0, except ValueError: af = None transport = None - headers = { - "accept": "application/dns-message" - } + headers = {"accept": "application/dns-message"} if af is not None: if af == socket.AF_INET: - url = 'https://{}:{}{}'.format(where, port, path) + url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: - url = 'https://[{}]:{}{}'.format(where, port, path) + url = "https://[{}]:{}{}".format(where, port, path) else: url = where if source is not None: transport = httpx.AsyncHTTPTransport(local_address=source[0]) - # After 3.6 is no longer supported, this can use an AsyncExitStack - client_to_close = None - try: - if not client: - client = httpx.AsyncClient(http1=True, http2=_have_http2, - verify=verify, transport=transport) - client_to_close = client + if client: + cm: contextlib.AbstractAsyncContextManager = NullContext(client) + else: + cm = httpx.AsyncClient( + http1=True, http2=_have_http2, verify=verify, transport=transport + ) + async with cm as the_client: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: - headers.update({ - "content-type": "application/dns-message", - "content-length": str(len(wire)) - }) - response = await client.post(url, headers=headers, content=wire, - timeout=timeout) + headers.update( + { + "content-type": "application/dns-message", + "content-length": str(len(wire)), + } + ) + response = await the_client.post( + url, headers=headers, content=wire, timeout=timeout + ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") - wire = wire.decode() # httpx does a repr() if we give it bytes - response = await client.get(url, headers=headers, timeout=timeout, - params={"dns": wire}) - finally: - if client_to_close: - await client.aclose() + twire = wire.decode() # httpx does a repr() if we give it bytes + response = await the_client.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: - raise ValueError('{} responded with status code {}' - '\nResponse body: {}'.format(where, - response.status_code, - response.content)) - r = dns.message.from_wire(response.content, - keyring=q.keyring, - request_mac=q.request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) - r.time = response.elapsed + raise ValueError( + "{} responded with status code {}" + "\nResponse body: {!r}".format( + where, response.status_code, response.content + ) + ) + r = dns.message.from_wire( + response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = response.elapsed.total_seconds() if not q.is_response(r): raise BadResponse return r -async def inbound_xfr(where, txn_manager, query=None, - port=53, timeout=None, lifetime=None, source=None, - source_port=0, udp_mode=UDPMode.NEVER, backend=None): + +async def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> None: """Conduct an inbound transfer and apply it via a transaction from the txn_manager. @@ -472,42 +621,48 @@ async def inbound_xfr(where, txn_manager, query=None, is_udp = False if not backend: backend = dns.asyncbackend.get_default_backend() - s = await backend.make_socket(af, sock_type, 0, stuple, dtuple, - _timeout(expiration)) + s = await backend.make_socket( + af, sock_type, 0, stuple, dtuple, _timeout(expiration) + ) async with s: if is_udp: await s.sendto(wire, dtuple, _timeout(expiration)) else: tcpmsg = struct.pack("!H", len(wire)) + wire await s.sendall(tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, - is_udp) as inbound: + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: done = False tsig_ctx = None while not done: (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or \ - (expiration is not None and mexpiration > expiration): + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): mexpiration = expiration if is_udp: destination = _lltuple((where, port), af) while True: timeout = _timeout(mexpiration) - (rwire, from_address) = await s.recvfrom(65535, - timeout) - if _matches_destination(af, from_address, - destination, True): + (rwire, from_address) = await s.recvfrom(65535, timeout) + if _matches_destination( + af, from_address, destination, True + ): break else: ldata = await _read_exactly(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) rwire = await _read_exactly(s, l, mexpiration) - is_ixfr = (rdtype == dns.rdatatype.IXFR) - r = dns.message.from_wire(rwire, keyring=query.keyring, - request_mac=query.mac, xfr=True, - origin=origin, tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr) + is_ixfr = rdtype == dns.rdatatype.IXFR + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) try: done = inbound.process_message(r) except dns.xfr.UseTCP: @@ -521,3 +676,62 @@ async def inbound_xfr(where, txn_manager, query=None, tsig_ctx = r.tsig_ctx if not retry and query.keyring and not r.had_tsig: raise dns.exception.FormError("missing TSIG") + + +async def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.AsyncQuicConnection] = None, + verify: Union[bool, str] = True, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: + """Return the response obtained after sending an asynchronous query via + DNS-over-QUIC. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.quic()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.AsyncQuicConnection + if connection: + cfactory = dns.quic.null_factory + mfactory = dns.quic.null_factory + the_connection = connection + else: + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + + async with cfactory() as context: + async with mfactory(context, verify_mode=verify) as the_manager: + if not connection: + the_connection = the_manager.connect(where, port, source, source_port) + start = time.time() + stream = await the_connection.make_stream() + async with stream: + await stream.send(wire, True) + wire = await stream.receive(timeout) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r diff --git a/lib/dns/asyncquery.pyi b/lib/dns/asyncquery.pyi deleted file mode 100644 index a03434c2..00000000 --- a/lib/dns/asyncquery.pyi +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional, Union, Dict, Generator, Any -from . import tsig, rdatatype, rdataclass, name, message, asyncbackend - -# If the ssl import works, then -# -# error: Name 'ssl' already defined (by an import) -# -# is expected and can be ignored. -try: - import ssl -except ImportError: - class ssl: # type: ignore - SSLContext : Dict = {} - -async def udp(q : message.Message, where : str, - timeout : Optional[float] = None, port=53, - source : Optional[str] = None, source_port : Optional[int] = 0, - ignore_unexpected : Optional[bool] = False, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[asyncbackend.DatagramSocket] = None, - backend : Optional[asyncbackend.Backend] = None) -> message.Message: - pass - -async def tcp(q : message.Message, where : str, timeout : float = None, port=53, - af : Optional[int] = None, source : Optional[str] = None, - source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[asyncbackend.StreamSocket] = None, - backend : Optional[asyncbackend.Backend] = None) -> message.Message: - pass - -async def tls(q : message.Message, where : str, - timeout : Optional[float] = None, port=53, - source : Optional[str] = None, source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[asyncbackend.StreamSocket] = None, - backend : Optional[asyncbackend.Backend] = None, - ssl_context: Optional[ssl.SSLContext] = None, - server_hostname: Optional[str] = None) -> message.Message: - pass diff --git a/lib/dns/asyncresolver.py b/lib/dns/asyncresolver.py index ed29deed..506530e2 100644 --- a/lib/dns/asyncresolver.py +++ b/lib/dns/asyncresolver.py @@ -17,13 +17,18 @@ """Asynchronous DNS stub resolver.""" +from typing import Any, Dict, Optional, Union + import time import dns.asyncbackend import dns.asyncquery import dns.exception +import dns.name import dns.query -import dns.resolver +import dns.rdataclass +import dns.rdatatype +import dns.resolver # lgtm[py/import-and-import-from] # import some resolver symbols for brevity from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA @@ -37,11 +42,19 @@ _tcp = dns.asyncquery.tcp class Resolver(dns.resolver.BaseResolver): """Asynchronous DNS stub resolver.""" - async def resolve(self, qname, rdtype=dns.rdatatype.A, - rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None, search=None, - backend=None): + async def resolve( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, @@ -52,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver): type of this method. """ - resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, - raise_on_no_answer, search) + resolution = dns.resolver._Resolution( + self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search + ) if not backend: backend = dns.asyncbackend.get_default_backend() start = time.time() @@ -66,30 +80,40 @@ class Resolver(dns.resolver.BaseResolver): if answer is not None: # cache hit! return answer + assert request is not None # needed for type checking done = False while not done: (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: await backend.sleep(backoff) - timeout = self._compute_timeout(start, lifetime, - resolution.errors) + timeout = self._compute_timeout(start, lifetime, resolution.errors) try: if dns.inet.is_address(nameserver): if tcp: - response = await _tcp(request, nameserver, - timeout, port, - source, source_port, - backend=backend) + response = await _tcp( + request, + nameserver, + timeout, + port, + source, + source_port, + backend=backend, + ) else: - response = await _udp(request, nameserver, - timeout, port, - source, source_port, - raise_on_truncation=True, - backend=backend) + response = await _udp( + request, + nameserver, + timeout, + port, + source, + source_port, + raise_on_truncation=True, + backend=backend, + ) else: - response = await dns.asyncquery.https(request, - nameserver, - timeout=timeout) + response = await dns.asyncquery.https( + request, nameserver, timeout=timeout + ) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -101,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver): if answer is not None: return answer - async def resolve_address(self, ipaddr, *args, **kwargs): + async def resolve_address( + self, ipaddr: str, *args: Any, **kwargs: Any + ) -> dns.resolver.Answer: """Use an asynchronous resolver to run a reverse query for PTR records. @@ -116,15 +142,20 @@ class Resolver(dns.resolver.BaseResolver): function. """ - - return await self.resolve(dns.reversename.from_address(ipaddr), - rdtype=dns.rdatatype.PTR, - rdclass=dns.rdataclass.IN, - *args, **kwargs) + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs["rdtype"] = dns.rdatatype.PTR + modified_kwargs["rdclass"] = dns.rdataclass.IN + return await self.resolve( + dns.reversename.from_address(ipaddr), *args, **modified_kwargs + ) # pylint: disable=redefined-outer-name - async def canonical_name(self, name): + async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. The canonical name is the name the resolver uses for queries @@ -149,14 +180,15 @@ class Resolver(dns.resolver.BaseResolver): default_resolver = None -def get_default_resolver(): +def get_default_resolver() -> Resolver: """Get the default asynchronous resolver, initializing it if necessary.""" if default_resolver is None: reset_default_resolver() + assert default_resolver is not None return default_resolver -def reset_default_resolver(): +def reset_default_resolver() -> None: """Re-initialize default asynchronous resolver. Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX @@ -167,9 +199,18 @@ def reset_default_resolver(): default_resolver = Resolver() -async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None, search=None, backend=None): +async def resolve( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. This is a convenience function that uses the default resolver @@ -179,13 +220,23 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, information on the parameters. """ - return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, - source, raise_on_no_answer, - source_port, lifetime, search, - backend) + return await get_default_resolver().resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + backend, + ) -async def resolve_address(ipaddr, *args, **kwargs): +async def resolve_address( + ipaddr: str, *args: Any, **kwargs: Any +) -> dns.resolver.Answer: """Use a resolver to run a reverse query for PTR records. See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more @@ -194,7 +245,8 @@ async def resolve_address(ipaddr, *args, **kwargs): return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) -async def canonical_name(name): + +async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. See :py:func:`dns.resolver.Resolver.canonical_name` for more @@ -203,8 +255,14 @@ async def canonical_name(name): return await get_default_resolver().canonical_name(name) -async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, - resolver=None, backend=None): + +async def zone_for_name( + name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + tcp: bool = False, + resolver: Optional[Resolver] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.name.Name: """Find the name of the zone which contains the specified name. See :py:func:`dns.resolver.Resolver.zone_for_name` for more @@ -219,8 +277,10 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, raise NotAbsolute(name) while True: try: - answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass, - tcp, backend=backend) + answer = await resolver.resolve( + name, dns.rdatatype.SOA, rdclass, tcp, backend=backend + ) + assert answer.rrset is not None if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher diff --git a/lib/dns/asyncresolver.pyi b/lib/dns/asyncresolver.pyi deleted file mode 100644 index 92759d29..00000000 --- a/lib/dns/asyncresolver.pyi +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Union, Optional, List, Any, Dict -from . import exception, rdataclass, name, rdatatype, asyncbackend - -async def resolve(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime : Optional[float]=None, - search : Optional[bool]=None, - backend : Optional[asyncbackend.Backend]=None): - ... -async def resolve_address(self, ipaddr: str, - *args: Any, **kwargs: Optional[Dict]): - ... - -class Resolver: - def __init__(self, filename : Optional[str] = '/etc/resolv.conf', - configure : Optional[bool] = True): - self.nameservers : List[str] - async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - lifetime : Optional[float]=None, - search : Optional[bool]=None, - backend : Optional[asyncbackend.Backend]=None): - ... diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index dee4e618..5dc26223 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -17,12 +17,17 @@ """Common DNSSEC-related functions and constants.""" +from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union + import hashlib +import math import struct import time import base64 +from datetime import datetime + +from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash -import dns.enum import dns.exception import dns.name import dns.node @@ -30,41 +35,47 @@ import dns.rdataset import dns.rdata import dns.rdatatype import dns.rdataclass +import dns.rrset +from dns.rdtypes.ANY.CDNSKEY import CDNSKEY +from dns.rdtypes.ANY.CDS import CDS +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.ANY.DS import DS +from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime +from dns.rdtypes.dnskeybase import Flag class UnsupportedAlgorithm(dns.exception.DNSException): """The DNSSEC algorithm is not supported.""" +class AlgorithmKeyMismatch(UnsupportedAlgorithm): + """The DNSSEC algorithm is not supported for the given key type.""" + + class ValidationFailure(dns.exception.DNSException): """The DNSSEC signature is invalid.""" -class Algorithm(dns.enum.IntEnum): - RSAMD5 = 1 - DH = 2 - DSA = 3 - ECC = 4 - RSASHA1 = 5 - DSANSEC3SHA1 = 6 - RSASHA1NSEC3SHA1 = 7 - RSASHA256 = 8 - RSASHA512 = 10 - ECCGOST = 12 - ECDSAP256SHA256 = 13 - ECDSAP384SHA384 = 14 - ED25519 = 15 - ED448 = 16 - INDIRECT = 252 - PRIVATEDNS = 253 - PRIVATEOID = 254 - - @classmethod - def _maximum(cls): - return 255 +class DeniedByPolicy(dns.exception.DNSException): + """Denied by DNSSEC policy.""" -def algorithm_from_text(text): +PublicKey = Union[ + "rsa.RSAPublicKey", + "ec.EllipticCurvePublicKey", + "ed25519.Ed25519PublicKey", + "ed448.Ed448PublicKey", +] + +PrivateKey = Union[ + "rsa.RSAPrivateKey", + "ec.EllipticCurvePrivateKey", + "ed25519.Ed25519PrivateKey", + "ed448.Ed448PrivateKey", +] + + +def algorithm_from_text(text: str) -> Algorithm: """Convert text into a DNSSEC algorithm value. *text*, a ``str``, the text to convert to into an algorithm value. @@ -75,10 +86,10 @@ def algorithm_from_text(text): return Algorithm.from_text(text) -def algorithm_to_text(value): +def algorithm_to_text(value: Union[Algorithm, int]) -> str: """Convert a DNSSEC algorithm value to text - *value*, an ``int`` a DNSSEC algorithm. + *value*, a ``dns.dnssec.Algorithm``. Returns a ``str``, the name of a DNSSEC algorithm. """ @@ -86,7 +97,21 @@ def algorithm_to_text(value): return Algorithm.to_text(value) -def key_id(key): +def to_timestamp(value: Union[datetime, str, float, int]) -> int: + """Convert various format to a timestamp""" + if isinstance(value, datetime): + return int(value.timestamp()) + elif isinstance(value, str): + return sigtime_to_posixtime(value) + elif isinstance(value, float): + return int(value) + elif isinstance(value, int): + return value + else: + raise TypeError("Unsupported timestamp type") + + +def key_id(key: Union[DNSKEY, CDNSKEY]) -> int: """Return the key id (a 16-bit number) for the specified key. *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` @@ -100,50 +125,116 @@ def key_id(key): else: total = 0 for i in range(len(rdata) // 2): - total += (rdata[2 * i] << 8) + \ - rdata[2 * i + 1] + total += (rdata[2 * i] << 8) + rdata[2 * i + 1] if len(rdata) % 2 != 0: total += rdata[len(rdata) - 1] << 8 - total += ((total >> 16) & 0xffff) - return total & 0xffff - -class DSDigest(dns.enum.IntEnum): - """DNSSEC Delegation Signer Digest Algorithm""" - - SHA1 = 1 - SHA256 = 2 - SHA384 = 4 - - @classmethod - def _maximum(cls): - return 255 + total += (total >> 16) & 0xFFFF + return total & 0xFFFF -def make_ds(name, key, algorithm, origin=None): +class Policy: + def __init__(self): + pass + + def ok_to_sign(self, _: DNSKEY) -> bool: # pragma: no cover + return False + + def ok_to_validate(self, _: DNSKEY) -> bool: # pragma: no cover + return False + + def ok_to_create_ds(self, _: DSDigest) -> bool: # pragma: no cover + return False + + def ok_to_validate_ds(self, _: DSDigest) -> bool: # pragma: no cover + return False + + +class SimpleDeny(Policy): + def __init__(self, deny_sign, deny_validate, deny_create_ds, deny_validate_ds): + super().__init__() + self._deny_sign = deny_sign + self._deny_validate = deny_validate + self._deny_create_ds = deny_create_ds + self._deny_validate_ds = deny_validate_ds + + def ok_to_sign(self, key: DNSKEY) -> bool: + return key.algorithm not in self._deny_sign + + def ok_to_validate(self, key: DNSKEY) -> bool: + return key.algorithm not in self._deny_validate + + def ok_to_create_ds(self, algorithm: DSDigest) -> bool: + return algorithm not in self._deny_create_ds + + def ok_to_validate_ds(self, algorithm: DSDigest) -> bool: + return algorithm not in self._deny_validate_ds + + +rfc_8624_policy = SimpleDeny( + {Algorithm.RSAMD5, Algorithm.DSA, Algorithm.DSANSEC3SHA1, Algorithm.ECCGOST}, + {Algorithm.RSAMD5, Algorithm.DSA, Algorithm.DSANSEC3SHA1}, + {DSDigest.NULL, DSDigest.SHA1, DSDigest.GOST}, + {DSDigest.NULL}, +) + +allow_all_policy = SimpleDeny(set(), set(), set(), set()) + + +default_policy = rfc_8624_policy + + +def make_ds( + name: Union[dns.name.Name, str], + key: dns.rdata.Rdata, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, + policy: Optional[Policy] = None, + validating: bool = False, +) -> DS: """Create a DS record for a DNSSEC key. *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. - *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``, the key the DS is about. + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` or ``dns.rdtypes.ANY.DNSKEY.CDNSKEY``, + the key the DS is about. *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case does not matter for these strings. - *origin*, a ``dns.name.Name`` or ``None``. If `key` is a relative name, + *origin*, a ``dns.name.Name`` or ``None``. If *key* is a relative name, then it will be made absolute using the specified origin. + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + + *validating*, a ``bool``. If ``True``, then policy is checked in + validating mode, i.e. "Is it ok to validate using this digest algorithm?". + Otherwise the policy is checked in creating mode, i.e. "Is it ok to create a DS with + this digest algorithm?". + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + Raises ``DeniedByPolicy`` if the algorithm is denied by policy. + Returns a ``dns.rdtypes.ANY.DS.DS`` """ + if policy is None: + policy = default_policy try: if isinstance(algorithm, str): algorithm = DSDigest[algorithm.upper()] except Exception: raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) - + if validating: + check = policy.ok_to_validate_ds + else: + check = policy.ok_to_create_ds + if not check(algorithm): + raise DeniedByPolicy + if not isinstance(key, (DNSKEY, CDNSKEY)): + raise ValueError("key is not a DNSKEY/CDNSKEY") if algorithm == DSDigest.SHA1: dshash = hashlib.sha1() elif algorithm == DSDigest.SHA256: @@ -155,17 +246,58 @@ def make_ds(name, key, algorithm, origin=None): if isinstance(name, str): name = dns.name.from_text(name, origin) - dshash.update(name.canonicalize().to_wire()) + wire = name.canonicalize().to_wire() + assert wire is not None + dshash.update(wire) dshash.update(key.to_wire(origin=origin)) digest = dshash.digest() - dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \ - digest - return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, - len(dsrdata)) + dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest + ds = dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, len(dsrdata) + ) + return cast(DS, ds) -def _find_candidate_keys(keys, rrsig): +def make_cds( + name: Union[dns.name.Name, str], + key: dns.rdata.Rdata, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, +) -> CDS: + """Create a CDS record for a DNSSEC key. + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. + + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` or ``dns.rdtypes.ANY.DNSKEY.CDNSKEY``, + the key the DS is about. + + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. + + *origin*, a ``dns.name.Name`` or ``None``. If *key* is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.rdtypes.ANY.DS.CDS`` + """ + + ds = make_ds(name, key, algorithm, origin) + return CDS( + rdclass=ds.rdclass, + rdtype=dns.rdatatype.CDS, + key_tag=ds.key_tag, + algorithm=ds.algorithm, + digest_type=ds.digest_type, + digest=ds.digest, + ) + + +def _find_candidate_keys( + keys: Dict[dns.name.Name, Union[dns.rdataset.Rdataset, dns.node.Node]], rrsig: RRSIG +) -> Optional[List[DNSKEY]]: value = keys.get(rrsig.signer) if isinstance(value, dns.node.Node): rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY) @@ -173,54 +305,94 @@ def _find_candidate_keys(keys, rrsig): rdataset = value if rdataset is None: return None - return [rd for rd in rdataset if - rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag] + return [ + cast(DNSKEY, rd) + for rd in rdataset + if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag + ] -def _is_rsa(algorithm): - return algorithm in (Algorithm.RSAMD5, Algorithm.RSASHA1, - Algorithm.RSASHA1NSEC3SHA1, Algorithm.RSASHA256, - Algorithm.RSASHA512) +def _is_rsa(algorithm: int) -> bool: + return algorithm in ( + Algorithm.RSAMD5, + Algorithm.RSASHA1, + Algorithm.RSASHA1NSEC3SHA1, + Algorithm.RSASHA256, + Algorithm.RSASHA512, + ) -def _is_dsa(algorithm): +def _is_dsa(algorithm: int) -> bool: return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1) -def _is_ecdsa(algorithm): +def _is_ecdsa(algorithm: int) -> bool: return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384) -def _is_eddsa(algorithm): +def _is_eddsa(algorithm: int) -> bool: return algorithm in (Algorithm.ED25519, Algorithm.ED448) -def _is_gost(algorithm): +def _is_gost(algorithm: int) -> bool: return algorithm == Algorithm.ECCGOST -def _is_md5(algorithm): +def _is_md5(algorithm: int) -> bool: return algorithm == Algorithm.RSAMD5 -def _is_sha1(algorithm): - return algorithm in (Algorithm.DSA, Algorithm.RSASHA1, - Algorithm.DSANSEC3SHA1, Algorithm.RSASHA1NSEC3SHA1) +def _is_sha1(algorithm: int) -> bool: + return algorithm in ( + Algorithm.DSA, + Algorithm.RSASHA1, + Algorithm.DSANSEC3SHA1, + Algorithm.RSASHA1NSEC3SHA1, + ) -def _is_sha256(algorithm): +def _is_sha256(algorithm: int) -> bool: return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256) -def _is_sha384(algorithm): +def _is_sha384(algorithm: int) -> bool: return algorithm == Algorithm.ECDSAP384SHA384 -def _is_sha512(algorithm): +def _is_sha512(algorithm: int) -> bool: return algorithm == Algorithm.RSASHA512 -def _make_hash(algorithm): +def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None: + """Ensure algorithm is valid for key type, throwing an exception on + mismatch.""" + if isinstance(key, rsa.RSAPublicKey): + if _is_rsa(algorithm): + return + raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm) + if isinstance(key, dsa.DSAPublicKey): + if _is_dsa(algorithm): + return + raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm) + if isinstance(key, ec.EllipticCurvePublicKey): + if _is_ecdsa(algorithm): + return + raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm) + if isinstance(key, ed25519.Ed25519PublicKey): + if algorithm == Algorithm.ED25519: + return + raise AlgorithmKeyMismatch( + 'algorithm "%s" not valid for ED25519 key' % algorithm + ) + if isinstance(key, ed448.Ed448PublicKey): + if algorithm == Algorithm.ED448: + return + raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm) + + raise TypeError("unsupported key type") + + +def _make_hash(algorithm: int) -> Any: if _is_md5(algorithm): return hashes.MD5() if _is_sha1(algorithm): @@ -236,33 +408,45 @@ def _make_hash(algorithm): if algorithm == Algorithm.ED448: return hashes.SHAKE256(114) - raise ValidationFailure('unknown hash for algorithm %u' % algorithm) + raise ValidationFailure("unknown hash for algorithm %u" % algorithm) -def _bytes_to_long(b): - return int.from_bytes(b, 'big') +def _bytes_to_long(b: bytes) -> int: + return int.from_bytes(b, "big") -def _validate_signature(sig, data, key, chosen_hash): +def _get_rrname_rdataset( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], +) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: + if isinstance(rrset, tuple): + return rrset[0], rrset[1] + else: + return rrset.name, rrset + + +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: + keyptr: bytes if _is_rsa(key.algorithm): + # we ignore because mypy is confused and thinks key.key is a str for unknown + # reasons. keyptr = key.key - (bytes_,) = struct.unpack('!B', keyptr[0:1]) + (bytes_,) = struct.unpack("!B", keyptr[0:1]) keyptr = keyptr[1:] if bytes_ == 0: - (bytes_,) = struct.unpack('!H', keyptr[0:2]) + (bytes_,) = struct.unpack("!H", keyptr[0:2]) keyptr = keyptr[2:] rsa_e = keyptr[0:bytes_] rsa_n = keyptr[bytes_:] try: - public_key = rsa.RSAPublicNumbers( - _bytes_to_long(rsa_e), - _bytes_to_long(rsa_n)).public_key(default_backend()) + rsa_public_key = rsa.RSAPublicNumbers( + _bytes_to_long(rsa_e), _bytes_to_long(rsa_n) + ).public_key(default_backend()) except ValueError: - raise ValidationFailure('invalid public key') - public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) + raise ValidationFailure("invalid public key") + rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) elif _is_dsa(key.algorithm): keyptr = key.key - (t,) = struct.unpack('!B', keyptr[0:1]) + (t,) = struct.unpack("!B", keyptr[0:1]) keyptr = keyptr[1:] octets = 64 + t * 8 dsa_q = keyptr[0:20] @@ -273,17 +457,18 @@ def _validate_signature(sig, data, key, chosen_hash): keyptr = keyptr[octets:] dsa_y = keyptr[0:octets] try: - public_key = dsa.DSAPublicNumbers( + dsa_public_key = dsa.DSAPublicNumbers( # type: ignore _bytes_to_long(dsa_y), dsa.DSAParameterNumbers( - _bytes_to_long(dsa_p), - _bytes_to_long(dsa_q), - _bytes_to_long(dsa_g))).public_key(default_backend()) + _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g) + ), + ).public_key(default_backend()) except ValueError: - raise ValidationFailure('invalid public key') - public_key.verify(sig, data, chosen_hash) + raise ValidationFailure("invalid public key") + dsa_public_key.verify(sig, data, chosen_hash) elif _is_ecdsa(key.algorithm): keyptr = key.key + curve: Any if key.algorithm == Algorithm.ECDSAP256SHA256: curve = ec.SECP256R1() octets = 32 @@ -291,35 +476,43 @@ def _validate_signature(sig, data, key, chosen_hash): curve = ec.SECP384R1() octets = 48 ecdsa_x = keyptr[0:octets] - ecdsa_y = keyptr[octets:octets * 2] + ecdsa_y = keyptr[octets : octets * 2] try: - public_key = ec.EllipticCurvePublicNumbers( - curve=curve, - x=_bytes_to_long(ecdsa_x), - y=_bytes_to_long(ecdsa_y)).public_key(default_backend()) + ecdsa_public_key = ec.EllipticCurvePublicNumbers( + curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y) + ).public_key(default_backend()) except ValueError: - raise ValidationFailure('invalid public key') - public_key.verify(sig, data, ec.ECDSA(chosen_hash)) + raise ValidationFailure("invalid public key") + ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash)) elif _is_eddsa(key.algorithm): keyptr = key.key + loader: Any if key.algorithm == Algorithm.ED25519: loader = ed25519.Ed25519PublicKey else: loader = ed448.Ed448PublicKey try: - public_key = loader.from_public_bytes(keyptr) + eddsa_public_key = loader.from_public_bytes(keyptr) except ValueError: - raise ValidationFailure('invalid public key') - public_key.verify(sig, data) + raise ValidationFailure("invalid public key") + eddsa_public_key.verify(sig, data) elif _is_gost(key.algorithm): raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' % - algorithm_to_text(key.algorithm)) + 'algorithm "%s" not supported by dnspython' + % algorithm_to_text(key.algorithm) + ) else: - raise ValidationFailure('unknown algorithm %u' % key.algorithm) + raise ValidationFailure("unknown algorithm %u" % key.algorithm) -def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): +def _validate_rrsig( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsig: RRSIG, + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name] = None, + now: Optional[float] = None, + policy: Optional[Policy] = None, +) -> None: """Validate an RRset against a single signature rdata, throwing an exception if validation is not successful. @@ -337,10 +530,13 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative names. - *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to + *now*, a ``float`` or ``None``, the time, in seconds since the epoch, to use as the current time when validating. If ``None``, the actual current time is used. + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + Raises ``ValidationFailure`` if the signature is expired, not yet valid, the public key is invalid, the algorithm is unknown, the verification fails, etc. @@ -349,34 +545,24 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): dnspython but not implemented. """ - if isinstance(origin, str): - origin = dns.name.from_text(origin, dns.name.root) + if policy is None: + policy = default_policy candidate_keys = _find_candidate_keys(keys, rrsig) if candidate_keys is None: - raise ValidationFailure('unknown key') - - # For convenience, allow the rrset to be specified as a (name, - # rdataset) tuple as well as a proper rrset - if isinstance(rrset, tuple): - rrname = rrset[0] - rdataset = rrset[1] - else: - rrname = rrset.name - rdataset = rrset + raise ValidationFailure("unknown key") if now is None: now = time.time() if rrsig.expiration < now: - raise ValidationFailure('expired') + raise ValidationFailure("expired") if rrsig.inception > now: - raise ValidationFailure('not yet valid') + raise ValidationFailure("not yet valid") if _is_dsa(rrsig.algorithm): sig_r = rrsig.signature[1:21] sig_s = rrsig.signature[21:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), - _bytes_to_long(sig_s)) + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) elif _is_ecdsa(rrsig.algorithm): if rrsig.algorithm == Algorithm.ECDSAP256SHA256: octets = 32 @@ -384,37 +570,16 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): octets = 48 sig_r = rrsig.signature[0:octets] sig_s = rrsig.signature[octets:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), - _bytes_to_long(sig_s)) + sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) else: sig = rrsig.signature - data = b'' - data += rrsig.to_wire(origin=origin)[:18] - data += rrsig.signer.to_digestable(origin) - - # Derelativize the name before considering labels. - rrname = rrname.derelativize(origin) - - if len(rrname) - 1 < rrsig.labels: - raise ValidationFailure('owner name longer than RRSIG labels') - elif rrsig.labels < len(rrname) - 1: - suffix = rrname.split(rrsig.labels + 1)[1] - rrname = dns.name.from_text('*', suffix) - rrnamebuf = rrname.to_digestable() - rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass, - rrsig.original_ttl) - rdatas = [rdata.to_digestable(origin) for rdata in rdataset] - for rdata in sorted(rdatas): - data += rrnamebuf - data += rrfixed - rrlen = struct.pack('!H', len(rdata)) - data += rrlen - data += rdata - + data = _make_rrsig_signature_data(rrset, rrsig, origin) chosen_hash = _make_hash(rrsig.algorithm) for candidate_key in candidate_keys: + if not policy.ok_to_validate(candidate_key): + continue try: _validate_signature(sig, data, candidate_key, chosen_hash) return @@ -422,10 +587,17 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): # this happens on an individual validation failure continue # nothing verified -- raise failure: - raise ValidationFailure('verify failure') + raise ValidationFailure("verify failure") -def _validate(rrset, rrsigset, keys, origin=None, now=None): +def _validate( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name] = None, + now: Optional[float] = None, + policy: Optional[Policy] = None, +) -> None: """Validate an RRset against a signature RRset, throwing an exception if none of the signatures validate. @@ -449,11 +621,17 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): use as the current time when validating. If ``None``, the actual current time is used. + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + Raises ``ValidationFailure`` if the signature is expired, not yet valid, the public key is invalid, the algorithm is unknown, the verification fails, etc. """ + if policy is None: + policy = default_policy + if isinstance(origin, str): origin = dns.name.from_text(origin, dns.name.root) @@ -475,24 +653,367 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): raise ValidationFailure("owner names do not match") for rrsig in rrsigrdataset: + if not isinstance(rrsig, RRSIG): + raise ValidationFailure("expected an RRSIG") try: - _validate_rrsig(rrset, rrsig, keys, origin, now) + _validate_rrsig(rrset, rrsig, keys, origin, now, policy) return except (ValidationFailure, UnsupportedAlgorithm): pass raise ValidationFailure("no RRSIGs validated") -class NSEC3Hash(dns.enum.IntEnum): - """NSEC3 hash algorithm""" +def _sign( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + private_key: PrivateKey, + signer: dns.name.Name, + dnskey: DNSKEY, + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + verify: bool = False, + policy: Optional[Policy] = None, +) -> RRSIG: + """Sign RRset using private key. - SHA1 = 1 + *rrset*, the RRset to validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. - @classmethod - def _maximum(cls): - return 255 + *private_key*, the private key to use for signing, a + ``cryptography.hazmat.primitives.asymmetric`` private key class applicable + for DNSSEC. -def nsec3_hash(domain, salt, iterations, algorithm): + *signer*, a ``dns.name.Name``, the Signer's name. + + *dnskey*, a ``DNSKEY`` matching ``private_key``. + + *inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the + signature inception time. If ``None``, the current time is used. If a ``str``, the + format is "YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX + epoch in text form; this is the same the RRSIG rdata's text form. + Values of type `int` or `float` are interpreted as seconds since the UNIX epoch. + + *expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + expiration time. If ``None``, the expiration time will be the inception time plus + the value of the *lifetime* parameter. See the description of *inception* above + for how the various parameter types are interpreted. + + *lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This + parameter is only meaningful if *expiration* is ``None``. + + *verify*, a ``bool``. If set to ``True``, the signer will verify signatures + after they are created; the default is ``False``. + + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + + Raises ``DeniedByPolicy`` if the signature is denied by policy. + """ + + if policy is None: + policy = default_policy + if not policy.ok_to_sign(dnskey): + raise DeniedByPolicy + + if isinstance(rrset, tuple): + rdclass = rrset[1].rdclass + rdtype = rrset[1].rdtype + rrname = rrset[0] + original_ttl = rrset[1].ttl + else: + rdclass = rrset.rdclass + rdtype = rrset.rdtype + rrname = rrset.name + original_ttl = rrset.ttl + + if inception is not None: + rrsig_inception = to_timestamp(inception) + else: + rrsig_inception = int(time.time()) + + if expiration is not None: + rrsig_expiration = to_timestamp(expiration) + elif lifetime is not None: + rrsig_expiration = int(time.time()) + lifetime + else: + raise ValueError("expiration or lifetime must be specified") + + rrsig_template = RRSIG( + rdclass=rdclass, + rdtype=dns.rdatatype.RRSIG, + type_covered=rdtype, + algorithm=dnskey.algorithm, + labels=len(rrname) - 1, + original_ttl=original_ttl, + expiration=rrsig_expiration, + inception=rrsig_inception, + key_tag=key_id(dnskey), + signer=signer, + signature=b"", + ) + + data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template) + chosen_hash = _make_hash(rrsig_template.algorithm) + signature = None + + if isinstance(private_key, rsa.RSAPrivateKey): + if not _is_rsa(dnskey.algorithm): + raise ValueError("Invalid DNSKEY algorithm for RSA key") + signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash) + if verify: + private_key.public_key().verify( + signature, data, padding.PKCS1v15(), chosen_hash + ) + elif isinstance(private_key, dsa.DSAPrivateKey): + if not _is_dsa(dnskey.algorithm): + raise ValueError("Invalid DNSKEY algorithm for DSA key") + public_dsa_key = private_key.public_key() + if public_dsa_key.key_size > 1024: + raise ValueError("DSA key size overflow") + der_signature = private_key.sign(data, chosen_hash) + if verify: + public_dsa_key.verify(der_signature, data, chosen_hash) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 + octets = 20 + signature = ( + struct.pack("!B", dsa_t) + + int.to_bytes(dsa_r, length=octets, byteorder="big") + + int.to_bytes(dsa_s, length=octets, byteorder="big") + ) + elif isinstance(private_key, ec.EllipticCurvePrivateKey): + if not _is_ecdsa(dnskey.algorithm): + raise ValueError("Invalid DNSKEY algorithm for EC key") + der_signature = private_key.sign(data, ec.ECDSA(chosen_hash)) + if verify: + private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash)) + if dnskey.algorithm == Algorithm.ECDSAP256SHA256: + octets = 32 + else: + octets = 48 + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes( + dsa_s, length=octets, byteorder="big" + ) + elif isinstance(private_key, ed25519.Ed25519PrivateKey): + if dnskey.algorithm != Algorithm.ED25519: + raise ValueError("Invalid DNSKEY algorithm for ED25519 key") + signature = private_key.sign(data) + if verify: + private_key.public_key().verify(signature, data) + elif isinstance(private_key, ed448.Ed448PrivateKey): + if dnskey.algorithm != Algorithm.ED448: + raise ValueError("Invalid DNSKEY algorithm for ED448 key") + signature = private_key.sign(data) + if verify: + private_key.public_key().verify(signature, data) + else: + raise TypeError("Unsupported key algorithm") + + return cast(RRSIG, rrsig_template.replace(signature=signature)) + + +def _make_rrsig_signature_data( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsig: RRSIG, + origin: Optional[dns.name.Name] = None, +) -> bytes: + """Create signature rdata. + + *rrset*, the RRset to sign/validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *rrsig*, a ``dns.rdata.Rdata``, the signature to validate, or the + signature template used when signing. + + *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative + names. + + Raises ``UnsupportedAlgorithm`` if the algorithm is recognized by + dnspython but not implemented. + """ + + if isinstance(origin, str): + origin = dns.name.from_text(origin, dns.name.root) + + signer = rrsig.signer + if not signer.is_absolute(): + if origin is None: + raise ValidationFailure("relative RR name without an origin specified") + signer = signer.derelativize(origin) + + # For convenience, allow the rrset to be specified as a (name, + # rdataset) tuple as well as a proper rrset + rrname, rdataset = _get_rrname_rdataset(rrset) + + data = b"" + data += rrsig.to_wire(origin=signer)[:18] + data += rrsig.signer.to_digestable(signer) + + # Derelativize the name before considering labels. + if not rrname.is_absolute(): + if origin is None: + raise ValidationFailure("relative RR name without an origin specified") + rrname = rrname.derelativize(origin) + + if len(rrname) - 1 < rrsig.labels: + raise ValidationFailure("owner name longer than RRSIG labels") + elif rrsig.labels < len(rrname) - 1: + suffix = rrname.split(rrsig.labels + 1)[1] + rrname = dns.name.from_text("*", suffix) + rrnamebuf = rrname.to_digestable() + rrfixed = struct.pack("!HHI", rdataset.rdtype, rdataset.rdclass, rrsig.original_ttl) + rdatas = [rdata.to_digestable(origin) for rdata in rdataset] + for rdata in sorted(rdatas): + data += rrnamebuf + data += rrfixed + rrlen = struct.pack("!H", len(rdata)) + data += rrlen + data += rdata + + return data + + +def _make_dnskey( + public_key: PublicKey, + algorithm: Union[int, str], + flags: int = Flag.ZONE, + protocol: int = 3, +) -> DNSKEY: + """Convert a public key to DNSKEY Rdata + + *public_key*, the public key to convert, a + ``cryptography.hazmat.primitives.asymmetric`` public key class applicable + for DNSSEC. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *flags*: DNSKEY flags field as an integer. + + *protocol*: DNSKEY protocol field as an integer. + + Raises ``ValueError`` if the specified key algorithm parameters are not + unsupported, ``TypeError`` if the key type is unsupported, + `UnsupportedAlgorithm` if the algorithm is unknown and + `AlgorithmKeyMismatch` if the algorithm does not match the key type. + + Return DNSKEY ``Rdata``. + """ + + def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes: + """Encode a public key per RFC 3110, section 2.""" + pn = public_key.public_numbers() + _exp_len = math.ceil(int.bit_length(pn.e) / 8) + exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") + if _exp_len > 255: + exp_header = b"\0" + struct.pack("!H", _exp_len) + else: + exp_header = struct.pack("!B", _exp_len) + if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: + raise ValueError("unsupported RSA key length") + return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + + def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes: + """Encode a public key per RFC 2536, section 2.""" + pn = public_key.public_numbers() + dsa_t = (public_key.key_size // 8 - 64) // 8 + if dsa_t > 8: + raise ValueError("unsupported DSA key size") + octets = 64 + dsa_t * 8 + res = struct.pack("!B", dsa_t) + res += pn.parameter_numbers.q.to_bytes(20, "big") + res += pn.parameter_numbers.p.to_bytes(octets, "big") + res += pn.parameter_numbers.g.to_bytes(octets, "big") + res += pn.y.to_bytes(octets, "big") + return res + + def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes: + """Encode a public key per RFC 6605, section 4.""" + pn = public_key.public_numbers() + if isinstance(public_key.curve, ec.SECP256R1): + return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big") + elif isinstance(public_key.curve, ec.SECP384R1): + return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big") + else: + raise ValueError("unsupported ECDSA curve") + + the_algorithm = Algorithm.make(algorithm) + + _ensure_algorithm_key_combination(the_algorithm, public_key) + + if isinstance(public_key, rsa.RSAPublicKey): + key_bytes = encode_rsa_public_key(public_key) + elif isinstance(public_key, dsa.DSAPublicKey): + key_bytes = encode_dsa_public_key(public_key) + elif isinstance(public_key, ec.EllipticCurvePublicKey): + key_bytes = encode_ecdsa_public_key(public_key) + elif isinstance(public_key, ed25519.Ed25519PublicKey): + key_bytes = public_key.public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + elif isinstance(public_key, ed448.Ed448PublicKey): + key_bytes = public_key.public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + else: + raise TypeError("unsupported key algorithm") + + return DNSKEY( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.DNSKEY, + flags=flags, + protocol=protocol, + algorithm=the_algorithm, + key=key_bytes, + ) + + +def _make_cdnskey( + public_key: PublicKey, + algorithm: Union[int, str], + flags: int = Flag.ZONE, + protocol: int = 3, +) -> CDNSKEY: + """Convert a public key to CDNSKEY Rdata + + *public_key*, the public key to convert, a + ``cryptography.hazmat.primitives.asymmetric`` public key class applicable + for DNSSEC. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *flags*: DNSKEY flags field as an integer. + + *protocol*: DNSKEY protocol field as an integer. + + Raises ``ValueError`` if the specified key algorithm parameters are not + unsupported, ``TypeError`` if the key type is unsupported, + `UnsupportedAlgorithm` if the algorithm is unknown and + `AlgorithmKeyMismatch` if the algorithm does not match the key type. + + Return CDNSKEY ``Rdata``. + """ + + dnskey = _make_dnskey(public_key, algorithm, flags, protocol) + + return CDNSKEY( + rdclass=dnskey.rdclass, + rdtype=dns.rdatatype.CDNSKEY, + flags=dnskey.flags, + protocol=dnskey.protocol, + algorithm=dnskey.algorithm, + key=dnskey.key, + ) + + +def nsec3_hash( + domain: Union[dns.name.Name, str], + salt: Optional[Union[str, bytes]], + iterations: int, + algorithm: Union[int, str], +) -> str: """ Calculate the NSEC3 hash, according to https://tools.ietf.org/html/rfc5155#section-5 @@ -523,18 +1044,20 @@ def nsec3_hash(domain, salt, iterations, algorithm): if algorithm != NSEC3Hash.SHA1: raise ValueError("Wrong hash algorithm (only SHA1 is supported)") - salt_encoded = salt if salt is None: - salt_encoded = b'' + salt_encoded = b"" elif isinstance(salt, str): if len(salt) % 2 == 0: salt_encoded = bytes.fromhex(salt) else: raise ValueError("Invalid salt length") + else: + salt_encoded = salt if not isinstance(domain, dns.name.Name): domain = dns.name.from_text(domain) domain_encoded = domain.canonicalize().to_wire() + assert domain_encoded is not None digest = hashlib.sha1(domain_encoded + salt_encoded).digest() for _ in range(iterations): @@ -546,15 +1069,163 @@ def nsec3_hash(domain, salt, iterations, algorithm): return output +def make_ds_rdataset( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + algorithms: Set[Union[DSDigest, str]], + origin: Optional[dns.name.Name] = None, +) -> dns.rdataset.Rdataset: + """Create a DS record from DNSKEY/CDNSKEY/CDS. + + *rrset*, the RRset to create DS Rdataset for. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *algorithms*, a set of ``str`` or ``int`` specifying the hash algorithms. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. If the RRset is a CDS, only digest + algorithms matching algorithms are accepted. + + *origin*, a ``dns.name.Name`` or ``None``. If `key` is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if any of the algorithms are unknown and + ``ValueError`` if the given RRset is not usable. + + Returns a ``dns.rdataset.Rdataset`` + """ + + rrname, rdataset = _get_rrname_rdataset(rrset) + + if rdataset.rdtype not in ( + dns.rdatatype.DNSKEY, + dns.rdatatype.CDNSKEY, + dns.rdatatype.CDS, + ): + raise ValueError("rrset not a DNSKEY/CDNSKEY/CDS") + + _algorithms = set() + for algorithm in algorithms: + try: + if isinstance(algorithm, str): + algorithm = DSDigest[algorithm.upper()] + except Exception: + raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + _algorithms.add(algorithm) + + if rdataset.rdtype == dns.rdatatype.CDS: + res = [] + for rdata in cds_rdataset_to_ds_rdataset(rdataset): + if rdata.digest_type in _algorithms: + res.append(rdata) + if len(res) == 0: + raise ValueError("no acceptable CDS rdata found") + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + res = [] + for algorithm in _algorithms: + res.extend(dnskey_rdataset_to_cds_rdataset(rrname, rdataset, algorithm, origin)) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def cds_rdataset_to_ds_rdataset( + rdataset: dns.rdataset.Rdataset, +) -> dns.rdataset.Rdataset: + """Create a CDS record from DS. + + *rdataset*, a ``dns.rdataset.Rdataset``, to create DS Rdataset for. + + Raises ``ValueError`` if the rdataset is not CDS. + + Returns a ``dns.rdataset.Rdataset`` + """ + + if rdataset.rdtype != dns.rdatatype.CDS: + raise ValueError("rdataset not a CDS") + res = [] + for rdata in rdataset: + res.append( + CDS( + rdclass=rdata.rdclass, + rdtype=dns.rdatatype.DS, + key_tag=rdata.key_tag, + algorithm=rdata.algorithm, + digest_type=rdata.digest_type, + digest=rdata.digest, + ) + ) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def dnskey_rdataset_to_cds_rdataset( + name: Union[dns.name.Name, str], + rdataset: dns.rdataset.Rdataset, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, +) -> dns.rdataset.Rdataset: + """Create a CDS record from DNSKEY/CDNSKEY. + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the CDS record. + + *rdataset*, a ``dns.rdataset.Rdataset``, to create DS Rdataset for. + + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. + + *origin*, a ``dns.name.Name`` or ``None``. If `key` is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown or + ``ValueError`` if the rdataset is not DNSKEY/CDNSKEY. + + Returns a ``dns.rdataset.Rdataset`` + """ + + if rdataset.rdtype not in (dns.rdatatype.DNSKEY, dns.rdatatype.CDNSKEY): + raise ValueError("rdataset not a DNSKEY/CDNSKEY") + res = [] + for rdata in rdataset: + res.append(make_cds(name, rdata, algorithm, origin)) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def dnskey_rdataset_to_cdnskey_rdataset( + rdataset: dns.rdataset.Rdataset, +) -> dns.rdataset.Rdataset: + """Create a CDNSKEY record from DNSKEY. + + *rdataset*, a ``dns.rdataset.Rdataset``, to create CDNSKEY Rdataset for. + + Returns a ``dns.rdataset.Rdataset`` + """ + + if rdataset.rdtype != dns.rdatatype.DNSKEY: + raise ValueError("rdataset not a DNSKEY") + res = [] + for rdata in rdataset: + res.append( + CDNSKEY( + rdclass=rdataset.rdclass, + rdtype=rdataset.rdtype, + flags=rdata.flags, + protocol=rdata.protocol, + algorithm=rdata.algorithm, + key=rdata.key, + ) + ) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + def _need_pyca(*args, **kwargs): - raise ImportError("DNSSEC validation requires " + - "python cryptography") # pragma: no cover + raise ImportError( + "DNSSEC validation requires " + "python cryptography" + ) # pragma: no cover try: from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import utils from cryptography.hazmat.primitives.asymmetric import dsa @@ -565,10 +1236,16 @@ try: except ImportError: # pragma: no cover validate = _need_pyca validate_rrsig = _need_pyca + sign = _need_pyca + make_dnskey = _need_pyca + make_cdnskey = _need_pyca _have_pyca = False else: - validate = _validate # type: ignore - validate_rrsig = _validate_rrsig # type: ignore + validate = _validate # type: ignore + validate_rrsig = _validate_rrsig # type: ignore + sign = _sign + make_dnskey = _make_dnskey + make_cdnskey = _make_cdnskey _have_pyca = True ### BEGIN generated Algorithm constants diff --git a/lib/dns/dnssec.pyi b/lib/dns/dnssec.pyi deleted file mode 100644 index e126f9b8..00000000 --- a/lib/dns/dnssec.pyi +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Union, Dict, Tuple, Optional -from . import rdataset, rrset, exception, name, rdtypes, rdata, node -import dns.rdtypes.ANY.DS as DS -import dns.rdtypes.ANY.DNSKEY as DNSKEY - -_have_pyca : bool - -def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None: - ... - -def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None: - ... - -class ValidationFailure(exception.DNSException): - ... - -def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS: - ... - -def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str: - ... diff --git a/lib/dns/dnssectypes.py b/lib/dns/dnssectypes.py new file mode 100644 index 00000000..02131e0a --- /dev/null +++ b/lib/dns/dnssectypes.py @@ -0,0 +1,71 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Common DNSSEC-related types.""" + +# This is a separate file to avoid import circularity between dns.dnssec and +# the implementations of the DS and DNSKEY types. + +import dns.enum + + +class Algorithm(dns.enum.IntEnum): + RSAMD5 = 1 + DH = 2 + DSA = 3 + ECC = 4 + RSASHA1 = 5 + DSANSEC3SHA1 = 6 + RSASHA1NSEC3SHA1 = 7 + RSASHA256 = 8 + RSASHA512 = 10 + ECCGOST = 12 + ECDSAP256SHA256 = 13 + ECDSAP384SHA384 = 14 + ED25519 = 15 + ED448 = 16 + INDIRECT = 252 + PRIVATEDNS = 253 + PRIVATEOID = 254 + + @classmethod + def _maximum(cls): + return 255 + + +class DSDigest(dns.enum.IntEnum): + """DNSSEC Delegation Signer Digest Algorithm""" + + NULL = 0 + SHA1 = 1 + SHA256 = 2 + GOST = 3 + SHA384 = 4 + + @classmethod + def _maximum(cls): + return 255 + + +class NSEC3Hash(dns.enum.IntEnum): + """NSEC3 hash algorithm""" + + SHA1 = 1 + + @classmethod + def _maximum(cls): + return 255 diff --git a/lib/dns/e164.py b/lib/dns/e164.py index 83731b2c..453736d4 100644 --- a/lib/dns/e164.py +++ b/lib/dns/e164.py @@ -17,15 +17,19 @@ """DNS E.164 helpers.""" +from typing import Iterable, Optional, Union + import dns.exception import dns.name import dns.resolver #: The public E.164 domain. -public_enum_domain = dns.name.from_text('e164.arpa.') +public_enum_domain = dns.name.from_text("e164.arpa.") -def from_e164(text, origin=public_enum_domain): +def from_e164( + text: str, origin: Optional[dns.name.Name] = public_enum_domain +) -> dns.name.Name: """Convert an E.164 number in textual form into a Name object whose value is the ENUM domain name for that number. @@ -42,10 +46,14 @@ def from_e164(text, origin=public_enum_domain): parts = [d for d in text if d.isdigit()] parts.reverse() - return dns.name.from_text('.'.join(parts), origin=origin) + return dns.name.from_text(".".join(parts), origin=origin) -def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): +def to_e164( + name: dns.name.Name, + origin: Optional[dns.name.Name] = public_enum_domain, + want_plus_prefix: bool = True, +) -> str: """Convert an ENUM domain name into an E.164 number. Note that dnspython does not have any information about preferred @@ -69,15 +77,19 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): name = name.relativize(origin) dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1] if len(dlabels) != len(name.labels): - raise dns.exception.SyntaxError('non-digit labels in ENUM domain name') + raise dns.exception.SyntaxError("non-digit labels in ENUM domain name") dlabels.reverse() - text = b''.join(dlabels) + text = b"".join(dlabels) if want_plus_prefix: - text = b'+' + text + text = b"+" + text return text.decode() -def query(number, domains, resolver=None): +def query( + number: str, + domains: Iterable[Union[dns.name.Name, str]], + resolver: Optional[dns.resolver.Resolver] = None, +) -> dns.resolver.Answer: """Look for NAPTR RRs for the specified number in the specified domains. e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) @@ -98,7 +110,7 @@ def query(number, domains, resolver=None): domain = dns.name.from_text(domain) qname = dns.e164.from_e164(number, domain) try: - return resolver.resolve(qname, 'NAPTR') + return resolver.resolve(qname, "NAPTR") except dns.resolver.NXDOMAIN as e: e_nx += e raise e_nx diff --git a/lib/dns/e164.pyi b/lib/dns/e164.pyi deleted file mode 100644 index 37a99fed..00000000 --- a/lib/dns/e164.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional, Iterable -from . import name, resolver -def from_e164(text : str, origin=name.Name(".")) -> name.Name: - ... - -def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str: - ... - -def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer: - ... diff --git a/lib/dns/edns.py b/lib/dns/edns.py index 9d7e909d..64436cde 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -17,6 +17,8 @@ """EDNS Options""" +from typing import Any, Dict, Optional, Union + import math import socket import struct @@ -24,6 +26,7 @@ import struct import dns.enum import dns.inet import dns.rdata +import dns.wire class OptionType(dns.enum.IntEnum): @@ -59,14 +62,14 @@ class Option: """Base class for all EDNS option types.""" - def __init__(self, otype): + def __init__(self, otype: Union[OptionType, str]): """Initialize an option. - *otype*, an ``int``, is the option type. + *otype*, a ``dns.edns.OptionType``, is the option type. """ self.otype = OptionType.make(otype) - def to_wire(self, file=None): + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: """Convert an option to wire format. Returns a ``bytes`` or ``None``. @@ -75,10 +78,10 @@ class Option: raise NotImplementedError # pragma: no cover @classmethod - def from_wire_parser(cls, otype, parser): + def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option": """Build an EDNS option object from wire format. - *otype*, an ``int``, is the option type. + *otype*, a ``dns.edns.OptionType``, is the option type. *parser*, a ``dns.wire.Parser``, the parser, which should be restructed to the option length. @@ -115,26 +118,22 @@ class Option: return self._cmp(other) != 0 def __lt__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) < 0 def __le__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) <= 0 def __ge__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) >= 0 def __gt__(self, other): - if not isinstance(other, Option) or \ - self.otype != other.otype: + if not isinstance(other, Option) or self.otype != other.otype: return NotImplemented return self._cmp(other) > 0 @@ -142,7 +141,7 @@ class Option: return self.to_text() -class GenericOption(Option): +class GenericOption(Option): # lgtm[py/missing-equals] """Generic Option Class @@ -150,28 +149,31 @@ class GenericOption(Option): implementation. """ - def __init__(self, otype, data): + def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]): super().__init__(otype) self.data = dns.rdata.Rdata._as_bytes(data, True) - def to_wire(self, file=None): + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: if file: file.write(self.data) + return None else: return self.data - def to_text(self): + def to_text(self) -> str: return "Generic %d" % self.otype @classmethod - def from_wire_parser(cls, otype, parser): + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: return cls(otype, parser.get_remaining()) -class ECSOption(Option): +class ECSOption(Option): # lgtm[py/missing-equals] """EDNS Client Subnet (ECS, RFC7871)""" - def __init__(self, address, srclen=None, scopelen=0): + def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0): """*address*, a ``str``, is the client address information. *srclen*, an ``int``, the source prefix length, which is the @@ -200,8 +202,9 @@ class ECSOption(Option): srclen = dns.rdata.Rdata._as_int(srclen, 0, 32) scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32) else: # pragma: no cover (this will never happen) - raise ValueError('Bad address family') + raise ValueError("Bad address family") + assert srclen is not None self.address = address self.srclen = srclen self.scopelen = scopelen @@ -214,16 +217,14 @@ class ECSOption(Option): self.addrdata = addrdata[:nbytes] nbits = srclen % 8 if nbits != 0: - last = struct.pack('B', - ord(self.addrdata[-1:]) & (0xff << (8 - nbits))) + last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits))) self.addrdata = self.addrdata[:-1] + last - def to_text(self): - return "ECS {}/{} scope/{}".format(self.address, self.srclen, - self.scopelen) + def to_text(self) -> str: + return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen) @staticmethod - def from_text(text): + def from_text(text: str) -> Option: """Convert a string into a `dns.edns.ECSOption` *text*, a `str`, the text form of the option. @@ -246,7 +247,7 @@ class ECSOption(Option): >>> # it understands results from `dns.edns.ECSOption.to_text()` >>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') """ - optional_prefix = 'ECS' + optional_prefix = "ECS" tokens = text.split() ecs_text = None if len(tokens) == 1: @@ -257,47 +258,53 @@ class ECSOption(Option): ecs_text = tokens[1] else: raise ValueError('could not parse ECS from "{}"'.format(text)) - n_slashes = ecs_text.count('/') + n_slashes = ecs_text.count("/") if n_slashes == 1: - address, srclen = ecs_text.split('/') - scope = 0 + address, tsrclen = ecs_text.split("/") + tscope = "0" elif n_slashes == 2: - address, srclen, scope = ecs_text.split('/') + address, tsrclen, tscope = ecs_text.split("/") else: raise ValueError('could not parse ECS from "{}"'.format(text)) try: - scope = int(scope) + scope = int(tscope) except ValueError: - raise ValueError('invalid scope ' + - '"{}": scope must be an integer'.format(scope)) + raise ValueError( + "invalid scope " + '"{}": scope must be an integer'.format(tscope) + ) try: - srclen = int(srclen) + srclen = int(tsrclen) except ValueError: - raise ValueError('invalid srclen ' + - '"{}": srclen must be an integer'.format(srclen)) + raise ValueError( + "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen) + ) return ECSOption(address, srclen, scope) - def to_wire(self, file=None): - value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + - self.addrdata) + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + value = ( + struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata + ) if file: file.write(value) + return None else: return value @classmethod - def from_wire_parser(cls, otype, parser): - family, src, scope = parser.get_struct('!HBB') + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: + family, src, scope = parser.get_struct("!HBB") addrlen = int(math.ceil(src / 8.0)) prefix = parser.get_bytes(addrlen) if family == 1: pad = 4 - addrlen - addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad) + addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad) elif family == 2: pad = 16 - addrlen - addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad) + addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad) else: - raise ValueError('unsupported family') + raise ValueError("unsupported family") return cls(addr, src, scope) @@ -334,10 +341,10 @@ class EDECode(dns.enum.IntEnum): return 65535 -class EDEOption(Option): +class EDEOption(Option): # lgtm[py/missing-equals] """Extended DNS Error (EDE, RFC8914)""" - def __init__(self, code, text=None): + def __init__(self, code: Union[EDECode, str], text: Optional[str] = None): """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the extended error. @@ -349,49 +356,50 @@ class EDEOption(Option): self.code = EDECode.make(code) if text is not None and not isinstance(text, str): - raise ValueError('text must be string or None') - - self.code = code + raise ValueError("text must be string or None") self.text = text - def to_text(self): - output = f'EDE {self.code}' + def to_text(self) -> str: + output = f"EDE {self.code}" if self.text is not None: - output += f': {self.text}' + output += f": {self.text}" return output - def to_wire(self, file=None): - value = struct.pack('!H', self.code) + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + value = struct.pack("!H", self.code) if self.text is not None: - value += self.text.encode('utf8') + value += self.text.encode("utf8") if file: file.write(value) + return None else: return value @classmethod - def from_wire_parser(cls, otype, parser): - code = parser.get_uint16() + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: + the_code = EDECode.make(parser.get_uint16()) text = parser.get_remaining() if text: if text[-1] == 0: # text MAY be null-terminated text = text[:-1] - text = text.decode('utf8') + btext = text.decode("utf8") else: - text = None + btext = None - return cls(code, text) + return cls(the_code, btext) -_type_to_class = { +_type_to_class: Dict[OptionType, Any] = { OptionType.ECS: ECSOption, OptionType.EDE: EDEOption, } -def get_option_class(otype): +def get_option_class(otype: OptionType) -> Any: """Return the class for the specified option type. The GenericOption class is used if a more specific class is not @@ -404,7 +412,9 @@ def get_option_class(otype): return cls -def option_from_wire_parser(otype, parser): +def option_from_wire_parser( + otype: Union[OptionType, str], parser: "dns.wire.Parser" +) -> Option: """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. @@ -414,12 +424,14 @@ def option_from_wire_parser(otype, parser): Returns an instance of a subclass of ``dns.edns.Option``. """ - cls = get_option_class(otype) - otype = OptionType.make(otype) + the_otype = OptionType.make(otype) + cls = get_option_class(the_otype) return cls.from_wire_parser(otype, parser) -def option_from_wire(otype, wire, current, olen): +def option_from_wire( + otype: Union[OptionType, str], wire: bytes, current: int, olen: int +) -> Option: """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. @@ -437,7 +449,8 @@ def option_from_wire(otype, wire, current, olen): with parser.restrict_to(olen): return option_from_wire_parser(otype, parser) -def register_type(implementation, otype): + +def register_type(implementation: Any, otype: OptionType) -> None: """Register the implementation of an option type. *implementation*, a ``class``, is a subclass of ``dns.edns.Option``. @@ -447,6 +460,7 @@ def register_type(implementation, otype): _type_to_class[otype] = implementation + ### BEGIN generated OptionType constants NSID = OptionType.NSID diff --git a/lib/dns/entropy.py b/lib/dns/entropy.py index 086bba78..5e1f5e23 100644 --- a/lib/dns/entropy.py +++ b/lib/dns/entropy.py @@ -15,14 +15,13 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +from typing import Any, Optional + import os import hashlib import random +import threading import time -try: - import threading as _threading -except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore class EntropyPool: @@ -32,51 +31,51 @@ class EntropyPool: # leaving this code doesn't hurt anything as the library code # is used if present. - def __init__(self, seed=None): + def __init__(self, seed: Optional[bytes] = None): self.pool_index = 0 - self.digest = None + self.digest: Optional[bytearray] = None self.next_byte = 0 - self.lock = _threading.Lock() + self.lock = threading.Lock() self.hash = hashlib.sha1() self.hash_len = 20 - self.pool = bytearray(b'\0' * self.hash_len) + self.pool = bytearray(b"\0" * self.hash_len) if seed is not None: - self._stir(bytearray(seed)) + self._stir(seed) self.seeded = True self.seed_pid = os.getpid() else: self.seeded = False self.seed_pid = 0 - def _stir(self, entropy): + def _stir(self, entropy: bytes) -> None: for c in entropy: if self.pool_index == self.hash_len: self.pool_index = 0 - b = c & 0xff + b = c & 0xFF self.pool[self.pool_index] ^= b self.pool_index += 1 - def stir(self, entropy): + def stir(self, entropy: bytes) -> None: with self.lock: self._stir(entropy) - def _maybe_seed(self): + def _maybe_seed(self) -> None: if not self.seeded or self.seed_pid != os.getpid(): try: seed = os.urandom(16) except Exception: # pragma: no cover try: - with open('/dev/urandom', 'rb', 0) as r: + with open("/dev/urandom", "rb", 0) as r: seed = r.read(16) except Exception: - seed = str(time.time()) + seed = str(time.time()).encode() self.seeded = True self.seed_pid = os.getpid() self.digest = None seed = bytearray(seed) self._stir(seed) - def random_8(self): + def random_8(self) -> int: with self.lock: self._maybe_seed() if self.digest is None or self.next_byte == self.hash_len: @@ -88,16 +87,16 @@ class EntropyPool: self.next_byte += 1 return value - def random_16(self): + def random_16(self) -> int: return self.random_8() * 256 + self.random_8() - def random_32(self): + def random_32(self) -> int: return self.random_16() * 65536 + self.random_16() - def random_between(self, first, last): + def random_between(self, first: int, last: int) -> int: size = last - first + 1 if size > 4294967296: - raise ValueError('too big') + raise ValueError("too big") if size > 65536: rand = self.random_32 max = 4294967295 @@ -109,20 +108,24 @@ class EntropyPool: max = 255 return first + size * rand() // (max + 1) + pool = EntropyPool() +system_random: Optional[Any] try: system_random = random.SystemRandom() except Exception: # pragma: no cover system_random = None -def random_16(): + +def random_16() -> int: if system_random is not None: return system_random.randrange(0, 65536) else: return pool.random_16() -def between(first, last): + +def between(first: int, last: int) -> int: if system_random is not None: return system_random.randrange(first, last + 1) else: diff --git a/lib/dns/entropy.pyi b/lib/dns/entropy.pyi deleted file mode 100644 index 818f805a..00000000 --- a/lib/dns/entropy.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional -from random import SystemRandom - -system_random : Optional[SystemRandom] - -def random_16() -> int: - pass - -def between(first: int, last: int) -> int: - pass diff --git a/lib/dns/enum.py b/lib/dns/enum.py index b822dd51..b5a4aed8 100644 --- a/lib/dns/enum.py +++ b/lib/dns/enum.py @@ -17,6 +17,7 @@ import enum + class IntEnum(enum.IntEnum): @classmethod def _check_value(cls, value): @@ -32,9 +33,12 @@ class IntEnum(enum.IntEnum): return cls[text] except KeyError: pass + value = cls._extra_from_text(text) + if value: + return value prefix = cls._prefix() - if text.startswith(prefix) and text[len(prefix):].isdigit(): - value = int(text[len(prefix):]) + if text.startswith(prefix) and text[len(prefix) :].isdigit(): + value = int(text[len(prefix) :]) cls._check_value(value) try: return cls(value) @@ -46,9 +50,13 @@ class IntEnum(enum.IntEnum): def to_text(cls, value): cls._check_value(value) try: - return cls(value).name + text = cls(value).name except ValueError: - return f"{cls._prefix()}{value}" + text = None + text = cls._extra_to_text(value, text) + if text is None: + text = f"{cls._prefix()}{value}" + return text @classmethod def make(cls, value): @@ -83,7 +91,15 @@ class IntEnum(enum.IntEnum): @classmethod def _prefix(cls): - return '' + return "" + + @classmethod + def _extra_from_text(cls, text): # pylint: disable=W0613 + return None + + @classmethod + def _extra_to_text(cls, value, current_text): # pylint: disable=W0613 + return current_text @classmethod def _unknown_exception_class(cls): diff --git a/lib/dns/exception.py b/lib/dns/exception.py index 08393821..4b1481d1 100644 --- a/lib/dns/exception.py +++ b/lib/dns/exception.py @@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will always be subclasses of ``DNSException``. """ + +from typing import Optional, Set + + class DNSException(Exception): """Abstract base class shared by all dnspython exceptions. @@ -44,14 +48,15 @@ class DNSException(Exception): and ``fmt`` class variables to get nice parametrized messages. """ - msg = None # non-parametrized message - supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check) - fmt = None # message parametrized with results from _fmt_kwargs + msg: Optional[str] = None # non-parametrized message + supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check) + fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs def __init__(self, *args, **kwargs): self._check_params(*args, **kwargs) if kwargs: - self.kwargs = self._check_kwargs(**kwargs) + # This call to a virtual method from __init__ is ok in our usage + self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass] self.msg = str(self) else: self.kwargs = dict() # defined but empty for old mode exceptions @@ -68,14 +73,15 @@ class DNSException(Exception): For sanity we do not allow to mix old and new behavior.""" if args or kwargs: - assert bool(args) != bool(kwargs), \ - 'keyword arguments are mutually exclusive with positional args' + assert bool(args) != bool( + kwargs + ), "keyword arguments are mutually exclusive with positional args" def _check_kwargs(self, **kwargs): if kwargs: - assert set(kwargs.keys()) == self.supp_kwargs, \ - 'following set of keyword args is required: %s' % ( - self.supp_kwargs) + assert ( + set(kwargs.keys()) == self.supp_kwargs + ), "following set of keyword args is required: %s" % (self.supp_kwargs) return kwargs def _fmt_kwargs(self, **kwargs): @@ -124,9 +130,15 @@ class TooBig(DNSException): class Timeout(DNSException): """The DNS operation timed out.""" - supp_kwargs = {'timeout'} + + supp_kwargs = {"timeout"} fmt = "The DNS operation timed out after {timeout:.3f} seconds" + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + class ExceptionWrapper: def __init__(self, exception_class): @@ -136,7 +148,6 @@ class ExceptionWrapper: return self def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None and not isinstance(exc_val, - self.exception_class): + if exc_type is not None and not isinstance(exc_val, self.exception_class): raise self.exception_class(str(exc_val)) from exc_val return False diff --git a/lib/dns/exception.pyi b/lib/dns/exception.pyi deleted file mode 100644 index dc571264..00000000 --- a/lib/dns/exception.pyi +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Set, Optional, Dict - -class DNSException(Exception): - supp_kwargs : Set[str] - kwargs : Optional[Dict] - fmt : Optional[str] - -class SyntaxError(DNSException): ... -class FormError(DNSException): ... -class Timeout(DNSException): ... -class TooBig(DNSException): ... -class UnexpectedEnd(SyntaxError): ... diff --git a/lib/dns/flags.py b/lib/dns/flags.py index 96522879..b21b8e3b 100644 --- a/lib/dns/flags.py +++ b/lib/dns/flags.py @@ -17,10 +17,13 @@ """DNS Message Flags.""" +from typing import Any + import enum # Standard DNS flags + class Flag(enum.IntFlag): #: Query Response QR = 0x8000 @@ -40,12 +43,13 @@ class Flag(enum.IntFlag): # EDNS flags + class EDNSFlag(enum.IntFlag): #: DNSSEC answer OK DO = 0x8000 -def _from_text(text, enum_class): +def _from_text(text: str, enum_class: Any) -> int: flags = 0 tokens = text.split() for t in tokens: @@ -53,15 +57,15 @@ def _from_text(text, enum_class): return flags -def _to_text(flags, enum_class): +def _to_text(flags: int, enum_class: Any) -> str: text_flags = [] for k, v in enum_class.__members__.items(): if flags & v != 0: text_flags.append(k) - return ' '.join(text_flags) + return " ".join(text_flags) -def from_text(text): +def from_text(text: str) -> int: """Convert a space-separated list of flag text values into a flags value. @@ -71,7 +75,7 @@ def from_text(text): return _from_text(text, Flag) -def to_text(flags): +def to_text(flags: int) -> str: """Convert a flags value into a space-separated list of flag text values. @@ -81,7 +85,7 @@ def to_text(flags): return _to_text(flags, Flag) -def edns_from_text(text): +def edns_from_text(text: str) -> int: """Convert a space-separated list of EDNS flag text values into a EDNS flags value. @@ -91,7 +95,7 @@ def edns_from_text(text): return _from_text(text, EDNSFlag) -def edns_to_text(flags): +def edns_to_text(flags: int) -> str: """Convert an EDNS flags value into a space-separated list of EDNS flag text values. @@ -100,6 +104,7 @@ def edns_to_text(flags): return _to_text(flags, EDNSFlag) + ### BEGIN generated Flag constants QR = Flag.QR diff --git a/lib/dns/grange.py b/lib/dns/grange.py index 112ede47..3a52278f 100644 --- a/lib/dns/grange.py +++ b/lib/dns/grange.py @@ -17,9 +17,12 @@ """DNS GENERATE range conversion.""" +from typing import Tuple + import dns -def from_text(text): + +def from_text(text: str) -> Tuple[int, int, int]: """Convert the text form of a range in a ``$GENERATE`` statement to an integer. @@ -31,22 +34,22 @@ def from_text(text): start = -1 stop = -1 step = 1 - cur = '' + cur = "" state = 0 # state 0 1 2 # x - y / z - if text and text[0] == '-': + if text and text[0] == "-": raise dns.exception.SyntaxError("Start cannot be a negative number") for c in text: - if c == '-' and state == 0: + if c == "-" and state == 0: start = int(cur) - cur = '' + cur = "" state = 1 - elif c == '/': + elif c == "/": stop = int(cur) - cur = '' + cur = "" state = 2 elif c.isdigit(): cur += c @@ -64,6 +67,6 @@ def from_text(text): assert step >= 1 assert start >= 0 if start > stop: - raise dns.exception.SyntaxError('start must be <= stop') + raise dns.exception.SyntaxError("start must be <= stop") return (start, stop, step) diff --git a/lib/dns/immutable.py b/lib/dns/immutable.py index db7abbcc..38fbe597 100644 --- a/lib/dns/immutable.py +++ b/lib/dns/immutable.py @@ -1,32 +1,25 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import collections.abc -import sys +from typing import Any -# pylint: disable=unused-import -if sys.version_info >= (3, 7): - odict = dict - from dns._immutable_ctx import immutable -else: - # pragma: no cover - from collections import OrderedDict as odict - from dns._immutable_attr import immutable # noqa -# pylint: enable=unused-import +import collections.abc + +from dns._immutable_ctx import immutable @immutable -class Dict(collections.abc.Mapping): - def __init__(self, dictionary, no_copy=False): +class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] + def __init__(self, dictionary: Any, no_copy: bool = False): """Make an immutable dictionary from the specified dictionary. If *no_copy* is `True`, then *dictionary* will be wrapped instead of copied. Only set this if you are sure there will be no external references to the dictionary. """ - if no_copy and isinstance(dictionary, odict): + if no_copy and isinstance(dictionary, dict): self._odict = dictionary else: - self._odict = odict(dictionary) + self._odict = dict(dictionary) self._hash = None def __getitem__(self, key): @@ -37,7 +30,7 @@ class Dict(collections.abc.Mapping): h = 0 for key in sorted(self._odict.keys()): h ^= hash(key) - object.__setattr__(self, '_hash', h) + object.__setattr__(self, "_hash", h) # this does return an int, but pylint doesn't figure that out return self._hash @@ -48,7 +41,7 @@ class Dict(collections.abc.Mapping): return iter(self._odict) -def constify(o): +def constify(o: Any) -> Any: """ Convert mutable types to immutable types. """ @@ -63,7 +56,7 @@ def constify(o): if isinstance(o, list): return tuple(constify(elt) for elt in o) if isinstance(o, dict): - cdict = odict() + cdict = dict() for k, v in o.items(): cdict[k] = constify(v) return Dict(cdict, True) diff --git a/lib/dns/inet.py b/lib/dns/inet.py index d3bdc64c..11180c96 100644 --- a/lib/dns/inet.py +++ b/lib/dns/inet.py @@ -17,6 +17,8 @@ """Generic Internet address helper functions.""" +from typing import Any, Optional, Tuple + import socket import dns.ipv4 @@ -30,7 +32,7 @@ AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 -def inet_pton(family, text): +def inet_pton(family: int, text: str) -> bytes: """Convert the textual form of a network address into its binary form. *family* is an ``int``, the address family. @@ -51,7 +53,7 @@ def inet_pton(family, text): raise NotImplementedError -def inet_ntop(family, address): +def inet_ntop(family: int, address: bytes) -> str: """Convert the binary form of a network address into its textual form. *family* is an ``int``, the address family. @@ -72,7 +74,7 @@ def inet_ntop(family, address): raise NotImplementedError -def af_for_address(text): +def af_for_address(text: str) -> int: """Determine the address family of a textual-form network address. *text*, a ``str``, the textual address. @@ -94,7 +96,7 @@ def af_for_address(text): raise ValueError -def is_multicast(text): +def is_multicast(text: str) -> bool: """Is the textual-form network address a multicast address? *text*, a ``str``, the textual address. @@ -116,7 +118,7 @@ def is_multicast(text): raise ValueError -def is_address(text): +def is_address(text: str) -> bool: """Is the specified string an IPv4 or IPv6 address? *text*, a ``str``, the textual address. @@ -135,7 +137,9 @@ def is_address(text): return False -def low_level_address_tuple(high_tuple, af=None): +def low_level_address_tuple( + high_tuple: Tuple[str, int], af: Optional[int] = None +) -> Any: """Given a "high-level" address tuple, i.e. an (address, port) return the appropriate "low-level" address tuple suitable for use in socket calls. @@ -143,7 +147,6 @@ def low_level_address_tuple(high_tuple, af=None): If an *af* other than ``None`` is provided, it is assumed the address in the high-level tuple is valid and has that af. If af is ``None``, then af_for_address will be called. - """ address, port = high_tuple if af is None: @@ -151,13 +154,13 @@ def low_level_address_tuple(high_tuple, af=None): if af == AF_INET: return (address, port) elif af == AF_INET6: - i = address.find('%') + i = address.find("%") if i < 0: # no scope, shortcut! return (address, port, 0, 0) # try to avoid getaddrinfo() addrpart = address[:i] - scope = address[i + 1:] + scope = address[i + 1 :] if scope.isdigit(): return (addrpart, port, 0, int(scope)) try: @@ -167,4 +170,4 @@ def low_level_address_tuple(high_tuple, af=None): ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) return tup else: - raise NotImplementedError(f'unknown address family {af}') + raise NotImplementedError(f"unknown address family {af}") diff --git a/lib/dns/inet.pyi b/lib/dns/inet.pyi deleted file mode 100644 index 6d9dcc70..00000000 --- a/lib/dns/inet.pyi +++ /dev/null @@ -1,4 +0,0 @@ -from typing import Union -from socket import AddressFamily - -AF_INET6 : Union[int, AddressFamily] diff --git a/lib/dns/ipv4.py b/lib/dns/ipv4.py index e1f38d3d..b8e148f3 100644 --- a/lib/dns/ipv4.py +++ b/lib/dns/ipv4.py @@ -17,11 +17,14 @@ """IPv4 helper functions.""" +from typing import Union + import struct import dns.exception -def inet_ntoa(address): + +def inet_ntoa(address: bytes) -> str: """Convert an IPv4 address in binary form to text form. *address*, a ``bytes``, the IPv4 address in binary form. @@ -31,30 +34,32 @@ def inet_ntoa(address): if len(address) != 4: raise dns.exception.SyntaxError - return ('%u.%u.%u.%u' % (address[0], address[1], - address[2], address[3])) + return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3]) -def inet_aton(text): + +def inet_aton(text: Union[str, bytes]) -> bytes: """Convert an IPv4 address in text form to binary form. - *text*, a ``str``, the IPv4 address in textual form. + *text*, a ``str`` or ``bytes``, the IPv4 address in textual form. Returns a ``bytes``. """ if not isinstance(text, bytes): - text = text.encode() - parts = text.split(b'.') + btext = text.encode() + else: + btext = text + parts = btext.split(b".") if len(parts) != 4: raise dns.exception.SyntaxError for part in parts: if not part.isdigit(): raise dns.exception.SyntaxError - if len(part) > 1 and part[0] == ord('0'): + if len(part) > 1 and part[0] == ord("0"): # No leading zeros raise dns.exception.SyntaxError try: b = [int(part) for part in parts] - return struct.pack('BBBB', *b) + return struct.pack("BBBB", *b) except Exception: raise dns.exception.SyntaxError diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index 0db6fcfa..fbd49623 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -17,15 +17,18 @@ """IPv6 helper functions.""" +from typing import List, Union + import re import binascii import dns.exception import dns.ipv4 -_leading_zero = re.compile(r'0+([0-9a-f]+)') +_leading_zero = re.compile(r"0+([0-9a-f]+)") -def inet_ntoa(address): + +def inet_ntoa(address: bytes) -> str: """Convert an IPv6 address in binary form to text form. *address*, a ``bytes``, the IPv6 address in binary form. @@ -41,7 +44,7 @@ def inet_ntoa(address): i = 0 l = len(hex) while i < l: - chunk = hex[i:i + 4].decode() + chunk = hex[i : i + 4].decode() # strip leading zeros. we do this with an re instead of # with lstrip() because lstrip() didn't support chars until # python 2.2.2 @@ -58,7 +61,7 @@ def inet_ntoa(address): start = -1 last_was_zero = False for i in range(8): - if chunks[i] != '0': + if chunks[i] != "0": if last_was_zero: end = i current_len = end - start @@ -76,27 +79,30 @@ def inet_ntoa(address): best_start = start best_len = current_len if best_len > 1: - if best_start == 0 and \ - (best_len == 6 or - best_len == 5 and chunks[5] == 'ffff'): + if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"): # We have an embedded IPv4 address if best_len == 6: - prefix = '::' + prefix = "::" else: - prefix = '::ffff:' - hex = prefix + dns.ipv4.inet_ntoa(address[12:]) + prefix = "::ffff:" + thex = prefix + dns.ipv4.inet_ntoa(address[12:]) else: - hex = ':'.join(chunks[:best_start]) + '::' + \ - ':'.join(chunks[best_start + best_len:]) + thex = ( + ":".join(chunks[:best_start]) + + "::" + + ":".join(chunks[best_start + best_len :]) + ) else: - hex = ':'.join(chunks) - return hex + thex = ":".join(chunks) + return thex -_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$') -_colon_colon_start = re.compile(br'::.*') -_colon_colon_end = re.compile(br'.*::$') -def inet_aton(text, ignore_scope=False): +_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$") +_colon_colon_start = re.compile(rb"::.*") +_colon_colon_end = re.compile(rb".*::$") + + +def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes: """Convert an IPv6 address in text form to binary form. *text*, a ``str``, the IPv6 address in textual form. @@ -111,82 +117,88 @@ def inet_aton(text, ignore_scope=False): # Our aim here is not something fast; we just want something that works. # if not isinstance(text, bytes): - text = text.encode() + btext = text.encode() + else: + btext = text if ignore_scope: - parts = text.split(b'%') + parts = btext.split(b"%") l = len(parts) if l == 2: - text = parts[0] + btext = parts[0] elif l > 2: raise dns.exception.SyntaxError - if text == b'': + if btext == b"": raise dns.exception.SyntaxError - elif text.endswith(b':') and not text.endswith(b'::'): + elif btext.endswith(b":") and not btext.endswith(b"::"): raise dns.exception.SyntaxError - elif text.startswith(b':') and not text.startswith(b'::'): + elif btext.startswith(b":") and not btext.startswith(b"::"): raise dns.exception.SyntaxError - elif text == b'::': - text = b'0::' + elif btext == b"::": + btext = b"0::" # # Get rid of the icky dot-quad syntax if we have it. # - m = _v4_ending.match(text) + m = _v4_ending.match(btext) if m is not None: b = dns.ipv4.inet_aton(m.group(2)) - text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), - b[0], b[1], b[2], - b[3])).encode() + btext = ( + "{}:{:02x}{:02x}:{:02x}{:02x}".format( + m.group(1).decode(), b[0], b[1], b[2], b[3] + ) + ).encode() # # Try to turn '::' into ':'; if no match try to # turn '::' into ':' # - m = _colon_colon_start.match(text) + m = _colon_colon_start.match(btext) if m is not None: - text = text[1:] + btext = btext[1:] else: - m = _colon_colon_end.match(text) + m = _colon_colon_end.match(btext) if m is not None: - text = text[:-1] + btext = btext[:-1] # # Now canonicalize into 8 chunks of 4 hex digits each # - chunks = text.split(b':') + chunks = btext.split(b":") l = len(chunks) if l > 8: raise dns.exception.SyntaxError seen_empty = False - canonical = [] + canonical: List[bytes] = [] for c in chunks: - if c == b'': + if c == b"": if seen_empty: raise dns.exception.SyntaxError seen_empty = True for _ in range(0, 8 - l + 1): - canonical.append(b'0000') + canonical.append(b"0000") else: lc = len(c) if lc > 4: raise dns.exception.SyntaxError if lc != 4: - c = (b'0' * (4 - lc)) + c + c = (b"0" * (4 - lc)) + c canonical.append(c) if l < 8 and not seen_empty: raise dns.exception.SyntaxError - text = b''.join(canonical) + btext = b"".join(canonical) # # Finally we can go to binary. # try: - return binascii.unhexlify(text) + return binascii.unhexlify(btext) except (binascii.Error, TypeError): raise dns.exception.SyntaxError -_mapped_prefix = b'\x00' * 10 + b'\xff\xff' -def is_mapped(address): +_mapped_prefix = b"\x00" * 10 + b"\xff\xff" + + +def is_mapped(address: bytes) -> bool: """Is the specified address a mapped IPv4 address? *address*, a ``bytes`` is an IPv6 address in binary form. diff --git a/lib/dns/message.py b/lib/dns/message.py index c2751a90..8250db3b 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -17,6 +17,8 @@ """DNS Messages""" +from typing import Any, Dict, List, Optional, Tuple, Union + import contextlib import io import time @@ -71,14 +73,19 @@ class UnknownTSIGKey(dns.exception.DNSException): class Truncated(dns.exception.DNSException): """The truncated flag is set.""" - supp_kwargs = {'message'} + supp_kwargs = {"message"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def message(self): """As much of the message as could be processed. Returns a ``dns.message.Message``. """ - return self.kwargs['message'] + return self.kwargs["message"] class NotQueryResponse(dns.exception.DNSException): @@ -92,12 +99,14 @@ class ChainTooLong(dns.exception.DNSException): class AnswerForNXDOMAIN(dns.exception.DNSException): """The rcode is NXDOMAIN but an answer was found.""" + class NoPreviousName(dns.exception.SyntaxError): """No previous name was known.""" class MessageSection(dns.enum.IntEnum): """Message sections""" + QUESTION = 0 ANSWER = 1 AUTHORITY = 2 @@ -109,7 +118,7 @@ class MessageSection(dns.enum.IntEnum): class MessageError: - def __init__(self, exception, offset): + def __init__(self, exception: Exception, offset: int): self.exception = exception self.offset = offset @@ -117,32 +126,46 @@ class MessageError: DEFAULT_EDNS_PAYLOAD = 1232 MAX_CHAIN = 16 +IndexKeyType = Tuple[ + int, + dns.name.Name, + dns.rdataclass.RdataClass, + dns.rdatatype.RdataType, + Optional[dns.rdatatype.RdataType], + Optional[dns.rdataclass.RdataClass], +] +IndexType = Dict[IndexKeyType, dns.rrset.RRset] +SectionType = Union[int, List[dns.rrset.RRset]] + + class Message: """A DNS message.""" _section_enum = MessageSection - def __init__(self, id=None): + def __init__(self, id: Optional[int] = None): if id is None: self.id = dns.entropy.random_16() else: self.id = id self.flags = 0 - self.sections = [[], [], [], []] - self.opt = None + self.sections: List[List[dns.rrset.RRset]] = [[], [], [], []] + self.opt: Optional[dns.rrset.RRset] = None self.request_payload = 0 - self.keyring = None - self.tsig = None - self.request_mac = b'' + self.pad = 0 + self.keyring: Any = None + self.tsig: Optional[dns.rrset.RRset] = None + self.request_mac = b"" self.xfr = False - self.origin = None - self.tsig_ctx = None - self.index = {} - self.errors = [] + self.origin: Optional[dns.name.Name] = None + self.tsig_ctx: Optional[Any] = None + self.index: IndexType = {} + self.errors: List[MessageError] = [] + self.time = 0.0 @property - def question(self): - """ The question section.""" + def question(self) -> List[dns.rrset.RRset]: + """The question section.""" return self.sections[0] @question.setter @@ -150,8 +173,8 @@ class Message: self.sections[0] = v @property - def answer(self): - """ The answer section.""" + def answer(self) -> List[dns.rrset.RRset]: + """The answer section.""" return self.sections[1] @answer.setter @@ -159,8 +182,8 @@ class Message: self.sections[1] = v @property - def authority(self): - """ The authority section.""" + def authority(self) -> List[dns.rrset.RRset]: + """The authority section.""" return self.sections[2] @authority.setter @@ -168,8 +191,8 @@ class Message: self.sections[2] = v @property - def additional(self): - """ The additional data section.""" + def additional(self) -> List[dns.rrset.RRset]: + """The additional data section.""" return self.sections[3] @additional.setter @@ -177,12 +200,17 @@ class Message: self.sections[3] = v def __repr__(self): - return '' + return "" def __str__(self): return self.to_text() - def to_text(self, origin=None, relativize=True, **kw): + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: """Convert the message to text. The *origin*, *relativize*, and any other keyword @@ -192,23 +220,22 @@ class Message: """ s = io.StringIO() - s.write('id %d\n' % self.id) - s.write('opcode %s\n' % dns.opcode.to_text(self.opcode())) - s.write('rcode %s\n' % dns.rcode.to_text(self.rcode())) - s.write('flags %s\n' % dns.flags.to_text(self.flags)) + s.write("id %d\n" % self.id) + s.write("opcode %s\n" % dns.opcode.to_text(self.opcode())) + s.write("rcode %s\n" % dns.rcode.to_text(self.rcode())) + s.write("flags %s\n" % dns.flags.to_text(self.flags)) if self.edns >= 0: - s.write('edns %s\n' % self.edns) + s.write("edns %s\n" % self.edns) if self.ednsflags != 0: - s.write('eflags %s\n' % - dns.flags.edns_to_text(self.ednsflags)) - s.write('payload %d\n' % self.payload) + s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags)) + s.write("payload %d\n" % self.payload) for opt in self.options: - s.write('option %s\n' % opt.to_text()) + s.write("option %s\n" % opt.to_text()) for (name, which) in self._section_enum.__members__.items(): - s.write(f';{name}\n') + s.write(f";{name}\n") for rrset in self.section_from_number(which): s.write(rrset.to_text(origin, relativize, **kw)) - s.write('\n') + s.write("\n") # # We strip off the final \n so the caller can print the result without # doing weird things to get around eccentricities in Python print @@ -242,20 +269,25 @@ class Message: def __ne__(self, other): return not self.__eq__(other) - def is_response(self, other): + def is_response(self, other: "Message") -> bool: """Is *other*, also a ``dns.message.Message``, a response to this message? Returns a ``bool``. """ - if other.flags & dns.flags.QR == 0 or \ - self.id != other.id or \ - dns.opcode.from_flags(self.flags) != \ - dns.opcode.from_flags(other.flags): + if ( + other.flags & dns.flags.QR == 0 + or self.id != other.id + or dns.opcode.from_flags(self.flags) != dns.opcode.from_flags(other.flags) + ): return False - if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL, - dns.rcode.NOTIMP, dns.rcode.REFUSED}: + if other.rcode() in { + dns.rcode.FORMERR, + dns.rcode.SERVFAIL, + dns.rcode.NOTIMP, + dns.rcode.REFUSED, + }: # We don't check the question section in these cases if # the other question section is empty, even though they # still really ought to have a question section. @@ -275,7 +307,7 @@ class Message: return False return True - def section_number(self, section): + def section_number(self, section: List[dns.rrset.RRset]) -> int: """Return the "section number" of the specified section for use in indexing. @@ -289,9 +321,9 @@ class Message: for i, our_section in enumerate(self.sections): if section is our_section: return self._section_enum(i) - raise ValueError('unknown section') + raise ValueError("unknown section") - def section_from_number(self, number): + def section_from_number(self, number: int) -> List[dns.rrset.RRset]: """Return the section list associated with the specified section number. @@ -306,9 +338,17 @@ class Message: section = self._section_enum.make(number) return self.sections[section] - def find_rrset(self, section, name, rdclass, rdtype, - covers=dns.rdatatype.NONE, deleting=None, create=False, - force_unique=False): + def find_rrset( + self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + create: bool = False, + force_unique: bool = False, + ) -> dns.rrset.RRset: """Find the RRset with the given attributes in the specified section. *section*, an ``int`` section number, or one of the section @@ -346,9 +386,10 @@ class Message: if isinstance(section, int): section_number = section - section = self.section_from_number(section_number) + the_section = self.section_from_number(section_number) else: section_number = self.section_number(section) + the_section = section key = (section_number, name, rdclass, rdtype, covers, deleting) if not force_unique: if self.index is not None: @@ -356,21 +397,28 @@ class Message: if rrset is not None: return rrset else: - for rrset in section: - if rrset.full_match(name, rdclass, rdtype, covers, - deleting): + for rrset in the_section: + if rrset.full_match(name, rdclass, rdtype, covers, deleting): return rrset if not create: raise KeyError rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) - section.append(rrset) + the_section.append(rrset) if self.index is not None: self.index[key] = rrset return rrset - def get_rrset(self, section, name, rdclass, rdtype, - covers=dns.rdatatype.NONE, deleting=None, create=False, - force_unique=False): + def get_rrset( + self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + create: bool = False, + force_unique: bool = False, + ) -> Optional[dns.rrset.RRset]: """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. @@ -406,14 +454,53 @@ class Message: """ try: - rrset = self.find_rrset(section, name, rdclass, rdtype, covers, - deleting, create, force_unique) + rrset = self.find_rrset( + section, name, rdclass, rdtype, covers, deleting, create, force_unique + ) except KeyError: rrset = None return rrset - def to_wire(self, origin=None, max_size=0, multi=False, tsig_ctx=None, - **kw): + def _compute_opt_reserve(self) -> int: + """Compute the size required for the OPT RR, padding excluded""" + if not self.opt: + return 0 + # 1 byte for the root name, 10 for the standard RR fields + size = 11 + # This would be more efficient if options had a size() method, but we won't + # worry about that for now. We also don't worry if there is an existing padding + # option, as it is unlikely and probably harmless, as the worst case is that we + # may add another, and this seems to be legal. + for option in self.opt[0].options: + wire = option.to_wire() + # We add 4 here to account for the option type and length + size += len(wire) + 4 + if self.pad: + # Padding will be added, so again add the option type and length. + size += 4 + return size + + def _compute_tsig_reserve(self) -> int: + """Compute the size required for the TSIG RR""" + # This would be more efficient if TSIGs had a size method, but we won't + # worry about for now. Also, we can't really cope with the potential + # compressibility of the TSIG owner name, so we estimate with the uncompressed + # size. We will disable compression when TSIG and padding are both is active + # so that the padding comes out right. + if not self.tsig: + return 0 + f = io.BytesIO() + self.tsig.to_wire(f) + return len(f.getvalue()) + + def to_wire( + self, + origin: Optional[dns.name.Name] = None, + max_size: int = 0, + multi: bool = False, + tsig_ctx: Optional[Any] = None, + **kw: Dict[str, Any], + ) -> bytes: """Return a string containing the message in DNS compressed wire format. @@ -451,25 +538,32 @@ class Message: elif max_size > 65535: max_size = 65535 r = dns.renderer.Renderer(self.id, self.flags, max_size, origin) + opt_reserve = self._compute_opt_reserve() + r.reserve(opt_reserve) + tsig_reserve = self._compute_tsig_reserve() + r.reserve(tsig_reserve) for rrset in self.question: r.add_question(rrset.name, rrset.rdtype, rrset.rdclass) for rrset in self.answer: r.add_rrset(dns.renderer.ANSWER, rrset, **kw) for rrset in self.authority: r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw) - if self.opt is not None: - r.add_rrset(dns.renderer.ADDITIONAL, self.opt) for rrset in self.additional: r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) + r.release_reserved() + if self.opt is not None: + r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve) r.write_header() if self.tsig is not None: - (new_tsig, ctx) = dns.tsig.sign(r.get_wire(), - self.keyring, - self.tsig[0], - int(time.time()), - self.request_mac, - tsig_ctx, - multi) + (new_tsig, ctx) = dns.tsig.sign( + r.get_wire(), + self.keyring, + self.tsig[0], + int(time.time()), + self.request_mac, + tsig_ctx, + multi, + ) self.tsig.clear() self.tsig.add(new_tsig) r.add_rrset(dns.renderer.ADDITIONAL, self.tsig) @@ -479,16 +573,32 @@ class Message: return r.get_wire() @staticmethod - def _make_tsig(keyname, algorithm, time_signed, fudge, mac, original_id, - error, other): - tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, - algorithm, time_signed, fudge, mac, - original_id, error, other) + def _make_tsig( + keyname, algorithm, time_signed, fudge, mac, original_id, error, other + ): + tsig = dns.rdtypes.ANY.TSIG.TSIG( + dns.rdataclass.ANY, + dns.rdatatype.TSIG, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) return dns.rrset.from_rdata(keyname, 0, tsig) - def use_tsig(self, keyring, keyname=None, fudge=300, - original_id=None, tsig_error=0, other_data=b'', - algorithm=dns.tsig.default_algorithm): + def use_tsig( + self, + keyring: Any, + keyname: Optional[Union[dns.name.Name, str]] = None, + fudge: int = 300, + original_id: Optional[int] = None, + tsig_error: int = 0, + other_data: bytes = b"", + algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + ) -> None: """When sending, a TSIG signature using the specified key should be added. @@ -522,7 +632,7 @@ class Message: *other_data*, a ``bytes``, the TSIG other data. - *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use. This is + *algorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. This is only used if *keyring* is a ``dict``, and the key entry is a ``bytes``. """ @@ -542,68 +652,84 @@ class Message: self.keyring = key if original_id is None: original_id = self.id - self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge, - b'', original_id, tsig_error, other_data) + self.tsig = self._make_tsig( + keyname, + self.keyring.algorithm, + 0, + fudge, + b"\x00" * dns.tsig.mac_sizes[self.keyring.algorithm], + original_id, + tsig_error, + other_data, + ) @property - def keyname(self): + def keyname(self) -> Optional[dns.name.Name]: if self.tsig: return self.tsig.name else: return None @property - def keyalgorithm(self): + def keyalgorithm(self) -> Optional[dns.name.Name]: if self.tsig: return self.tsig[0].algorithm else: return None @property - def mac(self): + def mac(self) -> Optional[bytes]: if self.tsig: return self.tsig[0].mac else: return None @property - def tsig_error(self): + def tsig_error(self) -> Optional[int]: if self.tsig: return self.tsig[0].error else: return None @property - def had_tsig(self): + def had_tsig(self) -> bool: return bool(self.tsig) @staticmethod def _make_opt(flags=0, payload=DEFAULT_EDNS_PAYLOAD, options=None): - opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, - options or ()) + opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, options or ()) return dns.rrset.from_rdata(dns.name.root, int(flags), opt) - def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD, - request_payload=None, options=None): + def use_edns( + self, + edns: Optional[Union[int, bool]] = 0, + ednsflags: int = 0, + payload: int = DEFAULT_EDNS_PAYLOAD, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + pad: int = 0, + ) -> None: """Configure EDNS behavior. - *edns*, an ``int``, is the EDNS level to use. Specifying - ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case - the other parameters are ignored. Specifying ``True`` is - equivalent to specifying 0, i.e. "use EDNS0". + *edns*, an ``int``, is the EDNS level to use. Specifying ``None``, ``False``, + or ``-1`` means "do not use EDNS", and in this case the other parameters are + ignored. Specifying ``True`` is equivalent to specifying 0, i.e. "use EDNS0". *ednsflags*, an ``int``, the EDNS flag values. - *payload*, an ``int``, is the EDNS sender's payload field, which is the - maximum size of UDP datagram the sender can handle. I.e. how big - a response to this message can be. + *payload*, an ``int``, is the EDNS sender's payload field, which is the maximum + size of UDP datagram the sender can handle. I.e. how big a response to this + message can be. - *request_payload*, an ``int``, is the EDNS payload size to use when - sending this message. If not specified, defaults to the value of - *payload*. + *request_payload*, an ``int``, is the EDNS payload size to use when sending this + message. If not specified, defaults to the value of *payload*. - *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS - options. + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS options. + + *pad*, a non-negative ``int``. If 0, the default, do not pad; otherwise add + padding bytes to make the message size a multiple of *pad*. Note that if + padding is non-zero, an EDNS PADDING option will always be added to the + message. """ if edns is None or edns is False: @@ -616,23 +742,24 @@ class Message: else: # make sure the EDNS version in ednsflags agrees with edns ednsflags &= 0xFF00FFFF - ednsflags |= (edns << 16) + ednsflags |= edns << 16 if options is None: options = [] self.opt = self._make_opt(ednsflags, payload, options) if request_payload is None: request_payload = payload self.request_payload = request_payload + self.pad = pad @property - def edns(self): + def edns(self) -> int: if self.opt: - return (self.ednsflags & 0xff0000) >> 16 + return (self.ednsflags & 0xFF0000) >> 16 else: return -1 @property - def ednsflags(self): + def ednsflags(self) -> int: if self.opt: return self.opt.ttl else: @@ -646,20 +773,20 @@ class Message: self.opt = self._make_opt(v) @property - def payload(self): + def payload(self) -> int: if self.opt: return self.opt[0].payload else: return 0 @property - def options(self): + def options(self) -> Tuple: if self.opt: return self.opt[0].options else: return () - def want_dnssec(self, wanted=True): + def want_dnssec(self, wanted: bool = True) -> None: """Enable or disable 'DNSSEC desired' flag in requests. *wanted*, a ``bool``. If ``True``, then DNSSEC data is @@ -673,17 +800,17 @@ class Message: elif self.opt: self.ednsflags &= ~dns.flags.DO - def rcode(self): + def rcode(self) -> dns.rcode.Rcode: """Return the rcode. - Returns an ``int``. + Returns a ``dns.rcode.Rcode``. """ return dns.rcode.from_flags(int(self.flags), int(self.ednsflags)) - def set_rcode(self, rcode): + def set_rcode(self, rcode: dns.rcode.Rcode) -> None: """Set the rcode. - *rcode*, an ``int``, is the rcode to set. + *rcode*, a ``dns.rcode.Rcode``, is the rcode to set. """ (value, evalue) = dns.rcode.to_flags(rcode) self.flags &= 0xFFF0 @@ -691,17 +818,17 @@ class Message: self.ednsflags &= 0x00FFFFFF self.ednsflags |= evalue - def opcode(self): + def opcode(self) -> dns.opcode.Opcode: """Return the opcode. - Returns an ``int``. + Returns a ``dns.opcode.Opcode``. """ return dns.opcode.from_flags(int(self.flags)) - def set_opcode(self, opcode): + def set_opcode(self, opcode: dns.opcode.Opcode) -> None: """Set the opcode. - *opcode*, an ``int``, is the opcode to set. + *opcode*, a ``dns.opcode.Opcode``, is the opcode to set. """ self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) @@ -717,16 +844,20 @@ class Message: # pylint: enable=unused-argument - def _parse_special_rr_header(self, section, count, position, - name, rdclass, rdtype): + def _parse_special_rr_header(self, section, count, position, name, rdclass, rdtype): if rdtype == dns.rdatatype.OPT: - if section != MessageSection.ADDITIONAL or self.opt or \ - name != dns.name.root: + if ( + section != MessageSection.ADDITIONAL + or self.opt + or name != dns.name.root + ): raise BadEDNS elif rdtype == dns.rdatatype.TSIG: - if section != MessageSection.ADDITIONAL or \ - rdclass != dns.rdatatype.ANY or \ - position != count - 1: + if ( + section != MessageSection.ADDITIONAL + or rdclass != dns.rdatatype.ANY + or position != count - 1 + ): raise BadTSIG return (rdclass, rdtype, None, False) @@ -738,7 +869,7 @@ class ChainingResult: exist. The ``canonical_name`` attribute is the canonical name after all - chaining has been applied (this is the name as ``rrset.name`` in cases + chaining has been applied (this is the same name as ``rrset.name`` in cases where rrset is not ``None``). The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to @@ -749,7 +880,14 @@ class ChainingResult: The ``cnames`` attribute is a list of all the CNAME RRSets followed to get to the canonical name. """ - def __init__(self, canonical_name, answer, minimum_ttl, cnames): + + def __init__( + self, + canonical_name: dns.name.Name, + answer: Optional[dns.rrset.RRset], + minimum_ttl: int, + cnames: List[dns.rrset.RRset], + ): self.canonical_name = canonical_name self.answer = answer self.minimum_ttl = minimum_ttl @@ -757,7 +895,7 @@ class ChainingResult: class QueryMessage(Message): - def resolve_chaining(self): + def resolve_chaining(self) -> ChainingResult: """Follow the CNAME chain in the response to determine the answer RRset. @@ -785,16 +923,17 @@ class QueryMessage(Message): cnames = [] while count < MAX_CHAIN: try: - answer = self.find_rrset(self.answer, qname, question.rdclass, - question.rdtype) + answer = self.find_rrset( + self.answer, qname, question.rdclass, question.rdtype + ) min_ttl = min(min_ttl, answer.ttl) break except KeyError: if question.rdtype != dns.rdatatype.CNAME: try: - crrset = self.find_rrset(self.answer, qname, - question.rdclass, - dns.rdatatype.CNAME) + crrset = self.find_rrset( + self.answer, qname, question.rdclass, dns.rdatatype.CNAME + ) cnames.append(crrset) min_ttl = min(min_ttl, crrset.ttl) for rd in crrset: @@ -819,9 +958,9 @@ class QueryMessage(Message): # Look for an SOA RR whose owner name is a superdomain # of qname. try: - srrset = self.find_rrset(self.authority, auname, - question.rdclass, - dns.rdatatype.SOA) + srrset = self.find_rrset( + self.authority, auname, question.rdclass, dns.rdatatype.SOA + ) min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum) break except KeyError: @@ -831,7 +970,7 @@ class QueryMessage(Message): break return ChainingResult(qname, answer, min_ttl, cnames) - def canonical_name(self): + def canonical_name(self) -> dns.name.Name: """Return the canonical name of the first name in the question section. @@ -885,9 +1024,17 @@ class _WireReader: raising them. """ - def __init__(self, wire, initialize_message, question_only=False, - one_rr_per_rrset=False, ignore_trailing=False, - keyring=None, multi=False, continue_on_error=False): + def __init__( + self, + wire, + initialize_message, + question_only=False, + one_rr_per_rrset=False, + ignore_trailing=False, + keyring=None, + multi=False, + continue_on_error=False, + ): self.parser = dns.wire.Parser(wire) self.message = None self.initialize_message = initialize_message @@ -903,16 +1050,17 @@ class _WireReader: """Read the next *qcount* records from the wire data and add them to the question section. """ - + assert self.message is not None section = self.message.sections[section_number] for _ in range(qcount): qname = self.parser.get_name(self.message.origin) - (rdtype, rdclass) = self.parser.get_struct('!HH') - (rdclass, rdtype, _, _) = \ - self.message._parse_rr_header(section_number, qname, rdclass, - rdtype) - self.message.find_rrset(section, qname, rdclass, rdtype, - create=True, force_unique=True) + (rdtype, rdclass) = self.parser.get_struct("!HH") + (rdclass, rdtype, _, _) = self.message._parse_rr_header( + section_number, qname, rdclass, rdtype + ) + self.message.find_rrset( + section, qname, rdclass, rdtype, create=True, force_unique=True + ) def _add_error(self, e): self.errors.append(MessageError(e, self.parser.current)) @@ -924,7 +1072,7 @@ class _WireReader: section_number: the section of the message to which to add records count: the number of records to read """ - + assert self.message is not None section = self.message.sections[section_number] force_unique = self.one_rr_per_rrset for i in range(count): @@ -934,18 +1082,22 @@ class _WireReader: name = absolute_name.relativize(self.message.origin) else: name = absolute_name - (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct('!HHIH') + (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct("!HHIH") if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG): - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_special_rr_header(section_number, - count, i, name, - rdclass, rdtype) + ( + rdclass, + rdtype, + deleting, + empty, + ) = self.message._parse_special_rr_header( + section_number, count, i, name, rdclass, rdtype + ) else: - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_rr_header(section_number, - name, rdclass, rdtype) + (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) + rdata_start = self.parser.current try: - rdata_start = self.parser.current if empty: if rdlen > 0: raise dns.exception.FormError @@ -953,9 +1105,9 @@ class _WireReader: covers = dns.rdatatype.NONE else: with self.parser.restrict_to(rdlen): - rd = dns.rdata.from_wire_parser(rdclass, rdtype, - self.parser, - self.message.origin) + rd = dns.rdata.from_wire_parser( + rdclass, rdtype, self.parser, self.message.origin + ) covers = rd.covers() if self.message.xfr and rdtype == dns.rdatatype.SOA: force_unique = True @@ -963,8 +1115,7 @@ class _WireReader: self.message.opt = dns.rrset.from_rdata(name, ttl, rd) elif rdtype == dns.rdatatype.TSIG: if self.keyring is None: - raise UnknownTSIGKey('got signed message without ' - 'keyring') + raise UnknownTSIGKey("got signed message without keyring") if isinstance(self.keyring, dict): key = self.keyring.get(absolute_name) if isinstance(key, bytes): @@ -976,25 +1127,31 @@ class _WireReader: if key is None: raise UnknownTSIGKey("key '%s' unknown" % name) self.message.keyring = key - self.message.tsig_ctx = \ - dns.tsig.validate(self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi) - self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, - rd) + self.message.tsig_ctx = dns.tsig.validate( + self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi, + ) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) + rrset = self.message.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + True, + force_unique, + ) if rd is not None: - if ttl > 0x7fffffff: + if ttl > 0x7FFFFFFF: ttl = 0 rrset.add(rd, ttl) except Exception as e: @@ -1010,14 +1167,16 @@ class _WireReader: if self.parser.remaining() < 12: raise ShortHeader - (id, flags, qcount, ancount, aucount, adcount) = \ - self.parser.get_struct('!HHHHHH') + (id, flags, qcount, ancount, aucount, adcount) = self.parser.get_struct( + "!HHHHHH" + ) factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) self.message = factory(id=id) self.message.flags = dns.flags.Flag(flags) self.initialize_message(self.message) - self.one_rr_per_rrset = \ - self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) + self.one_rr_per_rrset = self.message._get_one_rr_per_rrset( + self.one_rr_per_rrset + ) try: self._get_question(MessageSection.QUESTION, qcount) if self.question_only: @@ -1027,8 +1186,7 @@ class _WireReader: self._get_section(MessageSection.ADDITIONAL, adcount) if not self.ignore_trailing and self.parser.remaining() != 0: raise TrailingJunk - if self.multi and self.message.tsig_ctx and \ - not self.message.had_tsig: + if self.multi and self.message.tsig_ctx and not self.message.had_tsig: self.message.tsig_ctx.update(self.parser.wire) except Exception as e: if self.continue_on_error: @@ -1038,83 +1196,103 @@ class _WireReader: return self.message -def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, - tsig_ctx=None, multi=False, - question_only=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False, - continue_on_error=False): +def from_wire( + wire: bytes, + keyring: Optional[Any] = None, + request_mac: Optional[bytes] = b"", + xfr: bool = False, + origin: Optional[dns.name.Name] = None, + tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, + multi: bool = False, + question_only: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + continue_on_error: bool = False, +) -> Message: """Convert a DNS wire format message into a message object. - *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the - message is signed. + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message + is signed. - *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed - request, *request_mac* should be set to the MAC of that request. + *request_mac*, a ``bytes`` or ``None``. If the message is a response to a + TSIG-signed request, *request_mac* should be set to the MAC of that request. - *xfr*, a ``bool``, should be set to ``True`` if this message is part of a - zone transfer. + *xfr*, a ``bool``, should be set to ``True`` if this message is part of a zone + transfer. *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone - transfer, *origin* should be the origin name of the zone. If not ``None``, - names will be relativized to the origin. + transfer, *origin* should be the origin name of the zone. If not ``None``, names + will be relativized to the origin. - *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the - ongoing TSIG context, used when validating zone transfers. + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG + context, used when validating zone transfers. - *multi*, a ``bool``, should be set to ``True`` if this message is part of a - multiple message sequence. + *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple + message sequence. - *question_only*, a ``bool``. If ``True``, read only up to the end of the - question section. + *question_only*, a ``bool``. If ``True``, read only up to the end of the question + section. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of - the message. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + message. - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the - TC bit is set. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the TC bit is + set. - *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even - if errors occur. Erroneous rdata will be ignored. Errors will be - accumulated as a list of MessageError objects in the message's ``errors`` - attribute. This option is recommended only for DNS analysis tools, or for - use in a server as part of an error handling path. The default is - ``False``. + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even if + errors occur. Erroneous rdata will be ignored. Errors will be accumulated as a + list of MessageError objects in the message's ``errors`` attribute. This option is + recommended only for DNS analysis tools, or for use in a server as part of an error + handling path. The default is ``False``. - Raises ``dns.message.ShortHeader`` if the message is less than 12 octets - long. + Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. - Raises ``dns.message.TrailingJunk`` if there were octets in the message past - the end of the proper DNS message, and *ignore_trailing* is ``False``. + Raises ``dns.message.TrailingJunk`` if there were octets in the message past the end + of the proper DNS message, and *ignore_trailing* is ``False``. Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or occurred more than once. - Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of - the additional data section. + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of the + additional data section. - Raises ``dns.message.Truncated`` if the TC flag is set and - *raise_on_truncation* is ``True``. + Raises ``dns.message.Truncated`` if the TC flag is set and *raise_on_truncation* is + ``True``. Returns a ``dns.message.Message``. """ + # We permit None for request_mac solely for backwards compatibility + if request_mac is None: + request_mac = b"" + def initialize_message(message): message.request_mac = request_mac message.xfr = xfr message.origin = origin message.tsig_ctx = tsig_ctx - reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing, keyring, multi, - continue_on_error) + reader = _WireReader( + wire, + initialize_message, + question_only, + one_rr_per_rrset, + ignore_trailing, + keyring, + multi, + continue_on_error, + ) try: m = reader.read() except dns.exception.FormError: - if reader.message and (reader.message.flags & dns.flags.TC) and \ - raise_on_truncation: + if ( + reader.message + and (reader.message.flags & dns.flags.TC) + and raise_on_truncation + ): raise Truncated(message=reader.message) else: raise @@ -1142,8 +1320,15 @@ class _TextReader: relativize_to: the origin to relativize to. """ - def __init__(self, text, idna_codec, one_rr_per_rrset=False, - origin=None, relativize=True, relativize_to=None): + def __init__( + self, + text, + idna_codec, + one_rr_per_rrset=False, + origin=None, + relativize=True, + relativize_to=None, + ): self.message = None self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) self.last_name = None @@ -1164,19 +1349,19 @@ class _TextReader: token = self.tok.get() what = token.value - if what == 'id': + if what == "id": self.id = self.tok.get_int() - elif what == 'flags': + elif what == "flags": while True: token = self.tok.get() if not token.is_identifier(): self.tok.unget(token) break self.flags = self.flags | dns.flags.from_text(token.value) - elif what == 'edns': + elif what == "edns": self.edns = self.tok.get_int() self.ednsflags = self.ednsflags | (self.edns << 16) - elif what == 'eflags': + elif what == "eflags": if self.edns < 0: self.edns = 0 while True: @@ -1184,17 +1369,16 @@ class _TextReader: if not token.is_identifier(): self.tok.unget(token) break - self.ednsflags = self.ednsflags | \ - dns.flags.edns_from_text(token.value) - elif what == 'payload': + self.ednsflags = self.ednsflags | dns.flags.edns_from_text(token.value) + elif what == "payload": self.payload = self.tok.get_int() if self.edns < 0: self.edns = 0 - elif what == 'opcode': + elif what == "opcode": text = self.tok.get_string() self.opcode = dns.opcode.from_text(text) self.flags = self.flags | dns.opcode.to_flags(self.opcode) - elif what == 'rcode': + elif what == "rcode": text = self.tok.get_string() self.rcode = dns.rcode.from_text(text) else: @@ -1207,9 +1391,9 @@ class _TextReader: section = self.message.sections[section_number] token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.message.origin, - self.relativize, - self.relativize_to) + self.last_name = self.tok.as_name( + token, self.message.origin, self.relativize, self.relativize_to + ) name = self.last_name if name is None: raise NoPreviousName @@ -1228,10 +1412,12 @@ class _TextReader: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) - (rdclass, rdtype, _, _) = \ - self.message._parse_rr_header(section_number, name, rdclass, rdtype) - self.message.find_rrset(section, name, rdclass, rdtype, create=True, - force_unique=True) + (rdclass, rdtype, _, _) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) + self.message.find_rrset( + section, name, rdclass, rdtype, create=True, force_unique=True + ) self.tok.get_eol() def _rr_line(self, section_number): @@ -1243,9 +1429,9 @@ class _TextReader: # Name token = self.tok.get(want_leading=True) if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.message.origin, - self.relativize, - self.relativize_to) + self.last_name = self.tok.as_name( + token, self.message.origin, self.relativize, self.relativize_to + ) name = self.last_name if name is None: raise NoPreviousName @@ -1274,8 +1460,9 @@ class _TextReader: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_rr_header(section_number, name, rdclass, rdtype) + (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) token = self.tok.get() if empty and not token.is_eol_or_eof(): raise dns.exception.SyntaxError @@ -1283,16 +1470,28 @@ class _TextReader: raise dns.exception.UnexpectedEnd if not token.is_eol_or_eof(): self.tok.unget(token) - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.message.origin, self.relativize, - self.relativize_to) + rd = dns.rdata.from_text( + rdclass, + rdtype, + self.tok, + self.message.origin, + self.relativize, + self.relativize_to, + ) covers = rd.covers() else: rd = None covers = dns.rdatatype.NONE - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, self.one_rr_per_rrset) + rrset = self.message.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + True, + self.one_rr_per_rrset, + ) if rd is not None: rrset.add(rd, ttl) @@ -1320,7 +1519,7 @@ class _TextReader: break if token.is_comment(): u = token.value.upper() - if u == 'HEADER': + if u == "HEADER": line_method = self._header_line if self.message: @@ -1335,8 +1534,9 @@ class _TextReader: # use the one we just created. if not self.message: self.message = message - self.one_rr_per_rrset = \ - message._get_one_rr_per_rrset(self.one_rr_per_rrset) + self.one_rr_per_rrset = message._get_one_rr_per_rrset( + self.one_rr_per_rrset + ) if section_number == MessageSection.QUESTION: line_method = self._question_line else: @@ -1353,8 +1553,14 @@ class _TextReader: return self.message -def from_text(text, idna_codec=None, one_rr_per_rrset=False, - origin=None, relativize=True, relativize_to=None): +def from_text( + text: str, + idna_codec: Optional[dns.name.IDNACodec] = None, + one_rr_per_rrset: bool = False, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> Message: """Convert the text format message into a message object. The reader stops after reading the first blank line in the input to @@ -1389,12 +1595,17 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False, # since it's an implementation detail. The official file # interface is from_file(). - reader = _TextReader(text, idna_codec, one_rr_per_rrset, origin, - relativize, relativize_to) + reader = _TextReader( + text, idna_codec, one_rr_per_rrset, origin, relativize, relativize_to + ) return reader.read() -def from_file(f, idna_codec=None, one_rr_per_rrset=False): +def from_file( + f: Any, + idna_codec: Optional[dns.name.IDNACodec] = None, + one_rr_per_rrset: bool = False, +) -> Message: """Read the next text format message from the specified file. Message blocks are separated by a single blank line. @@ -1416,16 +1627,30 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False): Returns a ``dns.message.Message object`` """ - with contextlib.ExitStack() as stack: - if isinstance(f, str): - f = stack.enter_context(open(f)) + if isinstance(f, str): + cm: contextlib.AbstractContextManager = open(f) + else: + cm = contextlib.nullcontext(f) + with cm as f: return from_text(f, idna_codec, one_rr_per_rrset) + assert False # for mypy lgtm[py/unreachable-statement] -def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, - want_dnssec=False, ednsflags=None, payload=None, - request_payload=None, options=None, idna_codec=None, - id=None, flags=dns.flags.RD): +def make_query( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + use_edns: Optional[Union[int, bool]] = None, + want_dnssec: bool = False, + ednsflags: Optional[int] = None, + payload: Optional[int] = None, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + id: Optional[int] = None, + flags: int = dns.flags.RD, + pad: int = 0, +) -> QueryMessage: """Make a query message. The query name, type, and class may all be specified either @@ -1473,39 +1698,51 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, *flags*, an ``int``, the desired query flags. The default is ``dns.flags.RD``. + *pad*, a non-negative ``int``. If 0, the default, do not pad; otherwise add + padding bytes to make the message size a multiple of *pad*. Note that if + padding is non-zero, an EDNS PADDING option will always be added to the + message. + Returns a ``dns.message.QueryMessage`` """ if isinstance(qname, str): qname = dns.name.from_text(qname, idna_codec=idna_codec) - rdtype = dns.rdatatype.RdataType.make(rdtype) - rdclass = dns.rdataclass.RdataClass.make(rdclass) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) m.flags = dns.flags.Flag(flags) - m.find_rrset(m.question, qname, rdclass, rdtype, create=True, - force_unique=True) + m.find_rrset( + m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True + ) # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't # been configured. - kwargs = {} + kwargs: Dict[str, Any] = {} if ednsflags is not None: - kwargs['ednsflags'] = ednsflags + kwargs["ednsflags"] = ednsflags if payload is not None: - kwargs['payload'] = payload + kwargs["payload"] = payload if request_payload is not None: - kwargs['request_payload'] = request_payload + kwargs["request_payload"] = request_payload if options is not None: - kwargs['options'] = options + kwargs["options"] = options if kwargs and use_edns is None: use_edns = 0 - kwargs['edns'] = use_edns + kwargs["edns"] = use_edns + kwargs["pad"] = pad m.use_edns(**kwargs) m.want_dnssec(want_dnssec) return m -def make_response(query, recursion_available=False, our_payload=8192, - fudge=300, tsig_error=0): +def make_response( + query: Message, + recursion_available: bool = False, + our_payload: int = 8192, + fudge: int = 300, + tsig_error: int = 0, +) -> Message: """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the @@ -1532,7 +1769,7 @@ def make_response(query, recursion_available=False, our_payload=8192, """ if query.flags & dns.flags.QR: - raise dns.exception.FormError('specified query message is not a query') + raise dns.exception.FormError("specified query message is not a query") factory = _message_factory_from_opcode(query.opcode()) response = factory(id=query.id) response.flags = dns.flags.QR | (query.flags & dns.flags.RD) @@ -1543,11 +1780,19 @@ def make_response(query, recursion_available=False, our_payload=8192, if query.edns >= 0: response.use_edns(0, 0, our_payload, query.payload) if query.had_tsig: - response.use_tsig(query.keyring, query.keyname, fudge, None, - tsig_error, b'', query.keyalgorithm) + response.use_tsig( + query.keyring, + query.keyname, + fudge, + None, + tsig_error, + b"", + query.keyalgorithm, + ) response.request_mac = query.mac return response + ### BEGIN generated MessageSection constants QUESTION = MessageSection.QUESTION diff --git a/lib/dns/message.pyi b/lib/dns/message.pyi deleted file mode 100644 index 252a4118..00000000 --- a/lib/dns/message.pyi +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Optional, Dict, List, Tuple, Union -from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode -import hmac - -class Message: - def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes: - ... - def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int, - covers=rdatatype.NONE, deleting : Optional[int]=None, create=False, - force_unique=False) -> rrset.RRset: - ... - def __init__(self, id : Optional[int] =None) -> None: - self.id : int - self.flags = 0 - self.sections : List[List[rrset.RRset]] = [[], [], [], []] - self.opt : rrset.RRset = None - self.request_payload = 0 - self.keyring = None - self.tsig : rrset.RRset = None - self.request_mac = b'' - self.xfr = False - self.origin = None - self.tsig_ctx = None - self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {} - - def is_response(self, other : Message) -> bool: - ... - - def set_rcode(self, rcode : rcode.Rcode): - ... - -def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message: - ... - -def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None, - tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False, - question_only=False, one_rr_per_rrset=False, - ignore_trailing=False) -> Message: - ... -def make_response(query : Message, recursion_available=False, our_payload=8192, - fudge=300) -> Message: - ... - -def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None, - want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None, - request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message: - ... diff --git a/lib/dns/name.py b/lib/dns/name.py index 8905d70f..612af021 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -18,32 +18,61 @@ """DNS Names. """ +from typing import Any, Dict, Iterable, Optional, Tuple, Union + import copy import struct -import encodings.idna # type: ignore +import encodings.idna # type: ignore + try: - import idna # type: ignore + import idna # type: ignore + have_idna_2008 = True except ImportError: # pragma: no cover have_idna_2008 = False +import dns.enum import dns.wire import dns.exception import dns.immutable -# fullcompare() result values -#: The compared names have no relationship to each other. -NAMERELN_NONE = 0 -#: the first name is a superdomain of the second. -NAMERELN_SUPERDOMAIN = 1 -#: The first name is a subdomain of the second. -NAMERELN_SUBDOMAIN = 2 -#: The compared names are equal. -NAMERELN_EQUAL = 3 -#: The compared names have a common ancestor. -NAMERELN_COMMONANCESTOR = 4 +CompressType = Dict["Name", int] + + +class NameRelation(dns.enum.IntEnum): + """Name relation result from fullcompare().""" + + # This is an IntEnum for backwards compatibility in case anyone + # has hardwired the constants. + + #: The compared names have no relationship to each other. + NONE = 0 + #: the first name is a superdomain of the second. + SUPERDOMAIN = 1 + #: The first name is a subdomain of the second. + SUBDOMAIN = 2 + #: The compared names are equal. + EQUAL = 3 + #: The compared names have a common ancestor. + COMMONANCESTOR = 4 + + @classmethod + def _maximum(cls): + return cls.COMMONANCESTOR + + @classmethod + def _short_name(cls): + return cls.__name__ + + +# Backwards compatibility +NAMERELN_NONE = NameRelation.NONE +NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN +NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN +NAMERELN_EQUAL = NameRelation.EQUAL +NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR class EmptyLabel(dns.exception.SyntaxError): @@ -84,6 +113,7 @@ class NoParent(dns.exception.DNSException): """An attempt was made to get the parent of the root name or the empty name.""" + class NoIDNA2008(dns.exception.DNSException): """IDNA 2008 processing was requested but the idna module is not available.""" @@ -92,9 +122,47 @@ class NoIDNA2008(dns.exception.DNSException): class IDNAException(dns.exception.DNSException): """IDNA processing raised an exception.""" - supp_kwargs = {'idna_exception'} + supp_kwargs = {"idna_exception"} fmt = "IDNA processing exception: {idna_exception}" + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +_escaped = b'"().;\\@$' +_escaped_text = '"().;\\@$' + + +def _escapify(label: Union[bytes, str]) -> str: + """Escape the characters in label which need it. + @returns: the escaped string + @rtype: string""" + if isinstance(label, bytes): + # Ordinary DNS label mode. Escape special characters and values + # < 0x20 or > 0x7f. + text = "" + for c in label: + if c in _escaped: + text += "\\" + chr(c) + elif c > 0x20 and c < 0x7F: + text += chr(c) + else: + text += "\\%03d" % c + return text + + # Unicode label mode. Escape only special characters and values < 0x20 + text = "" + for uc in label: + if uc in _escaped_text: + text += "\\" + uc + elif uc <= "\x20": + text += "\\%03d" % ord(uc) + else: + text += uc + return text + class IDNACodec: """Abstract base class for IDNA encoder/decoders.""" @@ -102,26 +170,28 @@ class IDNACodec: def __init__(self): pass - def is_idna(self, label): - return label.lower().startswith(b'xn--') + def is_idna(self, label: bytes) -> bool: + return label.lower().startswith(b"xn--") - def encode(self, label): + def encode(self, label: str) -> bytes: raise NotImplementedError # pragma: no cover - def decode(self, label): + def decode(self, label: bytes) -> str: # We do not apply any IDNA policy on decode. if self.is_idna(label): try: - label = label[4:].decode('punycode') + slabel = label[4:].decode("punycode") + return _escapify(slabel) except Exception as e: raise IDNAException(idna_exception=e) - return _escapify(label) + else: + return _escapify(label) class IDNA2003Codec(IDNACodec): """IDNA 2003 encoder/decoder.""" - def __init__(self, strict_decode=False): + def __init__(self, strict_decode: bool = False): """Initialize the IDNA 2003 encoder/decoder. *strict_decode* is a ``bool``. If `True`, then IDNA2003 checking @@ -132,22 +202,22 @@ class IDNA2003Codec(IDNACodec): super().__init__() self.strict_decode = strict_decode - def encode(self, label): + def encode(self, label: str) -> bytes: """Encode *label*.""" - if label == '': - return b'' + if label == "": + return b"" try: return encodings.idna.ToASCII(label) except UnicodeError: raise LabelTooLong - def decode(self, label): + def decode(self, label: bytes) -> str: """Decode *label*.""" if not self.strict_decode: return super().decode(label) - if label == b'': - return '' + if label == b"": + return "" try: return _escapify(encodings.idna.ToUnicode(label)) except Exception as e: @@ -155,16 +225,20 @@ class IDNA2003Codec(IDNACodec): class IDNA2008Codec(IDNACodec): - """IDNA 2008 encoder/decoder. - """ + """IDNA 2008 encoder/decoder.""" - def __init__(self, uts_46=False, transitional=False, - allow_pure_ascii=False, strict_decode=False): + def __init__( + self, + uts_46: bool = False, + transitional: bool = False, + allow_pure_ascii: bool = False, + strict_decode: bool = False, + ): """Initialize the IDNA 2008 encoder/decoder. *uts_46* is a ``bool``. If True, apply Unicode IDNA compatibility processing as described in Unicode Technical - Standard #46 (http://unicode.org/reports/tr46/). + Standard #46 (https://unicode.org/reports/tr46/). If False, do not apply the mapping. The default is False. *transitional* is a ``bool``: If True, use the @@ -188,11 +262,11 @@ class IDNA2008Codec(IDNACodec): self.allow_pure_ascii = allow_pure_ascii self.strict_decode = strict_decode - def encode(self, label): - if label == '': - return b'' + def encode(self, label: str) -> bytes: + if label == "": + return b"" if self.allow_pure_ascii and is_all_ascii(label): - encoded = label.encode('ascii') + encoded = label.encode("ascii") if len(encoded) > 63: raise LabelTooLong return encoded @@ -203,16 +277,16 @@ class IDNA2008Codec(IDNACodec): label = idna.uts46_remap(label, False, self.transitional) return idna.alabel(label) except idna.IDNAError as e: - if e.args[0] == 'Label too long': + if e.args[0] == "Label too long": raise LabelTooLong else: raise IDNAException(idna_exception=e) - def decode(self, label): + def decode(self, label: bytes) -> str: if not self.strict_decode: return super().decode(label) - if label == b'': - return '' + if label == b"": + return "" if not have_idna_2008: raise NoIDNA2008 try: @@ -223,8 +297,6 @@ class IDNA2008Codec(IDNACodec): except (idna.IDNAError, UnicodeError) as e: raise IDNAException(idna_exception=e) -_escaped = b'"().;\\@$' -_escaped_text = '"().;\\@$' IDNA_2003_Practical = IDNA2003Codec(False) IDNA_2003_Strict = IDNA2003Codec(True) @@ -235,35 +307,8 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True) IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) IDNA_2008 = IDNA_2008_Practical -def _escapify(label): - """Escape the characters in label which need it. - @returns: the escaped string - @rtype: string""" - if isinstance(label, bytes): - # Ordinary DNS label mode. Escape special characters and values - # < 0x20 or > 0x7f. - text = '' - for c in label: - if c in _escaped: - text += '\\' + chr(c) - elif c > 0x20 and c < 0x7F: - text += chr(c) - else: - text += '\\%03d' % c - return text - # Unicode label mode. Escape only special characters and values < 0x20 - text = '' - for c in label: - if c in _escaped_text: - text += '\\' + c - elif c <= '\x20': - text += '\\%03d' % ord(c) - else: - text += c - return text - -def _validate_labels(labels): +def _validate_labels(labels: Tuple[bytes, ...]) -> None: """Check for empty labels in the middle of a label sequence, labels that are too long, and for too many labels. @@ -284,7 +329,7 @@ def _validate_labels(labels): total += ll + 1 if ll > 63: raise LabelTooLong - if i < 0 and label == b'': + if i < 0 and label == b"": i = j j += 1 if total > 255: @@ -293,7 +338,7 @@ def _validate_labels(labels): raise EmptyLabel -def _maybe_convert_to_binary(label): +def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes: """If label is ``str``, convert it to ``bytes``. If it is already ``bytes`` just return it. @@ -316,14 +361,13 @@ class Name: of the class are immutable. """ - __slots__ = ['labels'] + __slots__ = ["labels"] - def __init__(self, labels): - """*labels* is any iterable whose values are ``str`` or ``bytes``. - """ + def __init__(self, labels: Iterable[Union[bytes, str]]): + """*labels* is any iterable whose values are ``str`` or ``bytes``.""" - labels = [_maybe_convert_to_binary(x) for x in labels] - self.labels = tuple(labels) + blabels = [_maybe_convert_to_binary(x) for x in labels] + self.labels = tuple(blabels) _validate_labels(self.labels) def __copy__(self): @@ -334,29 +378,29 @@ class Name: def __getstate__(self): # Names can be pickled - return {'labels': self.labels} + return {"labels": self.labels} def __setstate__(self, state): - super().__setattr__('labels', state['labels']) + super().__setattr__("labels", state["labels"]) _validate_labels(self.labels) - def is_absolute(self): + def is_absolute(self) -> bool: """Is the most significant label of this name the root label? Returns a ``bool``. """ - return len(self.labels) > 0 and self.labels[-1] == b'' + return len(self.labels) > 0 and self.labels[-1] == b"" - def is_wild(self): + def is_wild(self) -> bool: """Is this name wild? (I.e. Is the least significant label '*'?) Returns a ``bool``. """ - return len(self.labels) > 0 and self.labels[0] == b'*' + return len(self.labels) > 0 and self.labels[0] == b"*" - def __hash__(self): + def __hash__(self) -> int: """Return a case-insensitive hash of the name. Returns an ``int``. @@ -368,14 +412,14 @@ class Name: h += (h << 3) + c return h - def fullcompare(self, other): + def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]: """Compare two names, returning a 3-tuple ``(relation, order, nlabels)``. *relation* describes the relation ship between the names, - and is one of: ``dns.name.NAMERELN_NONE``, - ``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``, - ``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``. + and is one of: ``dns.name.NameRelation.NONE``, + ``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``, + ``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``. *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and == 0 if *self* == *other*. A relative name is always less than an @@ -404,9 +448,9 @@ class Name: oabs = other.is_absolute() if sabs != oabs: if sabs: - return (NAMERELN_NONE, 1, 0) + return (NameRelation.NONE, 1, 0) else: - return (NAMERELN_NONE, -1, 0) + return (NameRelation.NONE, -1, 0) l1 = len(self.labels) l2 = len(other.labels) ldiff = l1 - l2 @@ -417,7 +461,7 @@ class Name: order = 0 nlabels = 0 - namereln = NAMERELN_NONE + namereln = NameRelation.NONE while l > 0: l -= 1 l1 -= 1 @@ -427,52 +471,52 @@ class Name: if label1 < label2: order = -1 if nlabels > 0: - namereln = NAMERELN_COMMONANCESTOR + namereln = NameRelation.COMMONANCESTOR return (namereln, order, nlabels) elif label1 > label2: order = 1 if nlabels > 0: - namereln = NAMERELN_COMMONANCESTOR + namereln = NameRelation.COMMONANCESTOR return (namereln, order, nlabels) nlabels += 1 order = ldiff if ldiff < 0: - namereln = NAMERELN_SUPERDOMAIN + namereln = NameRelation.SUPERDOMAIN elif ldiff > 0: - namereln = NAMERELN_SUBDOMAIN + namereln = NameRelation.SUBDOMAIN else: - namereln = NAMERELN_EQUAL + namereln = NameRelation.EQUAL return (namereln, order, nlabels) - def is_subdomain(self, other): + def is_subdomain(self, other: "Name") -> bool: """Is self a subdomain of other? Note that the notion of subdomain includes equality, e.g. - "dnpython.org" is a subdomain of itself. + "dnspython.org" is a subdomain of itself. Returns a ``bool``. """ (nr, _, _) = self.fullcompare(other) - if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL: + if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL: return True return False - def is_superdomain(self, other): + def is_superdomain(self, other: "Name") -> bool: """Is self a superdomain of other? Note that the notion of superdomain includes equality, e.g. - "dnpython.org" is a superdomain of itself. + "dnspython.org" is a superdomain of itself. Returns a ``bool``. """ (nr, _, _) = self.fullcompare(other) - if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL: + if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL: return True return False - def canonicalize(self): + def canonicalize(self) -> "Name": """Return a name which is equal to the current name, but is in DNSSEC canonical form. """ @@ -516,12 +560,12 @@ class Name: return NotImplemented def __repr__(self): - return '' + return "" def __str__(self): return self.to_text(False) - def to_text(self, omit_final_dot=False): + def to_text(self, omit_final_dot: bool = False) -> str: """Convert name to DNS text format. *omit_final_dot* is a ``bool``. If True, don't emit the final @@ -532,17 +576,19 @@ class Name: """ if len(self.labels) == 0: - return '@' - if len(self.labels) == 1 and self.labels[0] == b'': - return '.' + return "@" + if len(self.labels) == 1 and self.labels[0] == b"": + return "." if omit_final_dot and self.is_absolute(): l = self.labels[:-1] else: l = self.labels - s = '.'.join(map(_escapify, l)) + s = ".".join(map(_escapify, l)) return s - def to_unicode(self, omit_final_dot=False, idna_codec=None): + def to_unicode( + self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None + ) -> str: """Convert name to Unicode text format. IDN ACE labels are converted to Unicode. @@ -561,18 +607,18 @@ class Name: """ if len(self.labels) == 0: - return '@' - if len(self.labels) == 1 and self.labels[0] == b'': - return '.' + return "@" + if len(self.labels) == 1 and self.labels[0] == b"": + return "." if omit_final_dot and self.is_absolute(): l = self.labels[:-1] else: l = self.labels if idna_codec is None: idna_codec = IDNA_2003_Practical - return '.'.join([idna_codec.decode(x) for x in l]) + return ".".join([idna_codec.decode(x) for x in l]) - def to_digestable(self, origin=None): + def to_digestable(self, origin: Optional["Name"] = None) -> bytes: """Convert name to a format suitable for digesting in hashes. The name is canonicalized and converted to uncompressed wire @@ -589,10 +635,17 @@ class Name: Returns a ``bytes``. """ - return self.to_wire(origin=origin, canonicalize=True) + digest = self.to_wire(origin=origin, canonicalize=True) + assert digest is not None + return digest - def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False): + def to_wire( + self, + file: Optional[Any] = None, + compress: Optional[CompressType] = None, + origin: Optional["Name"] = None, + canonicalize: bool = False, + ) -> Optional[bytes]: """Convert name to wire format, possibly compressing it. *file* is the file where the name is emitted (typically an @@ -638,6 +691,7 @@ class Name: out += label return bytes(out) + labels: Iterable[bytes] if not self.is_absolute(): if origin is None or not origin.is_absolute(): raise NeedAbsoluteNameOrOrigin @@ -654,24 +708,25 @@ class Name: else: pos = None if pos is not None: - value = 0xc000 + pos - s = struct.pack('!H', value) + value = 0xC000 + pos + s = struct.pack("!H", value) file.write(s) break else: if compress is not None and len(n) > 1: pos = file.tell() - if pos <= 0x3fff: + if pos <= 0x3FFF: compress[n] = pos l = len(label) - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) if l > 0: if canonicalize: file.write(label.lower()) else: file.write(label) + return None - def __len__(self): + def __len__(self) -> int: """The length of the name (in labels). Returns an ``int``. @@ -688,7 +743,7 @@ class Name: def __sub__(self, other): return self.relativize(other) - def split(self, depth): + def split(self, depth: int) -> Tuple["Name", "Name"]: """Split a name into a prefix and suffix names at the specified depth. *depth* is an ``int`` specifying the number of labels in the suffix @@ -705,11 +760,10 @@ class Name: elif depth == l: return (dns.name.empty, self) elif depth < 0 or depth > l: - raise ValueError( - 'depth must be >= 0 and <= the length of the name') - return (Name(self[: -depth]), Name(self[-depth:])) + raise ValueError("depth must be >= 0 and <= the length of the name") + return (Name(self[:-depth]), Name(self[-depth:])) - def concatenate(self, other): + def concatenate(self, other: "Name") -> "Name": """Return a new name which is the concatenation of self and other. Raises ``dns.name.AbsoluteConcatenation`` if the name is @@ -724,7 +778,7 @@ class Name: labels.extend(list(other.labels)) return Name(labels) - def relativize(self, origin): + def relativize(self, origin: "Name") -> "Name": """If the name is a subdomain of *origin*, return a new name which is the name relative to origin. Otherwise return the name. @@ -740,7 +794,7 @@ class Name: else: return self - def derelativize(self, origin): + def derelativize(self, origin: "Name") -> "Name": """If the name is a relative name, return a new name which is the concatenation of the name and origin. Otherwise return the name. @@ -756,7 +810,9 @@ class Name: else: return self - def choose_relativity(self, origin=None, relativize=True): + def choose_relativity( + self, origin: Optional["Name"] = None, relativize: bool = True + ) -> "Name": """Return a name with the relativity desired by the caller. If *origin* is ``None``, then the name is returned. @@ -775,7 +831,7 @@ class Name: else: return self - def parent(self): + def parent(self) -> "Name": """Return the parent of the name. For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. @@ -790,13 +846,17 @@ class Name: raise NoParent return Name(self.labels[1:]) + #: The root name, '.' -root = Name([b'']) +root = Name([b""]) #: The empty name. empty = Name([]) -def from_unicode(text, origin=root, idna_codec=None): + +def from_unicode( + text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None +) -> Name: """Convert unicode text into a Name object. Labels are encoded in IDN ACE form according to rules specified by @@ -819,17 +879,17 @@ def from_unicode(text, origin=root, idna_codec=None): if not (origin is None or isinstance(origin, Name)): raise ValueError("origin must be a Name or None") labels = [] - label = '' + label = "" escaping = False edigits = 0 total = 0 if idna_codec is None: idna_codec = IDNA_2003 - if text == '@': - text = '' + if text == "@": + text = "" if text: - if text in ['.', '\u3002', '\uff0e', '\uff61']: - return Name([b'']) # no Unicode "u" on this constant! + if text in [".", "\u3002", "\uff0e", "\uff61"]: + return Name([b""]) # no Unicode "u" on this constant! for c in text: if escaping: if edigits == 0: @@ -848,12 +908,12 @@ def from_unicode(text, origin=root, idna_codec=None): if edigits == 3: escaping = False label += chr(total) - elif c in ['.', '\u3002', '\uff0e', '\uff61']: + elif c in [".", "\u3002", "\uff0e", "\uff61"]: if len(label) == 0: raise EmptyLabel labels.append(idna_codec.encode(label)) - label = '' - elif c == '\\': + label = "" + elif c == "\\": escaping = True edigits = 0 total = 0 @@ -864,22 +924,28 @@ def from_unicode(text, origin=root, idna_codec=None): if len(label) > 0: labels.append(idna_codec.encode(label)) else: - labels.append(b'') + labels.append(b"") - if (len(labels) == 0 or labels[-1] != b'') and origin is not None: + if (len(labels) == 0 or labels[-1] != b"") and origin is not None: labels.extend(list(origin.labels)) return Name(labels) -def is_all_ascii(text): + +def is_all_ascii(text: str) -> bool: for c in text: - if ord(c) > 0x7f: + if ord(c) > 0x7F: return False return True -def from_text(text, origin=root, idna_codec=None): + +def from_text( + text: Union[bytes, str], + origin: Optional[Name] = root, + idna_codec: Optional[IDNACodec] = None, +) -> Name: """Convert text into a Name object. - *text*, a ``str``, is the text to convert into a name. + *text*, a ``bytes`` or ``str``, is the text to convert into a name. *origin*, a ``dns.name.Name``, specifies the origin to append to non-absolute names. The default is the root name. @@ -903,23 +969,23 @@ def from_text(text, origin=root, idna_codec=None): # # then it's still "all ASCII" even though the domain name has # codepoints > 127. - text = text.encode('ascii') + text = text.encode("ascii") if not isinstance(text, bytes): raise ValueError("input to from_text() must be a string") if not (origin is None or isinstance(origin, Name)): raise ValueError("origin must be a Name or None") labels = [] - label = b'' + label = b"" escaping = False edigits = 0 total = 0 - if text == b'@': - text = b'' + if text == b"@": + text = b"" if text: - if text == b'.': - return Name([b'']) + if text == b".": + return Name([b""]) for c in text: - byte_ = struct.pack('!B', c) + byte_ = struct.pack("!B", c) if escaping: if edigits == 0: if byte_.isdigit(): @@ -936,13 +1002,13 @@ def from_text(text, origin=root, idna_codec=None): edigits += 1 if edigits == 3: escaping = False - label += struct.pack('!B', total) - elif byte_ == b'.': + label += struct.pack("!B", total) + elif byte_ == b".": if len(label) == 0: raise EmptyLabel labels.append(label) - label = b'' - elif byte_ == b'\\': + label = b"" + elif byte_ == b"\\": escaping = True edigits = 0 total = 0 @@ -953,13 +1019,16 @@ def from_text(text, origin=root, idna_codec=None): if len(label) > 0: labels.append(label) else: - labels.append(b'') - if (len(labels) == 0 or labels[-1] != b'') and origin is not None: + labels.append(b"") + if (len(labels) == 0 or labels[-1] != b"") and origin is not None: labels.extend(list(origin.labels)) return Name(labels) -def from_wire_parser(parser): +# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other. + + +def from_wire_parser(parser: "dns.wire.Parser") -> Name: """Convert possibly compressed wire format into a Name. *parser* is a dns.wire.Parser. @@ -980,7 +1049,7 @@ def from_wire_parser(parser): if count < 64: labels.append(parser.get_bytes(count)) elif count >= 192: - current = (count & 0x3f) * 256 + parser.get_uint8() + current = (count & 0x3F) * 256 + parser.get_uint8() if current >= biggest_pointer: raise BadPointer biggest_pointer = current @@ -988,11 +1057,11 @@ def from_wire_parser(parser): else: raise BadLabelType count = parser.get_uint8() - labels.append(b'') + labels.append(b"") return Name(labels) -def from_wire(message, current): +def from_wire(message: bytes, current: int) -> Tuple[Name, int]: """Convert possibly compressed wire format into a Name. *message* is a ``bytes`` containing an entire DNS message in DNS diff --git a/lib/dns/name.pyi b/lib/dns/name.pyi deleted file mode 100644 index c48d4bd1..00000000 --- a/lib/dns/name.pyi +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional, Union, Tuple, Iterable, List - -have_idna_2008: bool - -class Name: - def is_subdomain(self, o : Name) -> bool: ... - def is_superdomain(self, o : Name) -> bool: ... - def __init__(self, labels : Iterable[Union[bytes,str]]) -> None: - self.labels : List[bytes] - def is_absolute(self) -> bool: ... - def is_wild(self) -> bool: ... - def fullcompare(self, other) -> Tuple[int,int,int]: ... - def canonicalize(self) -> Name: ... - def __eq__(self, other) -> bool: ... - def __ne__(self, other) -> bool: ... - def __lt__(self, other : Name) -> bool: ... - def __le__(self, other : Name) -> bool: ... - def __ge__(self, other : Name) -> bool: ... - def __gt__(self, other : Name) -> bool: ... - def to_text(self, omit_final_dot=False) -> str: ... - def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ... - def to_digestable(self, origin=None) -> bytes: ... - def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False) -> Optional[bytes]: ... - def __add__(self, other : Name) -> Name: ... - def __sub__(self, other : Name) -> Name: ... - def split(self, depth) -> List[Tuple[str,str]]: ... - def concatenate(self, other : Name) -> Name: ... - def relativize(self, origin) -> Name: ... - def derelativize(self, origin) -> Name: ... - def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ... - def parent(self) -> Name: ... - -class IDNACodec: - pass - -def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name: - ... - -empty : Name diff --git a/lib/dns/namedict.py b/lib/dns/namedict.py index ec0750ce..ca8b1978 100644 --- a/lib/dns/namedict.py +++ b/lib/dns/namedict.py @@ -27,7 +27,8 @@ """DNS name dictionary""" -from collections.abc import MutableMapping +# pylint seems to be confused about this one! +from collections.abc import MutableMapping # pylint: disable=no-name-in-module import dns.name @@ -62,7 +63,7 @@ class NameDict(MutableMapping): def __setitem__(self, key, value): if not isinstance(key, dns.name.Name): - raise ValueError('NameDict key must be a name') + raise ValueError("NameDict key must be a name") self.__store[key] = value self.__update_max_depth(key) diff --git a/lib/dns/node.py b/lib/dns/node.py index 63ce008b..22bbe7cb 100644 --- a/lib/dns/node.py +++ b/lib/dns/node.py @@ -17,12 +17,17 @@ """DNS nodes. A node is a set of rdatasets.""" +from typing import Any, Dict, Optional + import enum import io import dns.immutable +import dns.name +import dns.rdataclass import dns.rdataset import dns.rdatatype +import dns.rrset import dns.renderer @@ -32,26 +37,28 @@ _cname_types = { # "neutral" types can coexist with a CNAME and thus are not "other data" _neutral_types = { - dns.rdatatype.NSEC, # RFC 4035 section 2.5 + dns.rdatatype.NSEC, # RFC 4035 section 2.5 dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible! - dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 + dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 } + def _matches_type_or_its_signature(rdtypes, rdtype, covers): - return rdtype in rdtypes or \ - (rdtype == dns.rdatatype.RRSIG and covers in rdtypes) + return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes) @enum.unique class NodeKind(enum.Enum): - """Rdatasets in nodes - """ - REGULAR = 0 # a.k.a "other data" + """Rdatasets in nodes""" + + REGULAR = 0 # a.k.a "other data" NEUTRAL = 1 CNAME = 2 @classmethod - def classify(cls, rdtype, covers): + def classify( + cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType + ) -> "NodeKind": if _matches_type_or_its_signature(_cname_types, rdtype, covers): return NodeKind.CNAME elif _matches_type_or_its_signature(_neutral_types, rdtype, covers): @@ -60,7 +67,7 @@ class NodeKind(enum.Enum): return NodeKind.REGULAR @classmethod - def classify_rdataset(cls, rdataset): + def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind": return cls.classify(rdataset.rdtype, rdataset.covers) @@ -81,19 +88,19 @@ class Node: deleted. """ - __slots__ = ['rdatasets'] + __slots__ = ["rdatasets"] def __init__(self): # the set of rdatasets, represented as a list. self.rdatasets = [] - def to_text(self, name, **kw): + def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str: """Convert a node to text format. Each rdataset at the node is printed. Any keyword arguments to this method are passed on to the rdataset's to_text() method. - *name*, a ``dns.name.Name`` or ``str``, the owner name of the + *name*, a ``dns.name.Name``, the owner name of the rdatasets. Returns a ``str``. @@ -103,12 +110,12 @@ class Node: s = io.StringIO() for rds in self.rdatasets: if len(rds) > 0: - s.write(rds.to_text(name, **kw)) - s.write('\n') + s.write(rds.to_text(name, **kw)) # type: ignore[arg-type] + s.write("\n") return s.getvalue()[:-1] def __repr__(self): - return '' + return "" def __eq__(self, other): # @@ -144,27 +151,36 @@ class Node: if len(self.rdatasets) > 0: kind = NodeKind.classify_rdataset(rdataset) if kind == NodeKind.CNAME: - self.rdatasets = [rds for rds in self.rdatasets if - NodeKind.classify_rdataset(rds) != - NodeKind.REGULAR] + self.rdatasets = [ + rds + for rds in self.rdatasets + if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR + ] elif kind == NodeKind.REGULAR: - self.rdatasets = [rds for rds in self.rdatasets if - NodeKind.classify_rdataset(rds) != - NodeKind.CNAME] + self.rdatasets = [ + rds + for rds in self.rdatasets + if NodeKind.classify_rdataset(rds) != NodeKind.CNAME + ] # Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to # edit self.rdatasets. self.rdatasets.append(rdataset) - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: """Find an rdataset matching the specified properties in the current node. - *rdclass*, an ``int``, the class of the rdataset. + *rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset. - *rdtype*, an ``int``, the type of the rdataset. + *rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset. - *covers*, an ``int`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -191,8 +207,13 @@ class Node: self._append_rdataset(rds) return rds - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: """Get an rdataset matching the specified properties in the current node. @@ -223,7 +244,12 @@ class Node: rds = None return rds - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: """Delete the rdataset matching the specified properties in the current node. @@ -240,7 +266,7 @@ class Node: if rds is not None: self.rdatasets.remove(rds) - def replace_rdataset(self, replacement): + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: """Replace an rdataset. It is not an error if there is no rdataset matching *replacement*. @@ -256,16 +282,17 @@ class Node: """ if not isinstance(replacement, dns.rdataset.Rdataset): - raise ValueError('replacement is not an rdataset') + raise ValueError("replacement is not an rdataset") if isinstance(replacement, dns.rrset.RRset): # RRsets are not good replacements as the match() method # is not compatible. replacement = replacement.to_rdataset() - self.delete_rdataset(replacement.rdclass, replacement.rdtype, - replacement.covers) + self.delete_rdataset( + replacement.rdclass, replacement.rdtype, replacement.covers + ) self._append_rdataset(replacement) - def classify(self): + def classify(self) -> NodeKind: """Classify a node. A node which contains a CNAME or RRSIG(CNAME) is a @@ -286,7 +313,7 @@ class Node: return kind return NodeKind.NEUTRAL - def is_immutable(self): + def is_immutable(self) -> bool: return False @@ -298,23 +325,38 @@ class ImmutableNode(Node): [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] ) - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: if create: raise TypeError("immutable") return super().find_rdataset(rdclass, rdtype, covers, False) - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: if create: raise TypeError("immutable") return super().get_rdataset(rdclass, rdtype, covers, False) - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: raise TypeError("immutable") - def replace_rdataset(self, replacement): + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: raise TypeError("immutable") - def is_immutable(self): + def is_immutable(self) -> bool: return True diff --git a/lib/dns/node.pyi b/lib/dns/node.pyi deleted file mode 100644 index 0997edf9..00000000 --- a/lib/dns/node.pyi +++ /dev/null @@ -1,17 +0,0 @@ -from typing import List, Optional, Union -from . import rdataset, rdatatype, name -class Node: - def __init__(self): - self.rdatasets : List[rdataset.Rdataset] - def to_text(self, name : Union[str,name.Name], **kw) -> str: - ... - def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, - create=False) -> rdataset.Rdataset: - ... - def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, - create=False) -> Optional[rdataset.Rdataset]: - ... - def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE): - ... - def replace_rdataset(self, replacement : rdataset.Rdataset) -> None: - ... diff --git a/lib/dns/opcode.py b/lib/dns/opcode.py index 5cf6143c..78b43d2c 100644 --- a/lib/dns/opcode.py +++ b/lib/dns/opcode.py @@ -20,6 +20,7 @@ import dns.enum import dns.exception + class Opcode(dns.enum.IntEnum): #: Query QUERY = 0 @@ -45,7 +46,7 @@ class UnknownOpcode(dns.exception.DNSException): """An DNS opcode is unknown.""" -def from_text(text): +def from_text(text: str) -> Opcode: """Convert text into an opcode. *text*, a ``str``, the textual opcode @@ -58,7 +59,7 @@ def from_text(text): return Opcode.from_text(text) -def from_flags(flags): +def from_flags(flags: int) -> Opcode: """Extract an opcode from DNS message flags. *flags*, an ``int``, the DNS flags. @@ -66,10 +67,10 @@ def from_flags(flags): Returns an ``int``. """ - return (flags & 0x7800) >> 11 + return Opcode((flags & 0x7800) >> 11) -def to_flags(value): +def to_flags(value: Opcode) -> int: """Convert an opcode to a value suitable for ORing into DNS message flags. @@ -81,7 +82,7 @@ def to_flags(value): return (value << 11) & 0x7800 -def to_text(value): +def to_text(value: Opcode) -> str: """Convert an opcode to text. *value*, an ``int`` the opcode value, @@ -94,7 +95,7 @@ def to_text(value): return Opcode.to_text(value) -def is_update(flags): +def is_update(flags: int) -> bool: """Is the opcode in flags UPDATE? *flags*, an ``int``, the DNS message flags. @@ -104,6 +105,7 @@ def is_update(flags): return from_flags(flags) == Opcode.UPDATE + ### BEGIN generated Opcode constants QUERY = Opcode.QUERY diff --git a/lib/dns/query.py b/lib/dns/query.py index 6d924b5f..b4cd69f7 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -17,6 +17,9 @@ """Talk to a DNS server.""" +from typing import Any, Dict, Optional, Tuple, Union + +import base64 import contextlib import enum import errno @@ -25,23 +28,26 @@ import selectors import socket import struct import time -import base64 import urllib.parse import dns.exception import dns.inet import dns.name import dns.message +import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype import dns.serial +import dns.transaction +import dns.tsig import dns.xfr try: import requests from requests_toolbelt.adapters.source import SourceAddressAdapter from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter + _have_requests = True except ImportError: # pragma: no cover _have_requests = False @@ -50,6 +56,7 @@ _have_httpx = False _have_http2 = False try: import httpx + _have_httpx = True try: # See if http2 support is available. @@ -65,24 +72,30 @@ have_doh = _have_requests or _have_httpx try: import ssl except ImportError: # pragma: no cover - class ssl: # type: ignore + class ssl: # type: ignore class WantReadException(Exception): pass class WantWriteException(Exception): pass + class SSLContext: + pass + class SSLSocket: pass - def create_default_context(self, *args, **kwargs): - raise Exception('no ssl support') + @classmethod + def create_default_context(cls, *args, **kwargs): + raise Exception("no ssl support") + # Function used to create a socket. Can be overridden if needed in special # situations. socket_factory = socket.socket + class UnexpectedSource(dns.exception.DNSException): """A DNS query response came from an unexpected address or port.""" @@ -96,6 +109,11 @@ class NoDOH(dns.exception.DNSException): available.""" +class NoDOQ(dns.exception.DNSException): + """DNS over QUIC (DOQ) was requested but the aioquic module is not + available.""" + + # for backwards compatibility TransferError = dns.xfr.TransferError @@ -143,13 +161,17 @@ def _set_selector_class(selector_class): _selector_class = selector_class -if hasattr(selectors, 'PollSelector'): + +if hasattr(selectors, "PollSelector"): # Prefer poll() on platforms that support it because it has no # limits on the maximum value of a file descriptor (plus it will # be more efficient for high values). - _selector_class = selectors.PollSelector + # + # We ignore typing here as we can't say _selector_class is Any + # on python < 3.8 due to a bug. + _selector_class = selectors.PollSelector # type: ignore else: - _selector_class = selectors.SelectSelector # pragma: no cover + _selector_class = selectors.SelectSelector # type: ignore def _wait_for_readable(s, expiration): @@ -177,18 +199,20 @@ def _matches_destination(af, from_address, destination, ignore_unexpected): # sent to destination. if not destination: return True - if _addresses_equal(af, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _addresses_equal(af, from_address, destination) or ( + dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:] + ): return True elif ignore_unexpected: return False - raise UnexpectedSource(f'got a response from {from_address} instead of ' - f'{destination}') + raise UnexpectedSource( + f"got a response from {from_address} instead of " f"{destination}" + ) -def _destination_and_source(where, port, source, source_port, - where_must_be_address=True): +def _destination_and_source( + where, port, source, source_port, where_must_be_address=True +): # Apply defaults and compute destination and source tuples # suitable for use in connect(), sendto(), or bind(). af = None @@ -205,8 +229,9 @@ def _destination_and_source(where, port, source, source_port, if af: # We know the destination af, so source had better agree! if saf != af: - raise ValueError('different address families for source ' + - 'and destination') + raise ValueError( + "different address families for source " + "and destination" + ) else: # We didn't know the destination af, but we know the source, # so that's our af. @@ -216,12 +241,11 @@ def _destination_and_source(where, port, source, source_port, # need to return a source, and we need to use the appropriate # wildcard address as the address. if af == socket.AF_INET: - source = '0.0.0.0' + source = "0.0.0.0" elif af == socket.AF_INET6: - source = '::' + source = "::" else: - raise ValueError('source_port specified but address family is ' - 'unknown') + raise ValueError("source_port specified but address family is unknown") # Convert high-level (address, port) tuples into low-level address # tuples. if destination: @@ -230,6 +254,7 @@ def _destination_and_source(where, port, source, source_port, source = dns.inet.low_level_address_tuple((source, source_port), af) return (af, destination, source) + def _make_socket(af, type, source, ssl_context=None, server_hostname=None): s = socket_factory(af, type) try: @@ -237,81 +262,98 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): if source is not None: s.bind(source) if ssl_context: - return ssl_context.wrap_socket(s, do_handshake_on_connect=False, - server_hostname=server_hostname) + # LGTM gets a false positive here, as our default context is OK + return ssl_context.wrap_socket( + s, + do_handshake_on_connect=False, # lgtm[py/insecure-protocol] + server_hostname=server_hostname, + ) else: return s except Exception: s.close() raise -def https(q, where, timeout=None, port=443, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, - session=None, path='/dns-query', post=True, - bootstrap_address=None, verify=True): + +def https( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 443, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + session: Optional[Any] = None, + path: str = "/dns-query", + post: bool = True, + bootstrap_address: Optional[str] = None, + verify: Union[bool, str] = True, +) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. - *where*, a ``str``, the nameserver IP address or the full URL. If an IP - address is given, the URL will be constructed using the following schema: + *where*, a ``str``, the nameserver IP address or the full URL. If an IP address is + given, the URL will be constructed using the following schema: https://:/. - *timeout*, a ``float`` or ``None``, the number of seconds to - wait before the query times out. If ``None``, the default, wait forever. + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. *port*, a ``int``, the port to send the query to. The default is 443. - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. - *source_port*, an ``int``, the port from which to send the message. - The default is 0. + *source_port*, an ``int``, the port from which to send the message. The default is + 0. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. - *session*, an ``httpx.Client`` or ``requests.session.Session``. If - provided, the client/session to use to send the queries. + *session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the + client/session to use to send the queries. *path*, a ``str``. If *where* is an IP address, then *path* will be used to construct the URL to send the DNS query to. *post*, a ``bool``. If ``True``, the default, POST method will be used. - *bootstrap_address*, a ``str``, the IP address to use to bypass the - system's DNS resolver. + *bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS + resolver. - *verify*, a ``str``, containing a path to a certificate file or directory. + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. Returns a ``dns.message.Message``. """ if not have_doh: - raise NoDOH('Neither httpx nor requests is available.') # pragma: no cover + raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover _httpx_ok = _have_httpx wire = q.to_wire() - (af, _, source) = _destination_and_source(where, port, source, source_port, - False) + (af, _, source) = _destination_and_source(where, port, source, source_port, False) transport_adapter = None transport = None - headers = { - "accept": "application/dns-message" - } + headers = {"accept": "application/dns-message"} if af is not None: if af == socket.AF_INET: - url = 'https://{}:{}{}'.format(where, port, path) + url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: - url = 'https://[{}]:{}{}'.format(where, port, path) + url = "https://[{}]:{}{}".format(where, port, path) elif bootstrap_address is not None: _httpx_ok = False split_url = urllib.parse.urlsplit(where) - headers['Host'] = split_url.hostname + if split_url.hostname is None: + raise ValueError("DoH URL has no hostname") + headers["Host"] = split_url.hostname url = where.replace(split_url.hostname, bootstrap_address) if _have_requests: transport_adapter = HostHeaderSSLAdapter() @@ -321,7 +363,7 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, # set source port and source address if _have_httpx: if source_port == 0: - transport = httpx.HTTPTransport(local_address=source[0]) + transport = httpx.HTTPTransport(local_address=source[0], verify=verify) else: _httpx_ok = False if _have_requests: @@ -333,70 +375,83 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, else: _is_httpx = False if _is_httpx and not _httpx_ok: - raise NoDOH('Session is httpx, but httpx cannot be used for ' - 'the requested operation.') + raise NoDOH( + "Session is httpx, but httpx cannot be used for " + "the requested operation." + ) else: _is_httpx = _httpx_ok if not _httpx_ok and not _have_requests: - raise NoDOH('Cannot use httpx for this operation, and ' - 'requests is not available.') + raise NoDOH( + "Cannot use httpx for this operation, and requests is not available." + ) - with contextlib.ExitStack() as stack: - if not session: - if _is_httpx: - session = stack.enter_context(httpx.Client(http1=True, - http2=_have_http2, - verify=verify, - transport=transport)) - else: - session = stack.enter_context(requests.sessions.Session()) - - if transport_adapter: + if session: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) + elif _is_httpx: + cm = httpx.Client( + http1=True, http2=_have_http2, verify=verify, transport=transport + ) + else: + cm = requests.sessions.Session() + with cm as session: + if transport_adapter and not _is_httpx: session.mount(url, transport_adapter) # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: - headers.update({ - "content-type": "application/dns-message", - "content-length": str(len(wire)) - }) + headers.update( + { + "content-type": "application/dns-message", + "content-length": str(len(wire)), + } + ) if _is_httpx: - response = session.post(url, headers=headers, content=wire, - timeout=timeout) + response = session.post( + url, headers=headers, content=wire, timeout=timeout + ) else: - response = session.post(url, headers=headers, data=wire, - timeout=timeout, verify=verify) + response = session.post( + url, headers=headers, data=wire, timeout=timeout, verify=verify + ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") if _is_httpx: - wire = wire.decode() # httpx does a repr() if we give it bytes - response = session.get(url, headers=headers, - timeout=timeout, - params={"dns": wire}) + twire = wire.decode() # httpx does a repr() if we give it bytes + response = session.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) else: - response = session.get(url, headers=headers, - timeout=timeout, verify=verify, - params={"dns": wire}) + response = session.get( + url, + headers=headers, + timeout=timeout, + verify=verify, + params={"dns": wire}, + ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: - raise ValueError('{} responded with status code {}' - '\nResponse body: {}'.format(where, - response.status_code, - response.content)) - r = dns.message.from_wire(response.content, - keyring=q.keyring, - request_mac=q.request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) - r.time = response.elapsed + raise ValueError( + "{} responded with status code {}" + "\nResponse body: {}".format(where, response.status_code, response.content) + ) + r = dns.message.from_wire( + response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = response.elapsed.total_seconds() if not q.is_response(r): raise BadResponse return r + def _udp_recv(sock, max_size, expiration): """Reads a datagram from the socket. A Timeout exception will be raised if the operation is not completed @@ -424,7 +479,12 @@ def _udp_send(sock, data, destination, expiration): _wait_for_writable(sock, expiration) -def send_udp(sock, what, destination, expiration=None): +def send_udp( + sock: Any, + what: Union[dns.message.Message, bytes], + destination: Any, + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified UDP socket. *sock*, a ``socket``. @@ -448,10 +508,17 @@ def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -def receive_udp(sock, destination=None, expiration=None, - ignore_unexpected=False, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False, - raise_on_truncation=False): +def receive_udp( + sock: Any, + destination: Optional[Any] = None, + expiration: Optional[float] = None, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, + raise_on_truncation: bool = False, +) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``socket``. @@ -473,7 +540,7 @@ def receive_udp(sock, destination=None, expiration=None, *keyring*, a ``dict``, the keyring to use for TSIG. - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. @@ -493,25 +560,41 @@ def receive_udp(sock, destination=None, expiration=None, the message arrived from. """ - wire = b'' + wire = b"" while True: (wire, from_address) = _udp_recv(sock, 65535, expiration) - if _matches_destination(sock.family, from_address, destination, - ignore_unexpected): + if _matches_destination( + sock.family, from_address, destination, ignore_unexpected + ): break received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) if destination: return (r, received_time) else: return (r, received_time, from_address) -def udp(q, where, timeout=None, port=53, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, - raise_on_truncation=False, sock=None): + +def udp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + sock: Optional[Any] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *q*, a ``dns.message.Message``, the query to send @@ -551,28 +634,49 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0, """ wire = q.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) (begin_time, expiration) = _compute_times(timeout) - with contextlib.ExitStack() as stack: - if sock: - s = sock - else: - s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source)) + if sock: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) + else: + cm = _make_socket(af, socket.SOCK_DGRAM, source) + with cm as s: send_udp(s, wire, destination, expiration) - (r, received_time) = receive_udp(s, destination, expiration, - ignore_unexpected, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing, - raise_on_truncation) + (r, received_time) = receive_udp( + s, + destination, + expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, + raise_on_truncation, + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) -def udp_with_fallback(q, where, timeout=None, port=53, source=None, - source_port=0, ignore_unexpected=False, - one_rr_per_rrset=False, ignore_trailing=False, - udp_sock=None, tcp_sock=None): + +def udp_with_fallback( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + udp_sock: Optional[Any] = None, + tcp_sock: Optional[Any] = None, +) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -616,26 +720,46 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None, if and only if TCP was used. """ try: - response = udp(q, where, timeout, port, source, source_port, - ignore_unexpected, one_rr_per_rrset, - ignore_trailing, True, udp_sock) + response = udp( + q, + where, + timeout, + port, + source, + source_port, + ignore_unexpected, + one_rr_per_rrset, + ignore_trailing, + True, + udp_sock, + ) return (response, False) except dns.message.Truncated: - response = tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, tcp_sock) + response = tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + tcp_sock, + ) return (response, True) + def _net_read(sock, count, expiration): """Read the specified number of bytes from sock. Keep trying until we either get the desired amount, or we hit EOF. A Timeout exception will be raised if the operation is not completed by the expiration time. """ - s = b'' + s = b"" while count > 0: try: n = sock.recv(count) - if n == b'': + if n == b"": raise EOFError count -= len(n) s += n @@ -662,7 +786,11 @@ def _net_write(sock, data, expiration): _wait_for_readable(sock, expiration) -def send_tcp(sock, what, expiration=None): +def send_tcp( + sock: Any, + what: Union[dns.message.Message, bytes], + expiration: Optional[float] = None, +) -> Tuple[int, float]: """Send a DNS message to the specified TCP socket. *sock*, a ``socket``. @@ -677,18 +805,27 @@ def send_tcp(sock, what, expiration=None): """ if isinstance(what, dns.message.Message): - what = what.to_wire() - l = len(what) + wire = what.to_wire() + else: + wire = what + l = len(wire) # copying the wire into tcpmsg is inefficient, but lets us # avoid writev() or doing a short write that would get pushed # onto the net - tcpmsg = struct.pack("!H", l) + what + tcpmsg = struct.pack("!H", l) + wire sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) -def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False): + +def receive_tcp( + sock: Any, + expiration: Optional[float] = None, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, +) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``socket``. @@ -702,7 +839,7 @@ def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, *keyring*, a ``dict``, the keyring to use for TSIG. - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. @@ -718,11 +855,16 @@ def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, (l,) = struct.unpack("!H", ldata) wire = _net_read(sock, l, expiration) received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) return (r, received_time) + def _connect(s, address, expiration): err = s.connect_ex(address) if err == 0: @@ -734,8 +876,17 @@ def _connect(s, address, expiration): raise OSError(err, os.strerror(err)) -def tcp(q, where, timeout=None, port=53, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None): +def tcp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[Any] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *q*, a ``dns.message.Message``, the query to send @@ -770,23 +921,27 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - with contextlib.ExitStack() as stack: - if sock: - s = sock - else: - (af, destination, source) = _destination_and_source(where, port, - source, - source_port) - s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, - source)) + if sock: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) + else: + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + cm = _make_socket(af, socket.SOCK_STREAM, source) + with cm as s: + if not sock: _connect(s, destination, expiration) send_tcp(s, wire, expiration) - (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing) + (r, received_time) = receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) def _tls_handshake(s, expiration): @@ -800,9 +955,19 @@ def _tls_handshake(s, expiration): _wait_for_writable(s, expiration) -def tls(q, where, timeout=None, port=853, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None, - ssl_context=None, server_hostname=None): +def tls( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[ssl.SSLSocket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, +) -> dns.message.Message: """Return the response obtained after sending a query via TLS. *q*, a ``dns.message.Message``, the query to send @@ -849,35 +1014,148 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0, # # If a socket was provided, there's no special TLS handling needed. # - return tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing, sock) + return tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + sock, + ) wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) if ssl_context is None and not sock: ssl_context = ssl.create_default_context() + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 if server_hostname is None: ssl_context.check_hostname = False - with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context, - server_hostname=server_hostname) as s: + with _make_socket( + af, + socket.SOCK_STREAM, + source, + ssl_context=ssl_context, + server_hostname=server_hostname, + ) as s: _connect(s, destination, expiration) _tls_handshake(s, expiration) send_tcp(s, wire, expiration) - (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing) + (r, received_time) = receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) -def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, - timeout=None, port=53, keyring=None, keyname=None, relativize=True, - lifetime=None, source=None, source_port=0, serial=0, - use_udp=False, keyalgorithm=dns.tsig.default_algorithm): +def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.SyncQuicConnection] = None, + verify: Union[bool, str] = True, +) -> dns.message.Message: + """Return the response obtained after sending a query via DNS-over-QUIC. + + *q*, a ``dns.message.Message``, the query to send. + + *where*, a ``str``, the nameserver IP address. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. + + *port*, a ``int``, the port to send the query to. The default is 853. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is + 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. + + *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the + connection to use to send the query. + + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + + Returns a ``dns.message.Message``. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.SyncQuicConnection + the_manager: dns.quic.SyncQuicManager + if connection: + manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) + the_connection = connection + else: + manager = dns.quic.SyncQuicManager(verify_mode=verify) + the_manager = manager # for type checking happiness + + with manager: + if not connection: + the_connection = the_manager.connect(where, port, source, source_port) + start = time.time() + with the_connection.make_stream() as stream: + stream.send(wire, True) + wire = stream.receive(timeout) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + +def xfr( + where: str, + zone: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + timeout: Optional[float] = None, + port: int = 53, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + keyname: Optional[Union[dns.name.Name, str]] = None, + relativize: bool = True, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + serial: int = 0, + use_udp: bool = False, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, +) -> Any: """Return a generator for the responses to a zone transfer. *where*, a ``str`` containing an IPv4 or IPv6 address, where @@ -935,16 +1213,16 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, rdtype = dns.rdatatype.RdataType.make(rdtype) q = dns.message.make_query(zone, rdtype, rdclass) if rdtype == dns.rdatatype.IXFR: - rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA', - '. . %u 0 0 0 0' % serial) + rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) q.authority.append(rrset) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) wire = q.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) if use_udp and rdtype != dns.rdatatype.IXFR: - raise ValueError('cannot do a UDP AXFR') + raise ValueError("cannot do a UDP AXFR") sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM with _make_socket(af, sock_type, source) as s: (_, expiration) = _compute_times(lifetime) @@ -968,8 +1246,9 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, tsig_ctx = None while not done: (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or \ - (expiration is not None and mexpiration > expiration): + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): mexpiration = expiration if use_udp: (wire, _) = _udp_recv(s, 65535, mexpiration) @@ -977,11 +1256,17 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) wire = _net_read(s, l, mexpiration) - is_ixfr = (rdtype == dns.rdatatype.IXFR) - r = dns.message.from_wire(wire, keyring=q.keyring, - request_mac=q.mac, xfr=True, - origin=origin, tsig_ctx=tsig_ctx, - multi=True, one_rr_per_rrset=is_ixfr) + is_ixfr = rdtype == dns.rdatatype.IXFR + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=True, + one_rr_per_rrset=is_ixfr, + ) rcode = r.rcode() if rcode != dns.rcode.NOERROR: raise TransferError(rcode) @@ -989,8 +1274,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, answer_index = 0 if soa_rrset is None: if not r.answer or r.answer[0].name != oname: - raise dns.exception.FormError( - "No answer or RRset not for qname") + raise dns.exception.FormError("No answer or RRset not for qname") rrset = r.answer[0] if rrset.rdtype != dns.rdatatype.SOA: raise dns.exception.FormError("first RRset is not an SOA") @@ -1014,8 +1298,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: if expecting_SOA: if rrset[0].serial != serial: - raise dns.exception.FormError( - "IXFR base serial mismatch") + raise dns.exception.FormError("IXFR base serial mismatch") expecting_SOA = False elif rdtype == dns.rdatatype.IXFR: delete_mode = not delete_mode @@ -1024,9 +1307,10 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, # finished. If this is an IXFR we also check that we're # seeing the record in the expected part of the response. # - if rrset == soa_rrset and \ - (rdtype == dns.rdatatype.AXFR or - (rdtype == dns.rdatatype.IXFR and delete_mode)): + if rrset == soa_rrset and ( + rdtype == dns.rdatatype.AXFR + or (rdtype == dns.rdatatype.IXFR and delete_mode) + ): done = True elif expecting_SOA: # @@ -1048,14 +1332,23 @@ class UDPMode(enum.IntEnum): TRY_FIRST means "try to use UDP but fall back to TCP if needed" ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" """ + NEVER = 0 TRY_FIRST = 1 ONLY = 2 -def inbound_xfr(where, txn_manager, query=None, - port=53, timeout=None, lifetime=None, source=None, - source_port=0, udp_mode=UDPMode.NEVER): +def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, +) -> None: """Conduct an inbound transfer and apply it via a transaction from the txn_manager. @@ -1100,8 +1393,9 @@ def inbound_xfr(where, txn_manager, query=None, is_ixfr = rdtype == dns.rdatatype.IXFR origin = txn_manager.from_wire_origin() wire = query.to_wire() - (af, destination, source) = _destination_and_source(where, port, - source, source_port) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) (_, expiration) = _compute_times(lifetime) retry = True while retry: @@ -1119,14 +1413,14 @@ def inbound_xfr(where, txn_manager, query=None, else: tcpmsg = struct.pack("!H", len(wire)) + wire _net_write(s, tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial, - is_udp) as inbound: + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: done = False tsig_ctx = None while not done: (_, mexpiration) = _compute_times(timeout) - if mexpiration is None or \ - (expiration is not None and mexpiration > expiration): + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): mexpiration = expiration if is_udp: (rwire, _) = _udp_recv(s, 65535, mexpiration) @@ -1134,11 +1428,16 @@ def inbound_xfr(where, txn_manager, query=None, ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) rwire = _net_read(s, l, mexpiration) - r = dns.message.from_wire(rwire, keyring=query.keyring, - request_mac=query.mac, xfr=True, - origin=origin, tsig_ctx=tsig_ctx, - multi=(not is_udp), - one_rr_per_rrset=is_ixfr) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) try: done = inbound.process_message(r) except dns.xfr.UseTCP: diff --git a/lib/dns/query.pyi b/lib/dns/query.pyi deleted file mode 100644 index a22e229f..00000000 --- a/lib/dns/query.pyi +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional, Union, Dict, Generator, Any -from . import tsig, rdatatype, rdataclass, name, message -from requests.sessions import Session - -import socket - -# If the ssl import works, then -# -# error: Name 'ssl' already defined (by an import) -# -# is expected and can be ignored. -try: - import ssl -except ImportError: - class ssl: # type: ignore - SSLContext : Dict = {} - -have_doh: bool - -def https(q : message.Message, where: str, timeout : Optional[float] = None, - port : Optional[int] = 443, source : Optional[str] = None, - source_port : Optional[int] = 0, - session: Optional[Session] = None, - path : Optional[str] = '/dns-query', post : Optional[bool] = True, - bootstrap_address : Optional[str] = None, - verify : Optional[bool] = True) -> message.Message: - pass - -def tcp(q : message.Message, where : str, timeout : float = None, port=53, - af : Optional[int] = None, source : Optional[str] = None, - source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[socket.socket] = None) -> message.Message: - pass - -def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR, - rdclass=rdataclass.IN, - timeout : Optional[float] = None, port=53, - keyring : Optional[Dict[name.Name, bytes]] = None, - keyname : Union[str,name.Name]= None, relativize=True, - lifetime : Optional[float] = None, - source : Optional[str] = None, source_port=0, serial=0, - use_udp : Optional[bool] = False, - keyalgorithm=tsig.default_algorithm) \ - -> Generator[Any,Any,message.Message]: - pass - -def udp(q : message.Message, where : str, timeout : Optional[float] = None, - port=53, source : Optional[str] = None, source_port : Optional[int] = 0, - ignore_unexpected : Optional[bool] = False, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[socket.socket] = None) -> message.Message: - pass - -def tls(q : message.Message, where : str, timeout : Optional[float] = None, - port=53, source : Optional[str] = None, source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, - server_hostname: Optional[str] = None) -> message.Message: - pass diff --git a/lib/dns/quic/__init__.py b/lib/dns/quic/__init__.py new file mode 100644 index 00000000..f48ecf57 --- /dev/null +++ b/lib/dns/quic/__init__.py @@ -0,0 +1,74 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +try: + import aioquic.quic.configuration # type: ignore + + import dns.asyncbackend + from dns._asyncbackend import NullContext + from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream + from dns.quic._asyncio import ( + AsyncioQuicManager, + AsyncioQuicConnection, + AsyncioQuicStream, + ) + from dns.quic._common import AsyncQuicConnection, AsyncQuicManager + + have_quic = True + + def null_factory( + *args, # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument + ): + return NullContext(None) + + def _asyncio_manager_factory( + context, *args, **kwargs # pylint: disable=unused-argument + ): + return AsyncioQuicManager(*args, **kwargs) + + # We have a context factory and a manager factory as for trio we need to have + # a nursery. + + _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} + + try: + import trio + from dns.quic._trio import ( # pylint: disable=ungrouped-imports + TrioQuicManager, + TrioQuicConnection, + TrioQuicStream, + ) + + def _trio_context_factory(): + return trio.open_nursery() + + def _trio_manager_factory(context, *args, **kwargs): + return TrioQuicManager(context, *args, **kwargs) + + _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) + except ImportError: + pass + + def factories_for_backend(backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + return _async_factories[backend.name()] + +except ImportError: + have_quic = False + + from typing import Any + + class AsyncQuicStream: # type: ignore + pass + + class AsyncQuicConnection: # type: ignore + async def make_stream(self) -> Any: + raise NotImplementedError + + class SyncQuicStream: # type: ignore + pass + + class SyncQuicConnection: # type: ignore + def make_stream(self) -> Any: + raise NotImplementedError diff --git a/lib/dns/quic/_asyncio.py b/lib/dns/quic/_asyncio.py new file mode 100644 index 00000000..0a2e220d --- /dev/null +++ b/lib/dns/quic/_asyncio.py @@ -0,0 +1,206 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import asyncio +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import dns.inet +import dns.asyncbackend + +from dns.quic._common import ( + BaseQuicStream, + AsyncQuicConnection, + AsyncQuicManager, + QUIC_MAX_DATAGRAM, +) + + +class AsyncioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = asyncio.Condition() + + async def _wait_for_wake_up(self): + async with self._wake_up: + await self._wake_up.wait() + + async def wait_for(self, amount, expiration): + timeout = self._timeout_from_expiration(expiration) + while True: + if self._buffer.have(amount): + return + self._expecting = amount + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except Exception: + pass + self._expecting = 0 + + async def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + await self.wait_for(2, expiration) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size, expiration) + return self._buffer.get(size) + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class AsyncioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = None + self._handshake_complete = asyncio.Event() + self._socket_created = asyncio.Event() + self._wake_timer = asyncio.Condition() + self._receiver_task = None + self._sender_task = None + + async def _receiver(self): + try: + af = dns.inet.af_for_address(self._address) + backend = dns.asyncbackend.get_backend("asyncio") + self._socket = await backend.make_socket( + af, socket.SOCK_DGRAM, 0, self._source, self._peer + ) + self._socket_created.set() + async with self._socket: + while not self._done: + (datagram, address) = await self._socket.recvfrom( + QUIC_MAX_DATAGRAM, None + ) + if address[0] != self._peer[0] or address[1] != self._peer[1]: + continue + self._connection.receive_datagram( + datagram, self._peer[0], time.time() + ) + # Wake up the timer in case the sender is sleeping, as there may be + # stuff to send now. + async with self._wake_timer: + self._wake_timer.notify_all() + except Exception: + pass + + async def _wait_for_wake_timer(self): + async with self._wake_timer: + await self._wake_timer.wait() + + async def _sender(self): + await self._socket_created.wait() + while not self._done: + datagrams = self._connection.datagrams_to_send(time.time()) + for (datagram, address) in datagrams: + assert address == self._peer[0] + await self._socket.sendto(datagram, self._peer, None) + (expiration, interval) = self._get_timer_values() + try: + await asyncio.wait_for(self._wait_for_wake_timer(), interval) + except Exception: + pass + self._handle_timer(expiration) + await self._handle_events() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance( + event, aioquic.quic.events.ConnectionTerminated + ) or isinstance(event, aioquic.quic.events.StreamReset): + self._done = True + self._receiver_task.cancel() + count += 1 + if count > 10: + # yield + count = 0 + await asyncio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + async with self._wake_timer: + self._wake_timer.notify_all() + + def run(self): + if self._closed: + return + self._receiver_task = asyncio.Task(self._receiver()) + self._sender_task = asyncio.Task(self._sender()) + + async def make_stream(self): + await self._handshake_complete.wait() + stream_id = self._connection.get_next_available_stream_id(False) + stream = AsyncioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + async with self._wake_timer: + self._wake_timer.notify_all() + try: + await self._receiver_task + except asyncio.CancelledError: + pass + try: + await self._sender_task + except asyncio.CancelledError: + pass + + +class AsyncioQuicManager(AsyncQuicManager): + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): + super().__init__(conf, verify_mode, AsyncioQuicConnection) + + def connect(self, address, port=853, source=None, source_port=0): + (connection, start) = self._connect(address, port, source, source_port) + if start: + connection.run() + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the itertor into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/lib/dns/quic/_common.py b/lib/dns/quic/_common.py new file mode 100644 index 00000000..d8f6f7fd --- /dev/null +++ b/lib/dns/quic/_common.py @@ -0,0 +1,180 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import struct +import time + +from typing import Any + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import dns.inet + + +QUIC_MAX_DATAGRAM = 2048 + + +class UnexpectedEOF(Exception): + pass + + +class Buffer: + def __init__(self): + self._buffer = b"" + self._seen_end = False + + def put(self, data, is_end): + if self._seen_end: + return + self._buffer += data + if is_end: + self._seen_end = True + + def have(self, amount): + if len(self._buffer) >= amount: + return True + if self._seen_end: + raise UnexpectedEOF + return False + + def seen_end(self): + return self._seen_end + + def get(self, amount): + assert self.have(amount) + data = self._buffer[:amount] + self._buffer = self._buffer[amount:] + return data + + +class BaseQuicStream: + def __init__(self, connection, stream_id): + self._connection = connection + self._stream_id = stream_id + self._buffer = Buffer() + self._expecting = 0 + + def id(self): + return self._stream_id + + def _expiration_from_timeout(self, timeout): + if timeout is not None: + expiration = time.time() + timeout + else: + expiration = None + return expiration + + def _timeout_from_expiration(self, expiration): + if expiration is not None: + timeout = max(expiration - time.time(), 0.0) + else: + timeout = None + return timeout + + # Subclass must implement receive() as sync / async and which returns a message + # or raises UnexpectedEOF. + + def _encapsulate(self, datagram): + l = len(datagram) + return struct.pack("!H", l) + datagram + + def _common_add_input(self, data, is_end): + self._buffer.put(data, is_end) + return self._expecting > 0 and self._buffer.have(self._expecting) + + def _close(self): + self._connection.close_stream(self._stream_id) + self._buffer.put(b"", True) # send EOF in case we haven't seen it. + + +class BaseQuicConnection: + def __init__( + self, connection, address, port, source=None, source_port=0, manager=None + ): + self._done = False + self._connection = connection + self._address = address + self._port = port + self._closed = False + self._manager = manager + self._streams = {} + self._af = dns.inet.af_for_address(address) + self._peer = dns.inet.low_level_address_tuple((address, port)) + if source is None and source_port != 0: + if self._af == socket.AF_INET: + source = "0.0.0.0" + elif self._af == socket.AF_INET6: + source = "::" + else: + raise NotImplementedError + if source: + self._source = (source, source_port) + else: + self._source = None + + def close_stream(self, stream_id): + del self._streams[stream_id] + + def _get_timer_values(self, closed_is_special=True): + now = time.time() + expiration = self._connection.get_timer() + if expiration is None: + expiration = now + 3600 # arbitrary "big" value + interval = max(expiration - now, 0) + if self._closed and closed_is_special: + # lower sleep interval to avoid a race in the closing process + # which can lead to higher latency closing due to sleeping when + # we have events. + interval = min(interval, 0.05) + return (expiration, interval) + + def _handle_timer(self, expiration): + now = time.time() + if expiration <= now: + self._connection.handle_timer(now) + + +class AsyncQuicConnection(BaseQuicConnection): + async def make_stream(self) -> Any: + pass + + +class BaseQuicManager: + def __init__(self, conf, verify_mode, connection_factory): + self._connections = {} + self._connection_factory = connection_factory + if conf is None: + verify_path = None + if isinstance(verify_mode, str): + verify_path = verify_mode + verify_mode = True + conf = aioquic.quic.configuration.QuicConfiguration( + alpn_protocols=["doq", "doq-i03"], + verify_mode=verify_mode, + ) + if verify_path is not None: + conf.load_verify_locations(verify_path) + self._conf = conf + + def _connect(self, address, port=853, source=None, source_port=0): + connection = self._connections.get((address, port)) + if connection is not None: + return (connection, False) + qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf) + qconn.connect(address, time.time()) + connection = self._connection_factory( + qconn, address, port, source, source_port, self + ) + self._connections[(address, port)] = connection + return (connection, True) + + def closed(self, address, port): + try: + del self._connections[(address, port)] + except KeyError: + pass + + +class AsyncQuicManager(BaseQuicManager): + def connect(self, address, port=853, source=None, source_port=0): + raise NotImplementedError diff --git a/lib/dns/quic/_sync.py b/lib/dns/quic/_sync.py new file mode 100644 index 00000000..be005ba9 --- /dev/null +++ b/lib/dns/quic/_sync.py @@ -0,0 +1,214 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import ssl +import selectors +import struct +import threading +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import dns.inet + +from dns.quic._common import ( + BaseQuicStream, + BaseQuicConnection, + BaseQuicManager, + QUIC_MAX_DATAGRAM, +) + +# Avoid circularity with dns.query +if hasattr(selectors, "PollSelector"): + _selector_class = selectors.PollSelector # type: ignore +else: + _selector_class = selectors.SelectSelector # type: ignore + + +class SyncQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = threading.Condition() + self._lock = threading.Lock() + + def wait_for(self, amount, expiration): + timeout = self._timeout_from_expiration(expiration) + while True: + with self._lock: + if self._buffer.have(amount): + return + self._expecting = amount + with self._wake_up: + self._wake_up.wait(timeout) + self._expecting = 0 + + def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + self.wait_for(2, expiration) + with self._lock: + (size,) = struct.unpack("!H", self._buffer.get(2)) + self.wait_for(size, expiration) + with self._lock: + return self._buffer.get(size) + + def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + self._connection.write(self._stream_id, data, is_end) + + def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + with self._wake_up: + self._wake_up.notify() + + def close(self): + with self._lock: + self._close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + with self._wake_up: + self._wake_up.notify() + return False + + +class SyncQuicConnection(BaseQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) + if self._source is not None: + try: + self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) + except Exception: + self._socket.close() + raise + self._handshake_complete = threading.Event() + self._worker_thread = None + self._lock = threading.Lock() + + def _read(self): + count = 0 + while count < 10: + count += 1 + try: + datagram = self._socket.recv(QUIC_MAX_DATAGRAM) + except BlockingIOError: + return + with self._lock: + self._connection.receive_datagram(datagram, self._peer[0], time.time()) + + def _drain_wakeup(self): + while True: + try: + self._receive_wakeup.recv(32) + except BlockingIOError: + return + + def _worker(self): + sel = _selector_class() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for (key, _) in items: + key.data() + with self._lock: + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for (datagram, _) in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + self._handle_events() + + def _handle_events(self): + while True: + with self._lock: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance( + event, aioquic.quic.events.ConnectionTerminated + ) or isinstance(event, aioquic.quic.events.StreamReset): + with self._lock: + self._done = True + + def write(self, stream, data, is_end=False): + with self._lock: + self._connection.send_stream_data(stream, data, is_end) + self._send_wakeup.send(b"\x01") + + def run(self): + if self._closed: + return + self._worker_thread = threading.Thread(target=self._worker) + self._worker_thread.start() + + def make_stream(self): + self._handshake_complete.wait() + with self._lock: + stream_id = self._connection.get_next_available_stream_id(False) + stream = SyncQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + def close_stream(self, stream_id): + with self._lock: + super().close_stream(stream_id) + + def close(self): + with self._lock: + if self._closed: + return + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + self._send_wakeup.send(b"\x01") + self._worker_thread.join() + + +class SyncQuicManager(BaseQuicManager): + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): + super().__init__(conf, verify_mode, SyncQuicConnection) + self._lock = threading.Lock() + + def connect(self, address, port=853, source=None, source_port=0): + with self._lock: + (connection, start) = self._connect(address, port, source, source_port) + if start: + connection.run() + return connection + + def closed(self, address, port): + with self._lock: + super().closed(address, port) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Copy the itertor into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + connection.close() + return False diff --git a/lib/dns/quic/_trio.py b/lib/dns/quic/_trio.py new file mode 100644 index 00000000..1e47a5a6 --- /dev/null +++ b/lib/dns/quic/_trio.py @@ -0,0 +1,170 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import trio + +import dns.inet +from dns._asyncbackend import NullContext +from dns.quic._common import ( + BaseQuicStream, + AsyncQuicConnection, + AsyncQuicManager, + QUIC_MAX_DATAGRAM, +) + + +class TrioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = trio.Condition() + + async def wait_for(self, amount): + while True: + if self._buffer.have(amount): + return + self._expecting = amount + async with self._wake_up: + await self._wake_up.wait() + self._expecting = 0 + + async def receive(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self.wait_for(2) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size) + return self._buffer.get(size) + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class TrioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) + if self._source: + trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af)) + self._handshake_complete = trio.Event() + self._run_done = trio.Event() + self._worker_scope = None + + async def _worker(self): + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram(datagram, self._peer[0], time.time()) + self._worker_scope = None + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for (datagram, _) in datagrams: + await self._socket.send(datagram) + await self._handle_events() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance( + event, aioquic.quic.events.ConnectionTerminated + ) or isinstance(event, aioquic.quic.events.StreamReset): + self._done = True + self._socket.close() + count += 1 + if count > 10: + # yield + count = 0 + await trio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + if self._worker_scope is not None: + self._worker_scope.cancel() + + async def run(self): + if self._closed: + return + async with trio.open_nursery() as nursery: + nursery.start_soon(self._worker) + self._run_done.set() + + async def make_stream(self): + await self._handshake_complete.wait() + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + if self._worker_scope is not None: + self._worker_scope.cancel() + await self._run_done.wait() + + +class TrioQuicManager(AsyncQuicManager): + def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): + super().__init__(conf, verify_mode, TrioQuicConnection) + self._nursery = nursery + + def connect(self, address, port=853, source=None, source_port=0): + (connection, start) = self._connect(address, port, source, source_port) + if start: + self._nursery.start_soon(connection.run) + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the itertor into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/lib/dns/rcode.py b/lib/dns/rcode.py index 49fee695..8e6386f8 100644 --- a/lib/dns/rcode.py +++ b/lib/dns/rcode.py @@ -17,9 +17,12 @@ """DNS Result Codes.""" +from typing import Tuple + import dns.enum import dns.exception + class Rcode(dns.enum.IntEnum): #: No error NOERROR = 0 @@ -77,20 +80,20 @@ class UnknownRcode(dns.exception.DNSException): """A DNS rcode is unknown.""" -def from_text(text): +def from_text(text: str) -> Rcode: """Convert text into an rcode. *text*, a ``str``, the textual rcode or an integer in textual form. Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown. - Returns an ``int``. + Returns a ``dns.rcode.Rcode``. """ return Rcode.from_text(text) -def from_flags(flags, ednsflags): +def from_flags(flags: int, ednsflags: int) -> Rcode: """Return the rcode value encoded by flags and ednsflags. *flags*, an ``int``, the DNS flags field. @@ -99,17 +102,17 @@ def from_flags(flags, ednsflags): Raises ``ValueError`` if rcode is < 0 or > 4095 - Returns an ``int``. + Returns a ``dns.rcode.Rcode``. """ - value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) - return value + value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0) + return Rcode.make(value) -def to_flags(value): +def to_flags(value: Rcode) -> Tuple[int, int]: """Return a (flags, ednsflags) tuple which encodes the rcode. - *value*, an ``int``, the rcode. + *value*, a ``dns.rcode.Rcode``, the rcode. Raises ``ValueError`` if rcode is < 0 or > 4095. @@ -117,16 +120,16 @@ def to_flags(value): """ if value < 0 or value > 4095: - raise ValueError('rcode must be >= 0 and <= 4095') - v = value & 0xf - ev = (value & 0xff0) << 20 + raise ValueError("rcode must be >= 0 and <= 4095") + v = value & 0xF + ev = (value & 0xFF0) << 20 return (v, ev) -def to_text(value, tsig=False): +def to_text(value: Rcode, tsig: bool = False) -> str: """Convert rcode into text. - *value*, an ``int``, the rcode. + *value*, a ``dns.rcode.Rcode``, the rcode. Raises ``ValueError`` if rcode is < 0 or > 4095. @@ -134,9 +137,10 @@ def to_text(value, tsig=False): """ if tsig and value == Rcode.BADVERS: - return 'BADSIG' + return "BADSIG" return Rcode.to_text(value) + ### BEGIN generated Rcode constants NOERROR = Rcode.NOERROR diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index 6b5b5c5a..1dd6ed90 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -17,6 +17,8 @@ """DNS rdata.""" +from typing import Any, Dict, Optional, Tuple, Union + from importlib import import_module import base64 import binascii @@ -55,21 +57,22 @@ class NoRelativeRdataOrdering(dns.exception.DNSException): """ -def _wordbreak(data, chunksize=_chunksize, separator=b' '): +def _wordbreak(data, chunksize=_chunksize, separator=b" "): """Break a binary string into chunks of chunksize characters separated by a space. """ if not chunksize: return data.decode() - return separator.join([data[i:i + chunksize] - for i - in range(0, len(data), chunksize)]).decode() + return separator.join( + [data[i : i + chunksize] for i in range(0, len(data), chunksize)] + ).decode() # pylint: disable=unused-argument -def _hexify(data, chunksize=_chunksize, separator=b' ', **kw): + +def _hexify(data, chunksize=_chunksize, separator=b" ", **kw): """Convert a binary string into its hex encoding, broken up into chunks of chunksize characters separated by a separator. """ @@ -77,17 +80,19 @@ def _hexify(data, chunksize=_chunksize, separator=b' ', **kw): return _wordbreak(binascii.hexlify(data), chunksize, separator) -def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw): +def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw): """Convert a binary string into its base64 encoding, broken up into chunks of chunksize characters separated by a separator. """ return _wordbreak(base64.b64encode(data), chunksize, separator) + # pylint: enable=unused-argument __escaped = b'"\\' + def _escapify(qstring): """Escape the characters in a quoted string which need it.""" @@ -96,14 +101,14 @@ def _escapify(qstring): if not isinstance(qstring, bytearray): qstring = bytearray(qstring) - text = '' + text = "" for c in qstring: if c in __escaped: - text += '\\' + chr(c) + text += "\\" + chr(c) elif c >= 0x20 and c < 0x7F: text += chr(c) else: - text += '\\%03d' % c + text += "\\%03d" % c return text @@ -114,9 +119,10 @@ def _truncate_bitmap(what): for i in range(len(what) - 1, -1, -1): if what[i] != 0: - return what[0: i + 1] + return what[0 : i + 1] return what[0:1] + # So we don't have to edit all the rdata classes... _constify = dns.immutable.constify @@ -125,7 +131,7 @@ _constify = dns.immutable.constify class Rdata: """Base class for all DNS rdata types.""" - __slots__ = ['rdclass', 'rdtype', 'rdcomment'] + __slots__ = ["rdclass", "rdtype", "rdcomment"] def __init__(self, rdclass, rdtype): """Initialize an rdata. @@ -140,8 +146,9 @@ class Rdata: self.rdcomment = None def _get_all_slots(self): - return itertools.chain.from_iterable(getattr(cls, '__slots__', []) - for cls in self.__class__.__mro__) + return itertools.chain.from_iterable( + getattr(cls, "__slots__", []) for cls in self.__class__.__mro__ + ) def __getstate__(self): # We used to try to do a tuple of all slots here, but it @@ -160,12 +167,12 @@ class Rdata: def __setstate__(self, state): for slot, val in state.items(): object.__setattr__(self, slot, val) - if not hasattr(self, 'rdcomment'): + if not hasattr(self, "rdcomment"): # Pickled rdata from 2.0.x might not have a rdcomment, so add # it if needed. - object.__setattr__(self, 'rdcomment', None) + object.__setattr__(self, "rdcomment", None) - def covers(self): + def covers(self) -> dns.rdatatype.RdataType: """Return the type a Rdata covers. DNS SIG/RRSIG rdatas apply to a specific type; this type is @@ -174,12 +181,12 @@ class Rdata: creating rdatasets, allowing the rdataset to contain only RRSIGs of a particular type, e.g. RRSIG(NS). - Returns an ``int``. + Returns a ``dns.rdatatype.RdataType``. """ return dns.rdatatype.NONE - def extended_rdatatype(self): + def extended_rdatatype(self) -> int: """Return a 32-bit type value, the least significant 16 bits of which are the ordinary DNS type, and the upper 16 bits of which are the "covered" type, if any. @@ -189,7 +196,12 @@ class Rdata: return self.covers() << 16 | self.rdtype - def to_text(self, origin=None, relativize=True, **kw): + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: """Convert an rdata to text format. Returns a ``str``. @@ -197,11 +209,22 @@ class Rdata: raise NotImplementedError # pragma: no cover - def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + def _to_wire( + self, + file: Optional[Any], + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + canonicalize: bool = False, + ) -> bytes: raise NotImplementedError # pragma: no cover - def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False): + def to_wire( + self, + file: Optional[Any] = None, + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + canonicalize: bool = False, + ) -> bytes: """Convert an rdata to wire format. Returns a ``bytes`` or ``None``. @@ -214,15 +237,18 @@ class Rdata: self._to_wire(f, compress, origin, canonicalize) return f.getvalue() - def to_generic(self, origin=None): + def to_generic( + self, origin: Optional[dns.name.Name] = None + ) -> "dns.rdata.GenericRdata": """Creates a dns.rdata.GenericRdata equivalent of this rdata. Returns a ``dns.rdata.GenericRdata``. """ - return dns.rdata.GenericRdata(self.rdclass, self.rdtype, - self.to_wire(origin=origin)) + return dns.rdata.GenericRdata( + self.rdclass, self.rdtype, self.to_wire(origin=origin) + ) - def to_digestable(self, origin=None): + def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes: """Convert rdata to a format suitable for digesting in hashes. This is also the DNSSEC canonical form. @@ -234,12 +260,19 @@ class Rdata: def __repr__(self): covers = self.covers() if covers == dns.rdatatype.NONE: - ctext = '' + ctext = "" else: - ctext = '(' + dns.rdatatype.to_text(covers) + ')' - return '' + ctext = "(" + dns.rdatatype.to_text(covers) + ")" + return ( + "" + ) def __str__(self): return self.to_text() @@ -320,27 +353,39 @@ class Rdata: return not self.__eq__(other) def __lt__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) < 0 def __le__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) <= 0 def __ge__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) >= 0 def __gt__(self, other): - if not isinstance(other, Rdata) or \ - self.rdclass != other.rdclass or self.rdtype != other.rdtype: + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): return NotImplemented return self._cmp(other) > 0 @@ -348,15 +393,28 @@ class Rdata: return hash(self.to_digestable(dns.name.root)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + ) -> "Rdata": raise NotImplementedError # pragma: no cover @classmethod - def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + def from_wire_parser( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + parser: dns.wire.Parser, + origin: Optional[dns.name.Name] = None, + ) -> "Rdata": raise NotImplementedError # pragma: no cover - def replace(self, **kwargs): + def replace(self, **kwargs: Any) -> "Rdata": """ Create a new Rdata instance based on the instance replace was invoked on. It is possible to pass different parameters to @@ -369,19 +427,25 @@ class Rdata: """ # Get the constructor parameters. - parameters = inspect.signature(self.__init__).parameters + parameters = inspect.signature(self.__init__).parameters # type: ignore # Ensure that all of the arguments correspond to valid fields. # Don't allow rdclass or rdtype to be changed, though. for key in kwargs: - if key == 'rdcomment': + if key == "rdcomment": continue if key not in parameters: - raise AttributeError("'{}' object has no attribute '{}'" - .format(self.__class__.__name__, key)) - if key in ('rdclass', 'rdtype'): - raise AttributeError("Cannot overwrite '{}' attribute '{}'" - .format(self.__class__.__name__, key)) + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, key + ) + ) + if key in ("rdclass", "rdtype"): + raise AttributeError( + "Cannot overwrite '{}' attribute '{}'".format( + self.__class__.__name__, key + ) + ) # Construct the parameter list. For each field, use the value in # kwargs if present, and the current value otherwise. @@ -391,9 +455,9 @@ class Rdata: rd = self.__class__(*args) # The comment is not set in the constructor, so give it special # handling. - rdcomment = kwargs.get('rdcomment', self.rdcomment) + rdcomment = kwargs.get("rdcomment", self.rdcomment) if rdcomment is not None: - object.__setattr__(rd, 'rdcomment', rdcomment) + object.__setattr__(rd, "rdcomment", rdcomment) return rd # Type checking and conversion helpers. These are class methods as @@ -408,18 +472,26 @@ class Rdata: return dns.rdatatype.RdataType.make(value) @classmethod - def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True): + def _as_bytes( + cls, + value: Any, + encode: bool = False, + max_length: Optional[int] = None, + empty_ok: bool = True, + ) -> bytes: if encode and isinstance(value, str): - value = value.encode() + bvalue = value.encode() elif isinstance(value, bytearray): - value = bytes(value) - elif not isinstance(value, bytes): - raise ValueError('not bytes') - if max_length is not None and len(value) > max_length: - raise ValueError('too long') - if not empty_ok and len(value) == 0: - raise ValueError('empty bytes not allowed') - return value + bvalue = bytes(value) + elif isinstance(value, bytes): + bvalue = value + else: + raise ValueError("not bytes") + if max_length is not None and len(bvalue) > max_length: + raise ValueError("too long") + if not empty_ok and len(bvalue) == 0: + raise ValueError("empty bytes not allowed") + return bvalue @classmethod def _as_name(cls, value): @@ -429,49 +501,49 @@ class Rdata: if isinstance(value, str): return dns.name.from_text(value) elif not isinstance(value, dns.name.Name): - raise ValueError('not a name') + raise ValueError("not a name") return value @classmethod def _as_uint8(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 255: - raise ValueError('not a uint8') + raise ValueError("not a uint8") return value @classmethod def _as_uint16(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 65535: - raise ValueError('not a uint16') + raise ValueError("not a uint16") return value @classmethod def _as_uint32(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 4294967295: - raise ValueError('not a uint32') + raise ValueError("not a uint32") return value @classmethod def _as_uint48(cls, value): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if value < 0 or value > 281474976710655: - raise ValueError('not a uint48') + raise ValueError("not a uint48") return value @classmethod def _as_int(cls, value, low=None, high=None): if not isinstance(value, int): - raise ValueError('not an integer') + raise ValueError("not an integer") if low is not None and value < low: - raise ValueError('value too small') + raise ValueError("value too small") if high is not None and value > high: - raise ValueError('value too large') + raise ValueError("value too large") return value @classmethod @@ -483,7 +555,7 @@ class Rdata: elif isinstance(value, bytes): return dns.ipv4.inet_ntoa(value) else: - raise ValueError('not an IPv4 address') + raise ValueError("not an IPv4 address") @classmethod def _as_ipv6_address(cls, value): @@ -494,14 +566,14 @@ class Rdata: elif isinstance(value, bytes): return dns.ipv6.inet_ntoa(value) else: - raise ValueError('not an IPv6 address') + raise ValueError("not an IPv6 address") @classmethod def _as_bool(cls, value): if isinstance(value, bool): return value else: - raise ValueError('not a boolean') + raise ValueError("not a boolean") @classmethod def _as_ttl(cls, value): @@ -510,7 +582,7 @@ class Rdata: elif isinstance(value, str): return dns.ttl.from_text(value) else: - raise ValueError('not a TTL') + raise ValueError("not a TTL") @classmethod def _as_tuple(cls, value, as_value): @@ -532,6 +604,7 @@ class Rdata: return items +@dns.immutable.immutable class GenericRdata(Rdata): """Generic Rdata Class @@ -540,28 +613,32 @@ class GenericRdata(Rdata): implementation. It implements the DNS "unknown RRs" scheme. """ - __slots__ = ['data'] + __slots__ = ["data"] def __init__(self, rdclass, rdtype, data): super().__init__(rdclass, rdtype) - object.__setattr__(self, 'data', data) + self.data = data - def to_text(self, origin=None, relativize=True, **kw): - return r'\# %d ' % len(self.data) + _hexify(self.data, **kw) + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: + return r"\# %d " % len(self.data) + _hexify(self.data, **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): token = tok.get() - if not token.is_identifier() or token.value != r'\#': - raise dns.exception.SyntaxError( - r'generic rdata does not start with \#') + if not token.is_identifier() or token.value != r"\#": + raise dns.exception.SyntaxError(r"generic rdata does not start with \#") length = tok.get_int() hex = tok.concatenate_remaining_identifiers(True).encode() data = binascii.unhexlify(hex) if len(data) != length: - raise dns.exception.SyntaxError( - 'generic rdata hex data has wrong length') + raise dns.exception.SyntaxError("generic rdata hex data has wrong length") return cls(rdclass, rdtype, data) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -571,8 +648,12 @@ class GenericRdata(Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): return cls(rdclass, rdtype, parser.get_remaining()) -_rdata_classes = {} -_module_prefix = 'dns.rdtypes' + +_rdata_classes: Dict[ + Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any +] = {} +_module_prefix = "dns.rdtypes" + def get_rdata_class(rdclass, rdtype): cls = _rdata_classes.get((rdclass, rdtype)) @@ -581,16 +662,16 @@ def get_rdata_class(rdclass, rdtype): if not cls: rdclass_text = dns.rdataclass.to_text(rdclass) rdtype_text = dns.rdatatype.to_text(rdtype) - rdtype_text = rdtype_text.replace('-', '_') + rdtype_text = rdtype_text.replace("-", "_") try: - mod = import_module('.'.join([_module_prefix, - rdclass_text, rdtype_text])) + mod = import_module( + ".".join([_module_prefix, rdclass_text, rdtype_text]) + ) cls = getattr(mod, rdtype_text) _rdata_classes[(rdclass, rdtype)] = cls except ImportError: try: - mod = import_module('.'.join([_module_prefix, - 'ANY', rdtype_text])) + mod = import_module(".".join([_module_prefix, "ANY", rdtype_text])) cls = getattr(mod, rdtype_text) _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls @@ -602,8 +683,15 @@ def get_rdata_class(rdclass, rdtype): return cls -def from_text(rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None, idna_codec=None): +def from_text( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + tok: Union[dns.tokenizer.Tokenizer, str], + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> Rdata: """Build an rdata object from text format. This function attempts to dynamically load a class which @@ -617,9 +705,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, If *tok* is a ``str``, then a tokenizer is created and the string is used as its input. - *rdclass*, an ``int``, the rdataclass. + *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass. - *rdtype*, an ``int``, the rdatatype. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype. *tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``. @@ -651,17 +739,18 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, # peek at first token token = tok.get() tok.unget(token) - if token.is_identifier() and \ - token.value == r'\#': + if token.is_identifier() and token.value == r"\#": # # Known type using the generic syntax. Extract the # wire form from the generic syntax, and then run # from_wire on it. # - grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, - relativize, relativize_to) - rdata = from_wire(rdclass, rdtype, grdata.data, 0, - len(grdata.data), origin) + grdata = GenericRdata.from_text( + rdclass, rdtype, tok, origin, relativize, relativize_to + ) + rdata = from_wire( + rdclass, rdtype, grdata.data, 0, len(grdata.data), origin + ) # # If this comparison isn't equal, then there must have been # compressed names in the wire format, which is an error, @@ -669,19 +758,27 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, # rwire = rdata.to_wire() if rwire != grdata.data: - raise dns.exception.SyntaxError('compressed data in ' - 'generic syntax form ' - 'of known rdatatype') + raise dns.exception.SyntaxError( + "compressed data in " + "generic syntax form " + "of known rdatatype" + ) if rdata is None: - rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize, - relativize_to) + rdata = cls.from_text( + rdclass, rdtype, tok, origin, relativize, relativize_to + ) token = tok.get_eol_as_token() if token.comment is not None: - object.__setattr__(rdata, 'rdcomment', token.comment) + object.__setattr__(rdata, "rdcomment", token.comment) return rdata -def from_wire_parser(rdclass, rdtype, parser, origin=None): +def from_wire_parser( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + parser: dns.wire.Parser, + origin: Optional[dns.name.Name] = None, +) -> Rdata: """Build an rdata object from wire format This function attempts to dynamically load a class which @@ -692,9 +789,9 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None): Once a class is chosen, its from_wire() class method is called with the parameters to this function. - *rdclass*, an ``int``, the rdataclass. + *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass. - *rdtype*, an ``int``, the rdatatype. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype. *parser*, a ``dns.wire.Parser``, the parser, which should be restricted to the rdata length. @@ -712,7 +809,14 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None): return cls.from_wire_parser(rdclass, rdtype, parser, origin) -def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): +def from_wire( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + wire: bytes, + current: int, + rdlen: int, + origin: Optional[dns.name.Name] = None, +) -> Rdata: """Build an rdata object from wire format This function attempts to dynamically load a class which @@ -746,13 +850,21 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): class RdatatypeExists(dns.exception.DNSException): """DNS rdatatype already exists.""" - supp_kwargs = {'rdclass', 'rdtype'} - fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \ - "already exists." + + supp_kwargs = {"rdclass", "rdtype"} + fmt = ( + "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + + "already exists." + ) -def register_type(implementation, rdtype, rdtype_text, is_singleton=False, - rdclass=dns.rdataclass.IN): +def register_type( + implementation: Any, + rdtype: int, + rdtype_text: str, + is_singleton: bool = False, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, +) -> None: """Dynamically register a module to handle an rdatatype. *implementation*, a module implementing the type in the usual dnspython @@ -769,14 +881,16 @@ def register_type(implementation, rdtype, rdtype_text, is_singleton=False, it applies to all classes. """ - existing_cls = get_rdata_class(rdclass, rdtype) - if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): - raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + existing_cls = get_rdata_class(rdclass, the_rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) try: - if dns.rdatatype.RdataType(rdtype).name != rdtype_text: - raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text: + raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) except ValueError: pass - _rdata_classes[(rdclass, rdtype)] = getattr(implementation, - rdtype_text.replace('-', '_')) - dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) + _rdata_classes[(rdclass, the_rdtype)] = getattr( + implementation, rdtype_text.replace("-", "_") + ) + dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) diff --git a/lib/dns/rdata.pyi b/lib/dns/rdata.pyi deleted file mode 100644 index f394791f..00000000 --- a/lib/dns/rdata.pyi +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Dict, Tuple, Any, Optional, BinaryIO -from .name import Name, IDNACodec -class Rdata: - def __init__(self): - self.address : str - def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]: - ... - @classmethod - def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True): - ... -_rdata_modules : Dict[Tuple[Any,Rdata],Any] - -def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None, - relativize : bool = True, relativize_to : Optional[Name] = None, - idna_codec : Optional[IDNACodec] = None): - ... - -def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None): - ... diff --git a/lib/dns/rdataclass.py b/lib/dns/rdataclass.py index 41bba693..89b85a79 100644 --- a/lib/dns/rdataclass.py +++ b/lib/dns/rdataclass.py @@ -20,8 +20,10 @@ import dns.enum import dns.exception + class RdataClass(dns.enum.IntEnum): """DNS Rdata Class""" + RESERVED0 = 0 IN = 1 INTERNET = IN @@ -56,7 +58,7 @@ class UnknownRdataclass(dns.exception.DNSException): """A DNS class is unknown.""" -def from_text(text): +def from_text(text: str) -> RdataClass: """Convert text into a DNS rdata class value. The input text can be a defined DNS RR class mnemonic or @@ -68,13 +70,13 @@ def from_text(text): Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. - Returns an ``int``. + Returns a ``dns.rdataclass.RdataClass``. """ return RdataClass.from_text(text) -def to_text(value): +def to_text(value: RdataClass) -> str: """Convert a DNS rdata class value to text. If the value has a known mnemonic, it will be used, otherwise the @@ -88,18 +90,19 @@ def to_text(value): return RdataClass.to_text(value) -def is_metaclass(rdclass): +def is_metaclass(rdclass: RdataClass) -> bool: """True if the specified class is a metaclass. The currently defined metaclasses are ANY and NONE. - *rdclass* is an ``int``. + *rdclass* is a ``dns.rdataclass.RdataClass``. """ if rdclass in _metaclasses: return True return False + ### BEGIN generated RdataClass constants RESERVED0 = RdataClass.RESERVED0 diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index 579bc964..c0ede425 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -17,16 +17,20 @@ """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" +from typing import Any, cast, Collection, Dict, List, Optional, Union + import io import random import struct import dns.exception import dns.immutable +import dns.name import dns.rdatatype import dns.rdataclass import dns.rdata import dns.set +import dns.ttl # define SimpleSet here for backwards compatibility SimpleSet = dns.set.Set @@ -45,24 +49,30 @@ class Rdataset(dns.set.Set): """A DNS rdataset.""" - __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl'] + __slots__ = ["rdclass", "rdtype", "covers", "ttl"] - def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0): + def __init__( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ttl: int = 0, + ): """Create a new rdataset of the specified class and type. - *rdclass*, an ``int``, the rdataclass. + *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass. - *rdtype*, an ``int``, the rdatatype. + *rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype. - *covers*, an ``int``, the covered rdatatype. + *covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype. *ttl*, an ``int``, the TTL. """ super().__init__() self.rdclass = rdclass - self.rdtype = rdtype - self.covers = covers + self.rdtype: dns.rdatatype.RdataType = rdtype + self.covers: dns.rdatatype.RdataType = covers self.ttl = ttl def _clone(self): @@ -73,7 +83,7 @@ class Rdataset(dns.set.Set): obj.ttl = self.ttl return obj - def update_ttl(self, ttl): + def update_ttl(self, ttl: int) -> None: """Perform TTL minimization. Set the TTL of the rdataset to be the lesser of the set's current @@ -88,7 +98,9 @@ class Rdataset(dns.set.Set): elif ttl < self.ttl: self.ttl = ttl - def add(self, rd, ttl=None): # pylint: disable=arguments-differ + def add( # pylint: disable=arguments-differ,arguments-renamed + self, rd: dns.rdata.Rdata, ttl: Optional[int] = None + ) -> None: """Add the specified rdata to the rdataset. If the optional *ttl* parameter is supplied, then @@ -115,8 +127,7 @@ class Rdataset(dns.set.Set): raise IncompatibleTypes if ttl is not None: self.update_ttl(ttl) - if self.rdtype == dns.rdatatype.RRSIG or \ - self.rdtype == dns.rdatatype.SIG: + if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG: covers = rd.covers() if len(self) == 0 and self.covers == dns.rdatatype.NONE: self.covers = covers @@ -147,19 +158,26 @@ class Rdataset(dns.set.Set): def _rdata_repr(self): def maybe_truncate(s): if len(s) > 100: - return s[:100] + '...' + return s[:100] + "..." return s - return '[%s]' % ', '.join('<%s>' % maybe_truncate(str(rr)) - for rr in self) + + return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self) def __repr__(self): if self.covers == 0: - ctext = '' + ctext = "" else: - ctext = '(' + dns.rdatatype.to_text(self.covers) + ')' - return '' + ctext = "(" + dns.rdatatype.to_text(self.covers) + ")" + return ( + "" + ) def __str__(self): return self.to_text() @@ -167,17 +185,26 @@ class Rdataset(dns.set.Set): def __eq__(self, other): if not isinstance(other, Rdataset): return False - if self.rdclass != other.rdclass or \ - self.rdtype != other.rdtype or \ - self.covers != other.covers: + if ( + self.rdclass != other.rdclass + or self.rdtype != other.rdtype + or self.covers != other.covers + ): return False return super().__eq__(other) def __ne__(self, other): return not self.__eq__(other) - def to_text(self, name=None, origin=None, relativize=True, - override_rdclass=None, want_comments=False, **kw): + def to_text( + self, + name: Optional[dns.name.Name] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + override_rdclass: Optional[dns.rdataclass.RdataClass] = None, + want_comments: bool = False, + **kw: Dict[str, Any], + ) -> str: """Convert the rdataset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -206,10 +233,10 @@ class Rdataset(dns.set.Set): if name is not None: name = name.choose_relativity(origin, relativize) ntext = str(name) - pad = ' ' + pad = " " else: - ntext = '' - pad = '' + ntext = "" + pad = "" s = io.StringIO() if override_rdclass is not None: rdclass = override_rdclass @@ -221,28 +248,46 @@ class Rdataset(dns.set.Set): # some dynamic updates, so we don't need to print out the TTL # (which is meaningless anyway). # - s.write('{}{}{} {}\n'.format(ntext, pad, - dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype))) + s.write( + "{}{}{} {}\n".format( + ntext, + pad, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype), + ) + ) else: for rd in self: - extra = '' + extra = "" if want_comments: if rd.rdcomment: - extra = f' ;{rd.rdcomment}' - s.write('%s%s%d %s %s %s%s\n' % - (ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype), - rd.to_text(origin=origin, relativize=relativize, - **kw), - extra)) + extra = f" ;{rd.rdcomment}" + s.write( + "%s%s%d %s %s %s%s\n" + % ( + ntext, + pad, + self.ttl, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype), + rd.to_text(origin=origin, relativize=relativize, **kw), + extra, + ) + ) # # We strip off the final \n for the caller's convenience in printing # return s.getvalue()[:-1] - def to_wire(self, name, file, compress=None, origin=None, - override_rdclass=None, want_shuffle=True): + def to_wire( + self, + name: dns.name.Name, + file: Any, + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + override_rdclass: Optional[dns.rdataclass.RdataClass] = None, + want_shuffle: bool = True, + ) -> int: """Convert the rdataset to wire format. *name*, a ``dns.name.Name`` is the owner name to use. @@ -279,6 +324,7 @@ class Rdataset(dns.set.Set): file.write(stuff) return 1 else: + l: Union[Rdataset, List[dns.rdata.Rdata]] if want_shuffle: l = list(self) random.shuffle(l) @@ -286,8 +332,7 @@ class Rdataset(dns.set.Set): l = self for rd in l: name.to_wire(file, compress, origin) - stuff = struct.pack("!HHIH", self.rdtype, rdclass, - self.ttl, 0) + stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0) file.write(stuff) start = file.tell() rd.to_wire(file, compress, origin) @@ -299,17 +344,20 @@ class Rdataset(dns.set.Set): file.seek(0, io.SEEK_END) return len(self) - def match(self, rdclass, rdtype, covers): + def match( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> bool: """Returns ``True`` if this rdataset matches the specified class, type, and covers. """ - if self.rdclass == rdclass and \ - self.rdtype == rdtype and \ - self.covers == covers: + if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers: return True return False - def processing_order(self): + def processing_order(self) -> List[dns.rdata.Rdata]: """Return rdatas in a valid processing order according to the type's specification. For example, MX records are in preference order from lowest to highest preferences, with items of the same preference @@ -325,51 +373,56 @@ class Rdataset(dns.set.Set): @dns.immutable.immutable -class ImmutableRdataset(Rdataset): +class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] """An immutable DNS rdataset.""" _clone_class = Rdataset - def __init__(self, rdataset): + def __init__(self, rdataset: Rdataset): """Create an immutable rdataset from the specified rdataset.""" - super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers, - rdataset.ttl) + super().__init__( + rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl + ) self.items = dns.immutable.Dict(rdataset.items) def update_ttl(self, ttl): - raise TypeError('immutable') + raise TypeError("immutable") def add(self, rd, ttl=None): - raise TypeError('immutable') + raise TypeError("immutable") def union_update(self, other): - raise TypeError('immutable') + raise TypeError("immutable") def intersection_update(self, other): - raise TypeError('immutable') + raise TypeError("immutable") def update(self, other): - raise TypeError('immutable') + raise TypeError("immutable") def __delitem__(self, i): - raise TypeError('immutable') + raise TypeError("immutable") - def __ior__(self, other): - raise TypeError('immutable') + # lgtm complains about these not raising ArithmeticError, but there is + # precedent for overrides of these methods in other classes to raise + # TypeError, and it seems like the better exception. - def __iand__(self, other): - raise TypeError('immutable') + def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") - def __iadd__(self, other): - raise TypeError('immutable') + def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") - def __isub__(self, other): - raise TypeError('immutable') + def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") + + def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") def clear(self): - raise TypeError('immutable') + raise TypeError("immutable") def __copy__(self): return ImmutableRdataset(super().copy()) @@ -386,9 +439,20 @@ class ImmutableRdataset(Rdataset): def difference(self, other): return ImmutableRdataset(super().difference(other)) + def symmetric_difference(self, other): + return ImmutableRdataset(super().symmetric_difference(other)) -def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, - origin=None, relativize=True, relativize_to=None): + +def from_text_list( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. @@ -407,28 +471,34 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, Returns a ``dns.rdataset.Rdataset`` object. """ - rdclass = dns.rdataclass.RdataClass.make(rdclass) - rdtype = dns.rdatatype.RdataType.make(rdtype) - r = Rdataset(rdclass, rdtype) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + r = Rdataset(the_rdclass, the_rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, - relativize_to, idna_codec) + rd = dns.rdata.from_text( + r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec + ) r.add(rd) return r -def from_text(rdclass, rdtype, ttl, *text_rdatas): +def from_text( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, + *text_rdatas: Any, +) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified rdatas in text format. Returns a ``dns.rdataset.Rdataset`` object. """ - return from_text_list(rdclass, rdtype, ttl, text_rdatas) + return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas)) -def from_rdata_list(ttl, rdatas): +def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset: """Create an rdataset with the specified TTL, and with the specified list of rdata objects. @@ -443,14 +513,15 @@ def from_rdata_list(ttl, rdatas): r = Rdataset(rd.rdclass, rd.rdtype) r.update_ttl(ttl) r.add(rd) + assert r is not None return r -def from_rdata(ttl, *rdatas): +def from_rdata(ttl: int, *rdatas: Any) -> Rdataset: """Create an rdataset with the specified TTL, and with the specified rdata objects. Returns a ``dns.rdataset.Rdataset`` object. """ - return from_rdata_list(ttl, rdatas) + return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas)) diff --git a/lib/dns/rdataset.pyi b/lib/dns/rdataset.pyi deleted file mode 100644 index a7bbf2d4..00000000 --- a/lib/dns/rdataset.pyi +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional, Dict, List, Union -from io import BytesIO -from . import exception, name, set, rdatatype, rdata, rdataset - -class DifferingCovers(exception.DNSException): - """An attempt was made to add a DNS SIG/RRSIG whose covered type - is not the same as that of the other rdatas in the rdataset.""" - - -class IncompatibleTypes(exception.DNSException): - """An attempt was made to add DNS RR data of an incompatible type.""" - - -class Rdataset(set.Set): - def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0): - self.rdclass : int = rdclass - self.rdtype : int = rdtype - self.covers : int = covers - self.ttl : int = ttl - - def update_ttl(self, ttl : int) -> None: - ... - - def add(self, rd : rdata.Rdata, ttl : Optional[int] =None): - ... - - def union_update(self, other : Rdataset): - ... - - def intersection_update(self, other : Rdataset): - ... - - def update(self, other : Rdataset): - ... - - def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True, - override_rdclass : Optional[int] =None, **kw) -> bytes: - ... - - def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None, - override_rdclass : Optional[int] = None, want_shuffle=True) -> int: - ... - - def match(self, rdclass : int, rdtype : int, covers : int) -> bool: - ... - - -def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset: - ... - -def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset: - ... - -def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: - ... - -def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: - ... diff --git a/lib/dns/rdatatype.py b/lib/dns/rdatatype.py index 9499c7b9..e6c58186 100644 --- a/lib/dns/rdatatype.py +++ b/lib/dns/rdatatype.py @@ -17,11 +17,15 @@ """DNS Rdata Types.""" +from typing import Dict + import dns.enum import dns.exception + class RdataType(dns.enum.IntEnum): """DNS Rdata Type""" + TYPE0 = 0 NONE = 0 A = 1 @@ -116,24 +120,47 @@ class RdataType(dns.enum.IntEnum): def _prefix(cls): return "TYPE" + @classmethod + def _extra_from_text(cls, text): + if text.find("-") >= 0: + try: + return cls[text.replace("-", "_")] + except KeyError: + pass + return _registered_by_text.get(text) + + @classmethod + def _extra_to_text(cls, value, current_text): + if current_text is None: + return _registered_by_value.get(value) + if current_text.find("_") >= 0: + return current_text.replace("_", "-") + return current_text + @classmethod def _unknown_exception_class(cls): return UnknownRdatatype -_registered_by_text = {} -_registered_by_value = {} + +_registered_by_text: Dict[str, RdataType] = {} +_registered_by_value: Dict[RdataType, str] = {} _metatypes = {RdataType.OPT} -_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, - RdataType.NSEC, RdataType.CNAME} +_singletons = { + RdataType.SOA, + RdataType.NXT, + RdataType.DNAME, + RdataType.NSEC, + RdataType.CNAME, +} class UnknownRdatatype(dns.exception.DNSException): """DNS resource record type is unknown.""" -def from_text(text): +def from_text(text: str) -> RdataType: """Convert text into a DNS rdata type value. The input text can be a defined DNS RR type mnemonic or @@ -145,20 +172,13 @@ def from_text(text): Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. - Returns an ``int``. + Returns a ``dns.rdatatype.RdataType``. """ - text = text.upper().replace('-', '_') - try: - return RdataType.from_text(text) - except UnknownRdatatype: - registered_type = _registered_by_text.get(text) - if registered_type: - return registered_type - raise + return RdataType.from_text(text) -def to_text(value): +def to_text(value: RdataType) -> str: """Convert a DNS rdata type value to text. If the value has a known mnemonic, it will be used, otherwise the @@ -169,18 +189,13 @@ def to_text(value): Returns a ``str``. """ - text = RdataType.to_text(value) - if text.startswith("TYPE"): - registered_text = _registered_by_value.get(value) - if registered_text: - text = registered_text - return text.replace('_', '-') + return RdataType.to_text(value) -def is_metatype(rdtype): +def is_metatype(rdtype: RdataType) -> bool: """True if the specified type is a metatype. - *rdtype* is an ``int``. + *rdtype* is a ``dns.rdatatype.RdataType``. The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA, MAILB, ANY, and OPT. @@ -191,7 +206,7 @@ def is_metatype(rdtype): return (256 > rdtype >= 128) or rdtype in _metatypes -def is_singleton(rdtype): +def is_singleton(rdtype: RdataType) -> bool: """Is the specified type a singleton type? Singleton types can only have a single rdata in an rdataset, or a single @@ -209,11 +224,14 @@ def is_singleton(rdtype): return True return False + # pylint: disable=redefined-outer-name -def register_type(rdtype, rdtype_text, is_singleton=False): +def register_type( + rdtype: RdataType, rdtype_text: str, is_singleton: bool = False +) -> None: """Dynamically register an rdatatype. - *rdtype*, an ``int``, the rdatatype to register. + *rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register. *rdtype_text*, a ``str``, the textual form of the rdatatype. @@ -226,6 +244,7 @@ def register_type(rdtype, rdtype_text, is_singleton=False): if is_singleton: _singletons.add(rdtype) + ### BEGIN generated RdataType constants TYPE0 = RdataType.TYPE0 diff --git a/lib/dns/rdtypes/ANY/AMTRELAY.py b/lib/dns/rdtypes/ANY/AMTRELAY.py index 9f093dee..dfe7abc3 100644 --- a/lib/dns/rdtypes/ANY/AMTRELAY.py +++ b/lib/dns/rdtypes/ANY/AMTRELAY.py @@ -23,7 +23,7 @@ import dns.rdtypes.util class Relay(dns.rdtypes.util.Gateway): - name = 'AMTRELAY relay' + name = "AMTRELAY relay" @property def relay(self): @@ -37,10 +37,11 @@ class AMTRELAY(dns.rdata.Rdata): # see: RFC 8777 - __slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay'] + __slots__ = ["precedence", "discovery_optional", "relay_type", "relay"] - def __init__(self, rdclass, rdtype, precedence, discovery_optional, - relay_type, relay): + def __init__( + self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay + ): super().__init__(rdclass, rdtype) relay = Relay(relay_type, relay) self.precedence = self._as_uint8(precedence) @@ -50,37 +51,42 @@ class AMTRELAY(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) - return '%d %d %d %s' % (self.precedence, self.discovery_optional, - self.relay_type, relay) + return "%d %d %d %s" % ( + self.precedence, + self.discovery_optional, + self.relay_type, + relay, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): precedence = tok.get_uint8() discovery_optional = tok.get_uint8() if discovery_optional > 1: - raise dns.exception.SyntaxError('expecting 0 or 1') + raise dns.exception.SyntaxError("expecting 0 or 1") discovery_optional = bool(discovery_optional) relay_type = tok.get_uint8() - if relay_type > 0x7f: - raise dns.exception.SyntaxError('expecting an integer <= 127') - relay = Relay.from_text(relay_type, tok, origin, relativize, - relativize_to) - return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, - relay.relay) + if relay_type > 0x7F: + raise dns.exception.SyntaxError("expecting an integer <= 127") + relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to) + return cls( + rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): relay_type = self.relay_type | (self.discovery_optional << 7) header = struct.pack("!BB", self.precedence, relay_type) file.write(header) - Relay(self.relay_type, self.relay).to_wire(file, compress, origin, - canonicalize) + Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (precedence, relay_type) = parser.get_struct('!BB') + (precedence, relay_type) = parser.get_struct("!BB") discovery_optional = bool(relay_type >> 7) - relay_type &= 0x7f + relay_type &= 0x7F relay = Relay.from_wire_parser(relay_type, parser, origin) - return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, - relay.relay) + return cls( + rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay + ) diff --git a/lib/dns/rdtypes/ANY/CAA.py b/lib/dns/rdtypes/ANY/CAA.py index c86b45ea..8afb538c 100644 --- a/lib/dns/rdtypes/ANY/CAA.py +++ b/lib/dns/rdtypes/ANY/CAA.py @@ -30,7 +30,7 @@ class CAA(dns.rdata.Rdata): # see: RFC 6844 - __slots__ = ['flags', 'tag', 'value'] + __slots__ = ["flags", "tag", "value"] def __init__(self, rdclass, rdtype, flags, tag, value): super().__init__(rdclass, rdtype) @@ -41,23 +41,26 @@ class CAA(dns.rdata.Rdata): self.value = self._as_bytes(value) def to_text(self, origin=None, relativize=True, **kw): - return '%u %s "%s"' % (self.flags, - dns.rdata._escapify(self.tag), - dns.rdata._escapify(self.value)) + return '%u %s "%s"' % ( + self.flags, + dns.rdata._escapify(self.tag), + dns.rdata._escapify(self.value), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): flags = tok.get_uint8() tag = tok.get_string().encode() value = tok.get_string().encode() return cls(rdclass, rdtype, flags, tag, value) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!B', self.flags)) + file.write(struct.pack("!B", self.flags)) l = len(self.tag) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.tag) file.write(self.value) diff --git a/lib/dns/rdtypes/ANY/CDNSKEY.py b/lib/dns/rdtypes/ANY/CDNSKEY.py index 14b19417..869523fb 100644 --- a/lib/dns/rdtypes/ANY/CDNSKEY.py +++ b/lib/dns/rdtypes/ANY/CDNSKEY.py @@ -15,13 +15,19 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dnskeybase +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 +from dns.rdtypes.dnskeybase import ( + SEP, + REVOKE, + ZONE, +) # noqa: F401 lgtm[py/unused-import] + # pylint: enable=unused-import + @dns.immutable.immutable class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): diff --git a/lib/dns/rdtypes/ANY/CERT.py b/lib/dns/rdtypes/ANY/CERT.py index f35ce3ad..1b0cbeca 100644 --- a/lib/dns/rdtypes/ANY/CERT.py +++ b/lib/dns/rdtypes/ANY/CERT.py @@ -20,34 +20,34 @@ import base64 import dns.exception import dns.immutable -import dns.dnssec +import dns.dnssectypes import dns.rdata import dns.tokenizer _ctype_by_value = { - 1: 'PKIX', - 2: 'SPKI', - 3: 'PGP', - 4: 'IPKIX', - 5: 'ISPKI', - 6: 'IPGP', - 7: 'ACPKIX', - 8: 'IACPKIX', - 253: 'URI', - 254: 'OID', + 1: "PKIX", + 2: "SPKI", + 3: "PGP", + 4: "IPKIX", + 5: "ISPKI", + 6: "IPGP", + 7: "ACPKIX", + 8: "IACPKIX", + 253: "URI", + 254: "OID", } _ctype_by_name = { - 'PKIX': 1, - 'SPKI': 2, - 'PGP': 3, - 'IPKIX': 4, - 'ISPKI': 5, - 'IPGP': 6, - 'ACPKIX': 7, - 'IACPKIX': 8, - 'URI': 253, - 'OID': 254, + "PKIX": 1, + "SPKI": 2, + "PGP": 3, + "IPKIX": 4, + "ISPKI": 5, + "IPGP": 6, + "ACPKIX": 7, + "IACPKIX": 8, + "URI": 253, + "OID": 254, } @@ -72,10 +72,11 @@ class CERT(dns.rdata.Rdata): # see RFC 4398 - __slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate'] + __slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"] - def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, - certificate): + def __init__( + self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate + ): super().__init__(rdclass, rdtype) self.certificate_type = self._as_uint16(certificate_type) self.key_tag = self._as_uint16(key_tag) @@ -84,24 +85,28 @@ class CERT(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): certificate_type = _ctype_to_text(self.certificate_type) - return "%s %d %s %s" % (certificate_type, self.key_tag, - dns.dnssec.algorithm_to_text(self.algorithm), - dns.rdata._base64ify(self.certificate, **kw)) + return "%s %d %s %s" % ( + certificate_type, + self.key_tag, + dns.dnssectypes.Algorithm.to_text(self.algorithm), + dns.rdata._base64ify(self.certificate, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): certificate_type = _ctype_from_text(tok.get_string()) key_tag = tok.get_uint16() - algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) + algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) b64 = tok.concatenate_remaining_identifiers().encode() certificate = base64.b64decode(b64) - return cls(rdclass, rdtype, certificate_type, key_tag, - algorithm, certificate) + return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - prefix = struct.pack("!HHB", self.certificate_type, self.key_tag, - self.algorithm) + prefix = struct.pack( + "!HHB", self.certificate_type, self.key_tag, self.algorithm + ) file.write(prefix) file.write(self.certificate) @@ -109,5 +114,4 @@ class CERT(dns.rdata.Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB") certificate = parser.get_remaining() - return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, - certificate) + return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) diff --git a/lib/dns/rdtypes/ANY/CSYNC.py b/lib/dns/rdtypes/ANY/CSYNC.py index 979028ae..f819c08c 100644 --- a/lib/dns/rdtypes/ANY/CSYNC.py +++ b/lib/dns/rdtypes/ANY/CSYNC.py @@ -27,7 +27,7 @@ import dns.rdtypes.util @dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): - type_name = 'CSYNC' + type_name = "CSYNC" @dns.immutable.immutable @@ -35,7 +35,7 @@ class CSYNC(dns.rdata.Rdata): """CSYNC record""" - __slots__ = ['serial', 'flags', 'windows'] + __slots__ = ["serial", "flags", "windows"] def __init__(self, rdclass, rdtype, serial, flags, windows): super().__init__(rdclass, rdtype) @@ -47,18 +47,19 @@ class CSYNC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): text = Bitmap(self.windows).to_text() - return '%d %d%s' % (self.serial, self.flags, text) + return "%d %d%s" % (self.serial, self.flags, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): serial = tok.get_uint32() flags = tok.get_uint16() bitmap = Bitmap.from_text(tok) return cls(rdclass, rdtype, serial, flags, bitmap) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!IH', self.serial, self.flags)) + file.write(struct.pack("!IH", self.serial, self.flags)) Bitmap(self.windows).to_wire(file) @classmethod diff --git a/lib/dns/rdtypes/ANY/DNSKEY.py b/lib/dns/rdtypes/ANY/DNSKEY.py index e69a7c19..50fa05b7 100644 --- a/lib/dns/rdtypes/ANY/DNSKEY.py +++ b/lib/dns/rdtypes/ANY/DNSKEY.py @@ -15,13 +15,19 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dnskeybase +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 +from dns.rdtypes.dnskeybase import ( + SEP, + REVOKE, + ZONE, +) # noqa: F401 lgtm[py/unused-import] + # pylint: enable=unused-import + @dns.immutable.immutable class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): diff --git a/lib/dns/rdtypes/ANY/GPOS.py b/lib/dns/rdtypes/ANY/GPOS.py index 29fa8f8b..30aab321 100644 --- a/lib/dns/rdtypes/ANY/GPOS.py +++ b/lib/dns/rdtypes/ANY/GPOS.py @@ -26,19 +26,19 @@ import dns.tokenizer def _validate_float_string(what): if len(what) == 0: raise dns.exception.FormError - if what[0] == b'-'[0] or what[0] == b'+'[0]: + if what[0] == b"-"[0] or what[0] == b"+"[0]: what = what[1:] if what.isdigit(): return try: - (left, right) = what.split(b'.') + (left, right) = what.split(b".") except ValueError: raise dns.exception.FormError - if left == b'' and right == b'': + if left == b"" and right == b"": raise dns.exception.FormError - if not left == b'' and not left.decode().isdigit(): + if not left == b"" and not left.decode().isdigit(): raise dns.exception.FormError - if not right == b'' and not right.decode().isdigit(): + if not right == b"" and not right.decode().isdigit(): raise dns.exception.FormError @@ -49,18 +49,15 @@ class GPOS(dns.rdata.Rdata): # see: RFC 1712 - __slots__ = ['latitude', 'longitude', 'altitude'] + __slots__ = ["latitude", "longitude", "altitude"] def __init__(self, rdclass, rdtype, latitude, longitude, altitude): super().__init__(rdclass, rdtype) - if isinstance(latitude, float) or \ - isinstance(latitude, int): + if isinstance(latitude, float) or isinstance(latitude, int): latitude = str(latitude) - if isinstance(longitude, float) or \ - isinstance(longitude, int): + if isinstance(longitude, float) or isinstance(longitude, int): longitude = str(longitude) - if isinstance(altitude, float) or \ - isinstance(altitude, int): + if isinstance(altitude, float) or isinstance(altitude, int): altitude = str(altitude) latitude = self._as_bytes(latitude, True, 255) longitude = self._as_bytes(longitude, True, 255) @@ -73,19 +70,20 @@ class GPOS(dns.rdata.Rdata): self.altitude = altitude flat = self.float_latitude if flat < -90.0 or flat > 90.0: - raise dns.exception.FormError('bad latitude') + raise dns.exception.FormError("bad latitude") flong = self.float_longitude if flong < -180.0 or flong > 180.0: - raise dns.exception.FormError('bad longitude') + raise dns.exception.FormError("bad longitude") def to_text(self, origin=None, relativize=True, **kw): - return '{} {} {}'.format(self.latitude.decode(), - self.longitude.decode(), - self.altitude.decode()) + return "{} {} {}".format( + self.latitude.decode(), self.longitude.decode(), self.altitude.decode() + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): latitude = tok.get_string() longitude = tok.get_string() altitude = tok.get_string() @@ -94,15 +92,15 @@ class GPOS(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.latitude) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.latitude) l = len(self.longitude) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.longitude) l = len(self.altitude) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.altitude) @classmethod diff --git a/lib/dns/rdtypes/ANY/HINFO.py b/lib/dns/rdtypes/ANY/HINFO.py index cd049693..513c155a 100644 --- a/lib/dns/rdtypes/ANY/HINFO.py +++ b/lib/dns/rdtypes/ANY/HINFO.py @@ -30,7 +30,7 @@ class HINFO(dns.rdata.Rdata): # see: RFC 1035 - __slots__ = ['cpu', 'os'] + __slots__ = ["cpu", "os"] def __init__(self, rdclass, rdtype, cpu, os): super().__init__(rdclass, rdtype) @@ -38,12 +38,14 @@ class HINFO(dns.rdata.Rdata): self.os = self._as_bytes(os, True, 255) def to_text(self, origin=None, relativize=True, **kw): - return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), - dns.rdata._escapify(self.os)) + return '"{}" "{}"'.format( + dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os) + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): cpu = tok.get_string(max_length=255) os = tok.get_string(max_length=255) return cls(rdclass, rdtype, cpu, os) @@ -51,11 +53,11 @@ class HINFO(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.cpu) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.cpu) l = len(self.os) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.os) @classmethod diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index e887359b..01fec822 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -32,7 +32,7 @@ class HIP(dns.rdata.Rdata): # see: RFC 5205 - __slots__ = ['hit', 'algorithm', 'key', 'servers'] + __slots__ = ["hit", "algorithm", "key", "servers"] def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): super().__init__(rdclass, rdtype) @@ -43,18 +43,19 @@ class HIP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): hit = binascii.hexlify(self.hit).decode() - key = base64.b64encode(self.key).replace(b'\n', b'').decode() - text = '' + key = base64.b64encode(self.key).replace(b"\n", b"").decode() + text = "" servers = [] for server in self.servers: servers.append(server.choose_relativity(origin, relativize)) if len(servers) > 0: - text += (' ' + ' '.join((x.to_unicode() for x in servers))) - return '%u %s %s%s' % (self.algorithm, hit, key, text) + text += " " + " ".join((x.to_unicode() for x in servers)) + return "%u %s %s%s" % (self.algorithm, hit, key, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() hit = binascii.unhexlify(tok.get_string().encode()) key = base64.b64decode(tok.get_string().encode()) @@ -75,7 +76,7 @@ class HIP(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (lh, algorithm, lk) = parser.get_struct('!BBH') + (lh, algorithm, lk) = parser.get_struct("!BBH") hit = parser.get_bytes(lh) key = parser.get_bytes(lk) servers = [] diff --git a/lib/dns/rdtypes/ANY/ISDN.py b/lib/dns/rdtypes/ANY/ISDN.py index b9a49adb..536a35d6 100644 --- a/lib/dns/rdtypes/ANY/ISDN.py +++ b/lib/dns/rdtypes/ANY/ISDN.py @@ -30,7 +30,7 @@ class ISDN(dns.rdata.Rdata): # see: RFC 1183 - __slots__ = ['address', 'subaddress'] + __slots__ = ["address", "subaddress"] def __init__(self, rdclass, rdtype, address, subaddress): super().__init__(rdclass, rdtype) @@ -39,31 +39,33 @@ class ISDN(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): if self.subaddress: - return '"{}" "{}"'.format(dns.rdata._escapify(self.address), - dns.rdata._escapify(self.subaddress)) + return '"{}" "{}"'.format( + dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress) + ) else: return '"%s"' % dns.rdata._escapify(self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() tokens = tok.get_remaining(max_tokens=1) if len(tokens) >= 1: subaddress = tokens[0].unescape().value else: - subaddress = '' + subaddress = "" return cls(rdclass, rdtype, address, subaddress) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.address) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.address) l = len(self.subaddress) if l > 0: assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.subaddress) @classmethod @@ -72,5 +74,5 @@ class ISDN(dns.rdata.Rdata): if parser.remaining() > 0: subaddress = parser.get_counted_bytes() else: - subaddress = b'' + subaddress = b"" return cls(rdclass, rdtype, address, subaddress) diff --git a/lib/dns/rdtypes/ANY/L32.py b/lib/dns/rdtypes/ANY/L32.py index 47eff958..14be01f9 100644 --- a/lib/dns/rdtypes/ANY/L32.py +++ b/lib/dns/rdtypes/ANY/L32.py @@ -3,6 +3,7 @@ import struct import dns.immutable +import dns.rdata @dns.immutable.immutable @@ -12,7 +13,7 @@ class L32(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'locator32'] + __slots__ = ["preference", "locator32"] def __init__(self, rdclass, rdtype, preference, locator32): super().__init__(rdclass, rdtype) @@ -20,17 +21,18 @@ class L32(dns.rdata.Rdata): self.locator32 = self._as_ipv4_address(locator32) def to_text(self, origin=None, relativize=True, **kw): - return f'{self.preference} {self.locator32}' + return f"{self.preference} {self.locator32}" @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() nodeid = tok.get_identifier() return cls(rdclass, rdtype, preference, nodeid) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) + file.write(struct.pack("!H", self.preference)) file.write(dns.ipv4.inet_aton(self.locator32)) @classmethod diff --git a/lib/dns/rdtypes/ANY/L64.py b/lib/dns/rdtypes/ANY/L64.py index aab36a82..d083d403 100644 --- a/lib/dns/rdtypes/ANY/L64.py +++ b/lib/dns/rdtypes/ANY/L64.py @@ -13,33 +13,33 @@ class L64(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'locator64'] + __slots__ = ["preference", "locator64"] def __init__(self, rdclass, rdtype, preference, locator64): super().__init__(rdclass, rdtype) self.preference = self._as_uint16(preference) if isinstance(locator64, bytes): if len(locator64) != 8: - raise ValueError('invalid locator64') - self.locator64 = dns.rdata._hexify(locator64, 4, b':') + raise ValueError("invalid locator64") + self.locator64 = dns.rdata._hexify(locator64, 4, b":") else: - dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':') + dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":") self.locator64 = locator64 def to_text(self, origin=None, relativize=True, **kw): - return f'{self.preference} {self.locator64}' + return f"{self.preference} {self.locator64}" @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() locator64 = tok.get_identifier() return cls(rdclass, rdtype, preference, locator64) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) - file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, - 4, 4, ':')) + file.write(struct.pack("!H", self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":")) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index c9398994..52c97532 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -93,15 +93,15 @@ def _decode_size(what, desc): def _check_coordinate_list(value, low, high): if value[0] < low or value[0] > high: - raise ValueError(f'not in range [{low}, {high}]') + raise ValueError(f"not in range [{low}, {high}]") if value[1] < 0 or value[1] > 59: - raise ValueError('bad minutes value') + raise ValueError("bad minutes value") if value[2] < 0 or value[2] > 59: - raise ValueError('bad seconds value') + raise ValueError("bad seconds value") if value[3] < 0 or value[3] > 999: - raise ValueError('bad milliseconds value') + raise ValueError("bad milliseconds value") if value[4] != 1 and value[4] != -1: - raise ValueError('bad hemisphere value') + raise ValueError("bad hemisphere value") @dns.immutable.immutable @@ -111,12 +111,26 @@ class LOC(dns.rdata.Rdata): # see: RFC 1876 - __slots__ = ['latitude', 'longitude', 'altitude', 'size', - 'horizontal_precision', 'vertical_precision'] + __slots__ = [ + "latitude", + "longitude", + "altitude", + "size", + "horizontal_precision", + "vertical_precision", + ] - def __init__(self, rdclass, rdtype, latitude, longitude, altitude, - size=_default_size, hprec=_default_hprec, - vprec=_default_vprec): + def __init__( + self, + rdclass, + rdtype, + latitude, + longitude, + altitude, + size=_default_size, + hprec=_default_hprec, + vprec=_default_vprec, + ): """Initialize a LOC record instance. The parameters I{latitude} and I{longitude} may be either a 4-tuple @@ -145,34 +159,44 @@ class LOC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): if self.latitude[4] > 0: - lat_hemisphere = 'N' + lat_hemisphere = "N" else: - lat_hemisphere = 'S' + lat_hemisphere = "S" if self.longitude[4] > 0: - long_hemisphere = 'E' + long_hemisphere = "E" else: - long_hemisphere = 'W' + long_hemisphere = "W" text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % ( - self.latitude[0], self.latitude[1], - self.latitude[2], self.latitude[3], lat_hemisphere, - self.longitude[0], self.longitude[1], self.longitude[2], - self.longitude[3], long_hemisphere, - self.altitude / 100.0 + self.latitude[0], + self.latitude[1], + self.latitude[2], + self.latitude[3], + lat_hemisphere, + self.longitude[0], + self.longitude[1], + self.longitude[2], + self.longitude[3], + long_hemisphere, + self.altitude / 100.0, ) # do not print default values - if self.size != _default_size or \ - self.horizontal_precision != _default_hprec or \ - self.vertical_precision != _default_vprec: + if ( + self.size != _default_size + or self.horizontal_precision != _default_hprec + or self.vertical_precision != _default_vprec + ): text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( - self.size / 100.0, self.horizontal_precision / 100.0, - self.vertical_precision / 100.0 + self.size / 100.0, + self.horizontal_precision / 100.0, + self.vertical_precision / 100.0, ) return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): latitude = [0, 0, 0, 0, 1] longitude = [0, 0, 0, 0, 1] size = _default_size @@ -184,16 +208,14 @@ class LOC(dns.rdata.Rdata): if t.isdigit(): latitude[1] = int(t) t = tok.get_string() - if '.' in t: - (seconds, milliseconds) = t.split('.') + if "." in t: + (seconds, milliseconds) = t.split(".") if not seconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad latitude seconds value') + raise dns.exception.SyntaxError("bad latitude seconds value") latitude[2] = int(seconds) l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad latitude milliseconds value') + raise dns.exception.SyntaxError("bad latitude milliseconds value") if l == 1: m = 100 elif l == 2: @@ -205,26 +227,24 @@ class LOC(dns.rdata.Rdata): elif t.isdigit(): latitude[2] = int(t) t = tok.get_string() - if t == 'S': + if t == "S": latitude[4] = -1 - elif t != 'N': - raise dns.exception.SyntaxError('bad latitude hemisphere value') + elif t != "N": + raise dns.exception.SyntaxError("bad latitude hemisphere value") longitude[0] = tok.get_int() t = tok.get_string() if t.isdigit(): longitude[1] = int(t) t = tok.get_string() - if '.' in t: - (seconds, milliseconds) = t.split('.') + if "." in t: + (seconds, milliseconds) = t.split(".") if not seconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad longitude seconds value') + raise dns.exception.SyntaxError("bad longitude seconds value") longitude[2] = int(seconds) l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): - raise dns.exception.SyntaxError( - 'bad longitude milliseconds value') + raise dns.exception.SyntaxError("bad longitude milliseconds value") if l == 1: m = 100 elif l == 2: @@ -236,64 +256,75 @@ class LOC(dns.rdata.Rdata): elif t.isdigit(): longitude[2] = int(t) t = tok.get_string() - if t == 'W': + if t == "W": longitude[4] = -1 - elif t != 'E': - raise dns.exception.SyntaxError('bad longitude hemisphere value') + elif t != "E": + raise dns.exception.SyntaxError("bad longitude hemisphere value") t = tok.get_string() - if t[-1] == 'm': - t = t[0: -1] - altitude = float(t) * 100.0 # m -> cm + if t[-1] == "m": + t = t[0:-1] + altitude = float(t) * 100.0 # m -> cm tokens = tok.get_remaining(max_tokens=3) if len(tokens) >= 1: value = tokens[0].unescape().value - if value[-1] == 'm': - value = value[0: -1] - size = float(value) * 100.0 # m -> cm + if value[-1] == "m": + value = value[0:-1] + size = float(value) * 100.0 # m -> cm if len(tokens) >= 2: value = tokens[1].unescape().value - if value[-1] == 'm': - value = value[0: -1] - hprec = float(value) * 100.0 # m -> cm + if value[-1] == "m": + value = value[0:-1] + hprec = float(value) * 100.0 # m -> cm if len(tokens) >= 3: value = tokens[2].unescape().value - if value[-1] == 'm': - value = value[0: -1] - vprec = float(value) * 100.0 # m -> cm + if value[-1] == "m": + value = value[0:-1] + vprec = float(value) * 100.0 # m -> cm # Try encoding these now so we raise if they are bad _encode_size(size, "size") _encode_size(hprec, "horizontal precision") _encode_size(vprec, "vertical precision") - return cls(rdclass, rdtype, latitude, longitude, altitude, - size, hprec, vprec) + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - milliseconds = (self.latitude[0] * 3600000 + - self.latitude[1] * 60000 + - self.latitude[2] * 1000 + - self.latitude[3]) * self.latitude[4] + milliseconds = ( + self.latitude[0] * 3600000 + + self.latitude[1] * 60000 + + self.latitude[2] * 1000 + + self.latitude[3] + ) * self.latitude[4] latitude = 0x80000000 + milliseconds - milliseconds = (self.longitude[0] * 3600000 + - self.longitude[1] * 60000 + - self.longitude[2] * 1000 + - self.longitude[3]) * self.longitude[4] + milliseconds = ( + self.longitude[0] * 3600000 + + self.longitude[1] * 60000 + + self.longitude[2] * 1000 + + self.longitude[3] + ) * self.longitude[4] longitude = 0x80000000 + milliseconds altitude = int(self.altitude) + 10000000 size = _encode_size(self.size, "size") hprec = _encode_size(self.horizontal_precision, "horizontal precision") vprec = _encode_size(self.vertical_precision, "vertical precision") - wire = struct.pack("!BBBBIII", 0, size, hprec, vprec, latitude, - longitude, altitude) + wire = struct.pack( + "!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude + ) file.write(wire) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (version, size, hprec, vprec, latitude, longitude, altitude) = \ - parser.get_struct("!BBBBIII") + ( + version, + size, + hprec, + vprec, + latitude, + longitude, + altitude, + ) = parser.get_struct("!BBBBIII") if version != 0: raise dns.exception.FormError("LOC version not zero") if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: @@ -312,8 +343,7 @@ class LOC(dns.rdata.Rdata): size = _decode_size(size, "size") hprec = _decode_size(hprec, "horizontal precision") vprec = _decode_size(vprec, "vertical precision") - return cls(rdclass, rdtype, latitude, longitude, altitude, - size, hprec, vprec) + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) @property def float_latitude(self): diff --git a/lib/dns/rdtypes/ANY/LP.py b/lib/dns/rdtypes/ANY/LP.py index b6a2e36c..8a7c5125 100644 --- a/lib/dns/rdtypes/ANY/LP.py +++ b/lib/dns/rdtypes/ANY/LP.py @@ -3,6 +3,7 @@ import struct import dns.immutable +import dns.rdata @dns.immutable.immutable @@ -12,7 +13,7 @@ class LP(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'fqdn'] + __slots__ = ["preference", "fqdn"] def __init__(self, rdclass, rdtype, preference, fqdn): super().__init__(rdclass, rdtype) @@ -21,17 +22,18 @@ class LP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): fqdn = self.fqdn.choose_relativity(origin, relativize) - return '%d %s' % (self.preference, fqdn) + return "%d %s" % (self.preference, fqdn) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() fqdn = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, preference, fqdn) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) + file.write(struct.pack("!H", self.preference)) self.fqdn.to_wire(file, compress, origin, canonicalize) @classmethod diff --git a/lib/dns/rdtypes/ANY/NID.py b/lib/dns/rdtypes/ANY/NID.py index 74951bbf..ad54aca3 100644 --- a/lib/dns/rdtypes/ANY/NID.py +++ b/lib/dns/rdtypes/ANY/NID.py @@ -13,32 +13,33 @@ class NID(dns.rdata.Rdata): # see: rfc6742.txt - __slots__ = ['preference', 'nodeid'] + __slots__ = ["preference", "nodeid"] def __init__(self, rdclass, rdtype, preference, nodeid): super().__init__(rdclass, rdtype) self.preference = self._as_uint16(preference) if isinstance(nodeid, bytes): if len(nodeid) != 8: - raise ValueError('invalid nodeid') - self.nodeid = dns.rdata._hexify(nodeid, 4, b':') + raise ValueError("invalid nodeid") + self.nodeid = dns.rdata._hexify(nodeid, 4, b":") else: - dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':') + dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":") self.nodeid = nodeid def to_text(self, origin=None, relativize=True, **kw): - return f'{self.preference} {self.nodeid}' + return f"{self.preference} {self.nodeid}" @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() nodeid = tok.get_identifier() return cls(rdclass, rdtype, preference, nodeid) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - file.write(struct.pack('!H', self.preference)) - file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':')) + file.write(struct.pack("!H", self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":")) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index dc31f4c4..7af7b77f 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -25,7 +25,7 @@ import dns.rdtypes.util @dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): - type_name = 'NSEC' + type_name = "NSEC" @dns.immutable.immutable @@ -33,7 +33,7 @@ class NSEC(dns.rdata.Rdata): """NSEC record""" - __slots__ = ['next', 'windows'] + __slots__ = ["next", "windows"] def __init__(self, rdclass, rdtype, next, windows): super().__init__(rdclass, rdtype) @@ -45,11 +45,12 @@ class NSEC(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): next = self.next.choose_relativity(origin, relativize) text = Bitmap(self.windows).to_text() - return '{}{}'.format(next, text) + return "{}{}".format(next, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): next = tok.get_name(origin, relativize, relativize_to) windows = Bitmap.from_text(tok) return cls(rdclass, rdtype, next, windows) diff --git a/lib/dns/rdtypes/ANY/NSEC3.py b/lib/dns/rdtypes/ANY/NSEC3.py index 14242bda..6eae16e0 100644 --- a/lib/dns/rdtypes/ANY/NSEC3.py +++ b/lib/dns/rdtypes/ANY/NSEC3.py @@ -26,10 +26,12 @@ import dns.rdatatype import dns.rdtypes.util -b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV', - b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') -b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', - b'0123456789ABCDEFGHIJKLMNOPQRSTUV') +b32_hex_to_normal = bytes.maketrans( + b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" +) +b32_normal_to_hex = bytes.maketrans( + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV" +) # hash algorithm constants SHA1 = 1 @@ -40,7 +42,7 @@ OPTOUT = 1 @dns.immutable.immutable class Bitmap(dns.rdtypes.util.Bitmap): - type_name = 'NSEC3' + type_name = "NSEC3" @dns.immutable.immutable @@ -48,10 +50,11 @@ class NSEC3(dns.rdata.Rdata): """NSEC3 record""" - __slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows'] + __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"] - def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt, - next, windows): + def __init__( + self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows + ): super().__init__(rdclass, rdtype) self.algorithm = self._as_uint8(algorithm) self.flags = self._as_uint8(flags) @@ -63,38 +66,41 @@ class NSEC3(dns.rdata.Rdata): self.windows = tuple(windows.windows) def to_text(self, origin=None, relativize=True, **kw): - next = base64.b32encode(self.next).translate( - b32_normal_to_hex).lower().decode() - if self.salt == b'': - salt = '-' + next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() + if self.salt == b"": + salt = "-" else: salt = binascii.hexlify(self.salt).decode() text = Bitmap(self.windows).to_text() - return '%u %u %u %s %s%s' % (self.algorithm, self.flags, - self.iterations, salt, next, text) + return "%u %u %u %s %s%s" % ( + self.algorithm, + self.flags, + self.iterations, + salt, + next, + text, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() flags = tok.get_uint8() iterations = tok.get_uint16() salt = tok.get_string() - if salt == '-': - salt = b'' + if salt == "-": + salt = b"" else: - salt = binascii.unhexlify(salt.encode('ascii')) - next = tok.get_string().encode( - 'ascii').upper().translate(b32_hex_to_normal) + salt = binascii.unhexlify(salt.encode("ascii")) + next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) next = base64.b32decode(next) bitmap = Bitmap.from_text(tok) - return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, - bitmap) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) - file.write(struct.pack("!BBHB", self.algorithm, self.flags, - self.iterations, l)) + file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) file.write(self.salt) l = len(self.next) file.write(struct.pack("!B", l)) @@ -103,9 +109,8 @@ class NSEC3(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (algorithm, flags, iterations) = parser.get_struct('!BBH') + (algorithm, flags, iterations) = parser.get_struct("!BBH") salt = parser.get_counted_bytes() next = parser.get_counted_bytes() bitmap = Bitmap.from_wire_parser(parser) - return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, - bitmap) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) diff --git a/lib/dns/rdtypes/ANY/NSEC3PARAM.py b/lib/dns/rdtypes/ANY/NSEC3PARAM.py index 299bf6ed..1b7269a0 100644 --- a/lib/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/lib/dns/rdtypes/ANY/NSEC3PARAM.py @@ -28,7 +28,7 @@ class NSEC3PARAM(dns.rdata.Rdata): """NSEC3PARAM record""" - __slots__ = ['algorithm', 'flags', 'iterations', 'salt'] + __slots__ = ["algorithm", "flags", "iterations", "salt"] def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): super().__init__(rdclass, rdtype) @@ -38,34 +38,33 @@ class NSEC3PARAM(dns.rdata.Rdata): self.salt = self._as_bytes(salt, True, 255) def to_text(self, origin=None, relativize=True, **kw): - if self.salt == b'': - salt = '-' + if self.salt == b"": + salt = "-" else: salt = binascii.hexlify(self.salt).decode() - return '%u %u %u %s' % (self.algorithm, self.flags, self.iterations, - salt) + return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() flags = tok.get_uint8() iterations = tok.get_uint16() salt = tok.get_string() - if salt == '-': - salt = '' + if salt == "-": + salt = "" else: salt = binascii.unhexlify(salt.encode()) return cls(rdclass, rdtype, algorithm, flags, iterations, salt) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.salt) - file.write(struct.pack("!BBHB", self.algorithm, self.flags, - self.iterations, l)) + file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) file.write(self.salt) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (algorithm, flags, iterations) = parser.get_struct('!BBH') + (algorithm, flags, iterations) = parser.get_struct("!BBH") salt = parser.get_counted_bytes() return cls(rdclass, rdtype, algorithm, flags, iterations, salt) diff --git a/lib/dns/rdtypes/ANY/OPENPGPKEY.py b/lib/dns/rdtypes/ANY/OPENPGPKEY.py index dcfa028d..e5e25727 100644 --- a/lib/dns/rdtypes/ANY/OPENPGPKEY.py +++ b/lib/dns/rdtypes/ANY/OPENPGPKEY.py @@ -22,6 +22,7 @@ import dns.immutable import dns.rdata import dns.tokenizer + @dns.immutable.immutable class OPENPGPKEY(dns.rdata.Rdata): @@ -37,8 +38,9 @@ class OPENPGPKEY(dns.rdata.Rdata): return dns.rdata._base64ify(self.key, chunksize=None, **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) return cls(rdclass, rdtype, key) diff --git a/lib/dns/rdtypes/ANY/OPT.py b/lib/dns/rdtypes/ANY/OPT.py index 69b8fe75..36d4c7c6 100644 --- a/lib/dns/rdtypes/ANY/OPT.py +++ b/lib/dns/rdtypes/ANY/OPT.py @@ -26,12 +26,13 @@ import dns.rdata # We don't implement from_text, and that's ok. # pylint: disable=abstract-method + @dns.immutable.immutable class OPT(dns.rdata.Rdata): """OPT record""" - __slots__ = ['options'] + __slots__ = ["options"] def __init__(self, rdclass, rdtype, options): """Initialize an OPT rdata. @@ -45,10 +46,12 @@ class OPT(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) + def as_option(option): if not isinstance(option, dns.edns.Option): - raise ValueError('option is not a dns.edns.option') + raise ValueError("option is not a dns.edns.option") return option + self.options = self._as_tuple(options, as_option) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): @@ -58,13 +61,13 @@ class OPT(dns.rdata.Rdata): file.write(owire) def to_text(self, origin=None, relativize=True, **kw): - return ' '.join(opt.to_text() for opt in self.options) + return " ".join(opt.to_text() for opt in self.options) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): options = [] while parser.remaining() > 0: - (otype, olen) = parser.get_struct('!HH') + (otype, olen) = parser.get_struct("!HH") with parser.restrict_to(olen): opt = dns.edns.option_from_wire_parser(otype, parser) options.append(opt) diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index a4e2297d..c0c316b5 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -28,7 +28,7 @@ class RP(dns.rdata.Rdata): # see: RFC 1183 - __slots__ = ['mbox', 'txt'] + __slots__ = ["mbox", "txt"] def __init__(self, rdclass, rdtype, mbox, txt): super().__init__(rdclass, rdtype) @@ -41,8 +41,9 @@ class RP(dns.rdata.Rdata): return "{} {}".format(str(mbox), str(txt)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): mbox = tok.get_name(origin, relativize, relativize_to) txt = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, mbox, txt) diff --git a/lib/dns/rdtypes/ANY/RRSIG.py b/lib/dns/rdtypes/ANY/RRSIG.py index d050ccc6..3d5ad0f3 100644 --- a/lib/dns/rdtypes/ANY/RRSIG.py +++ b/lib/dns/rdtypes/ANY/RRSIG.py @@ -20,7 +20,7 @@ import calendar import struct import time -import dns.dnssec +import dns.dnssectypes import dns.immutable import dns.exception import dns.rdata @@ -43,12 +43,11 @@ def sigtime_to_posixtime(what): hour = int(what[8:10]) minute = int(what[10:12]) second = int(what[12:14]) - return calendar.timegm((year, month, day, hour, minute, second, - 0, 0, 0)) + return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0)) def posixtime_to_sigtime(what): - return time.strftime('%Y%m%d%H%M%S', time.gmtime(what)) + return time.strftime("%Y%m%d%H%M%S", time.gmtime(what)) @dns.immutable.immutable @@ -56,16 +55,35 @@ class RRSIG(dns.rdata.Rdata): """RRSIG record""" - __slots__ = ['type_covered', 'algorithm', 'labels', 'original_ttl', - 'expiration', 'inception', 'key_tag', 'signer', - 'signature'] + __slots__ = [ + "type_covered", + "algorithm", + "labels", + "original_ttl", + "expiration", + "inception", + "key_tag", + "signer", + "signature", + ] - def __init__(self, rdclass, rdtype, type_covered, algorithm, labels, - original_ttl, expiration, inception, key_tag, signer, - signature): + def __init__( + self, + rdclass, + rdtype, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer, + signature, + ): super().__init__(rdclass, rdtype) self.type_covered = self._as_rdatatype(type_covered) - self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.labels = self._as_uint8(labels) self.original_ttl = self._as_ttl(original_ttl) self.expiration = self._as_uint32(expiration) @@ -78,7 +96,7 @@ class RRSIG(dns.rdata.Rdata): return self.type_covered def to_text(self, origin=None, relativize=True, **kw): - return '%s %d %d %d %s %s %d %s %s' % ( + return "%s %d %d %d %s %s %d %s %s" % ( dns.rdatatype.to_text(self.type_covered), self.algorithm, self.labels, @@ -87,14 +105,15 @@ class RRSIG(dns.rdata.Rdata): posixtime_to_sigtime(self.inception), self.key_tag, self.signer.choose_relativity(origin, relativize), - dns.rdata._base64ify(self.signature, **kw) + dns.rdata._base64ify(self.signature, **kw), ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): type_covered = dns.rdatatype.from_text(tok.get_string()) - algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) + algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) labels = tok.get_int() original_ttl = tok.get_ttl() expiration = sigtime_to_posixtime(tok.get_string()) @@ -103,22 +122,38 @@ class RRSIG(dns.rdata.Rdata): signer = tok.get_name(origin, relativize, relativize_to) b64 = tok.concatenate_remaining_identifiers().encode() signature = base64.b64decode(b64) - return cls(rdclass, rdtype, type_covered, algorithm, labels, - original_ttl, expiration, inception, key_tag, signer, - signature) + return cls( + rdclass, + rdtype, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer, + signature, + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack('!HBBIIIH', self.type_covered, - self.algorithm, self.labels, - self.original_ttl, self.expiration, - self.inception, self.key_tag) + header = struct.pack( + "!HBBIIIH", + self.type_covered, + self.algorithm, + self.labels, + self.original_ttl, + self.expiration, + self.inception, + self.key_tag, + ) file.write(header) self.signer.to_wire(file, None, origin, canonicalize) file.write(self.signature) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct('!HBBIIIH') + header = parser.get_struct("!HBBIIIH") signer = parser.get_name(origin) signature = parser.get_remaining() return cls(rdclass, rdtype, *header, signer, signature) diff --git a/lib/dns/rdtypes/ANY/SOA.py b/lib/dns/rdtypes/ANY/SOA.py index 7ce88652..6f6fe58b 100644 --- a/lib/dns/rdtypes/ANY/SOA.py +++ b/lib/dns/rdtypes/ANY/SOA.py @@ -30,11 +30,11 @@ class SOA(dns.rdata.Rdata): # see: RFC 1035 - __slots__ = ['mname', 'rname', 'serial', 'refresh', 'retry', 'expire', - 'minimum'] + __slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"] - def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry, - expire, minimum): + def __init__( + self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum + ): super().__init__(rdclass, rdtype) self.mname = self._as_name(mname) self.rname = self._as_name(rname) @@ -47,13 +47,20 @@ class SOA(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): mname = self.mname.choose_relativity(origin, relativize) rname = self.rname.choose_relativity(origin, relativize) - return '%s %s %d %d %d %d %d' % ( - mname, rname, self.serial, self.refresh, self.retry, - self.expire, self.minimum) + return "%s %s %d %d %d %d %d" % ( + mname, + rname, + self.serial, + self.refresh, + self.retry, + self.expire, + self.minimum, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): mname = tok.get_name(origin, relativize, relativize_to) rname = tok.get_name(origin, relativize, relativize_to) serial = tok.get_uint32() @@ -61,18 +68,20 @@ class SOA(dns.rdata.Rdata): retry = tok.get_ttl() expire = tok.get_ttl() minimum = tok.get_ttl() - return cls(rdclass, rdtype, mname, rname, serial, refresh, retry, - expire, minimum) + return cls( + rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.mname.to_wire(file, compress, origin, canonicalize) self.rname.to_wire(file, compress, origin, canonicalize) - five_ints = struct.pack('!IIIII', self.serial, self.refresh, - self.retry, self.expire, self.minimum) + five_ints = struct.pack( + "!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum + ) file.write(five_ints) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): mname = parser.get_name(origin) rname = parser.get_name(origin) - return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII')) + return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII")) diff --git a/lib/dns/rdtypes/ANY/SSHFP.py b/lib/dns/rdtypes/ANY/SSHFP.py index cc035195..58ffcbbc 100644 --- a/lib/dns/rdtypes/ANY/SSHFP.py +++ b/lib/dns/rdtypes/ANY/SSHFP.py @@ -30,10 +30,9 @@ class SSHFP(dns.rdata.Rdata): # See RFC 4255 - __slots__ = ['algorithm', 'fp_type', 'fingerprint'] + __slots__ = ["algorithm", "fp_type", "fingerprint"] - def __init__(self, rdclass, rdtype, algorithm, fp_type, - fingerprint): + def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint): super().__init__(rdclass, rdtype) self.algorithm = self._as_uint8(algorithm) self.fp_type = self._as_uint8(fp_type) @@ -41,16 +40,17 @@ class SSHFP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %s' % (self.algorithm, - self.fp_type, - dns.rdata._hexify(self.fingerprint, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %s" % ( + self.algorithm, + self.fp_type, + dns.rdata._hexify(self.fingerprint, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_uint8() fp_type = tok.get_uint8() fingerprint = tok.concatenate_remaining_identifiers().encode() diff --git a/lib/dns/rdtypes/ANY/TKEY.py b/lib/dns/rdtypes/ANY/TKEY.py index 861fc4e3..070f03af 100644 --- a/lib/dns/rdtypes/ANY/TKEY.py +++ b/lib/dns/rdtypes/ANY/TKEY.py @@ -18,7 +18,6 @@ import base64 import struct -import dns.dnssec import dns.immutable import dns.exception import dns.rdata @@ -29,11 +28,28 @@ class TKEY(dns.rdata.Rdata): """TKEY Record""" - __slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error', - 'key', 'other'] + __slots__ = [ + "algorithm", + "inception", + "expiration", + "mode", + "error", + "key", + "other", + ] - def __init__(self, rdclass, rdtype, algorithm, inception, expiration, - mode, error, key, other=b''): + def __init__( + self, + rdclass, + rdtype, + algorithm, + inception, + expiration, + mode, + error, + key, + other=b"", + ): super().__init__(rdclass, rdtype) self.algorithm = self._as_name(algorithm) self.inception = self._as_uint32(inception) @@ -45,17 +61,23 @@ class TKEY(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): _algorithm = self.algorithm.choose_relativity(origin, relativize) - text = '%s %u %u %u %u %s' % (str(_algorithm), self.inception, - self.expiration, self.mode, self.error, - dns.rdata._base64ify(self.key, 0)) + text = "%s %u %u %u %u %s" % ( + str(_algorithm), + self.inception, + self.expiration, + self.mode, + self.error, + dns.rdata._base64ify(self.key, 0), + ) if len(self.other) > 0: - text += ' %s' % (dns.rdata._base64ify(self.other, 0)) + text += " %s" % (dns.rdata._base64ify(self.other, 0)) return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_name(relativize=False) inception = tok.get_uint32() expiration = tok.get_uint32() @@ -66,13 +88,15 @@ class TKEY(dns.rdata.Rdata): other_b64 = tok.concatenate_remaining_identifiers(True).encode() other = base64.b64decode(other_b64) - return cls(rdclass, rdtype, algorithm, inception, expiration, mode, - error, key, other) + return cls( + rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.algorithm.to_wire(file, compress, origin) - file.write(struct.pack("!IIHH", self.inception, self.expiration, - self.mode, self.error)) + file.write( + struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error) + ) file.write(struct.pack("!H", len(self.key))) file.write(self.key) file.write(struct.pack("!H", len(self.other))) @@ -86,8 +110,9 @@ class TKEY(dns.rdata.Rdata): key = parser.get_counted_bytes(2) other = parser.get_counted_bytes(2) - return cls(rdclass, rdtype, algorithm, inception, expiration, mode, - error, key, other) + return cls( + rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other + ) # Constants for the mode field - from RFC 2930: # 2.5 The Mode Field diff --git a/lib/dns/rdtypes/ANY/TSIG.py b/lib/dns/rdtypes/ANY/TSIG.py index b43a78f1..1ae87ebe 100644 --- a/lib/dns/rdtypes/ANY/TSIG.py +++ b/lib/dns/rdtypes/ANY/TSIG.py @@ -29,11 +29,28 @@ class TSIG(dns.rdata.Rdata): """TSIG record""" - __slots__ = ['algorithm', 'time_signed', 'fudge', 'mac', - 'original_id', 'error', 'other'] + __slots__ = [ + "algorithm", + "time_signed", + "fudge", + "mac", + "original_id", + "error", + "other", + ] - def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac, - original_id, error, other): + def __init__( + self, + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ): """Initialize a TSIG rdata. *rdclass*, an ``int`` is the rdataclass of the Rdata. @@ -67,45 +84,60 @@ class TSIG(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): algorithm = self.algorithm.choose_relativity(origin, relativize) error = dns.rcode.to_text(self.error, True) - text = f"{algorithm} {self.time_signed} {self.fudge} " + \ - f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \ - f"{self.original_id} {error} {len(self.other)}" + text = ( + f"{algorithm} {self.time_signed} {self.fudge} " + + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + + f"{self.original_id} {error} {len(self.other)}" + ) if self.other: text += f" {dns.rdata._base64ify(self.other, 0)}" return text @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): algorithm = tok.get_name(relativize=False) time_signed = tok.get_uint48() fudge = tok.get_uint16() mac_len = tok.get_uint16() mac = base64.b64decode(tok.get_string()) if len(mac) != mac_len: - raise SyntaxError('invalid MAC') + raise SyntaxError("invalid MAC") original_id = tok.get_uint16() error = dns.rcode.from_text(tok.get_string()) other_len = tok.get_uint16() if other_len > 0: other = base64.b64decode(tok.get_string()) if len(other) != other_len: - raise SyntaxError('invalid other data') + raise SyntaxError("invalid other data") else: - other = b'' - return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, - original_id, error, other) + other = b"" + return cls( + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.algorithm.to_wire(file, None, origin, False) - file.write(struct.pack('!HIHH', - (self.time_signed >> 32) & 0xffff, - self.time_signed & 0xffffffff, - self.fudge, - len(self.mac))) + file.write( + struct.pack( + "!HIHH", + (self.time_signed >> 32) & 0xFFFF, + self.time_signed & 0xFFFFFFFF, + self.fudge, + len(self.mac), + ) + ) file.write(self.mac) - file.write(struct.pack('!HHH', self.original_id, self.error, - len(self.other))) + file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other))) file.write(self.other) @classmethod @@ -114,7 +146,16 @@ class TSIG(dns.rdata.Rdata): time_signed = parser.get_uint48() fudge = parser.get_uint16() mac = parser.get_counted_bytes(2) - (original_id, error) = parser.get_struct('!HH') + (original_id, error) = parser.get_struct("!HH") other = parser.get_counted_bytes(2) - return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, - original_id, error, other) + return cls( + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) diff --git a/lib/dns/rdtypes/ANY/URI.py b/lib/dns/rdtypes/ANY/URI.py index 524fa1ba..b4c95a3b 100644 --- a/lib/dns/rdtypes/ANY/URI.py +++ b/lib/dns/rdtypes/ANY/URI.py @@ -32,7 +32,7 @@ class URI(dns.rdata.Rdata): # see RFC 7553 - __slots__ = ['priority', 'weight', 'target'] + __slots__ = ["priority", "weight", "target"] def __init__(self, rdclass, rdtype, priority, weight, target): super().__init__(rdclass, rdtype) @@ -43,12 +43,12 @@ class URI(dns.rdata.Rdata): raise dns.exception.SyntaxError("URI target cannot be empty") def to_text(self, origin=None, relativize=True, **kw): - return '%d %d "%s"' % (self.priority, self.weight, - self.target.decode()) + return '%d %d "%s"' % (self.priority, self.weight, self.target.decode()) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): priority = tok.get_uint16() weight = tok.get_uint16() target = tok.get().unescape() @@ -63,10 +63,10 @@ class URI(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (priority, weight) = parser.get_struct('!HH') + (priority, weight) = parser.get_struct("!HH") target = parser.get_remaining() if len(target) == 0: - raise dns.exception.FormError('URI target may not be empty') + raise dns.exception.FormError("URI target may not be empty") return cls(rdclass, rdtype, priority, weight, target) def _processing_priority(self): diff --git a/lib/dns/rdtypes/ANY/X25.py b/lib/dns/rdtypes/ANY/X25.py index 4f7230c0..06c14534 100644 --- a/lib/dns/rdtypes/ANY/X25.py +++ b/lib/dns/rdtypes/ANY/X25.py @@ -30,7 +30,7 @@ class X25(dns.rdata.Rdata): # see RFC 1183 - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -40,15 +40,16 @@ class X25(dns.rdata.Rdata): return '"%s"' % dns.rdata._escapify(self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() return cls(rdclass, rdtype, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): l = len(self.address) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(self.address) @classmethod diff --git a/lib/dns/rdtypes/ANY/ZONEMD.py b/lib/dns/rdtypes/ANY/ZONEMD.py index 035f7b32..1f86ba49 100644 --- a/lib/dns/rdtypes/ANY/ZONEMD.py +++ b/lib/dns/rdtypes/ANY/ZONEMD.py @@ -6,7 +6,7 @@ import binascii import dns.immutable import dns.rdata import dns.rdatatype -import dns.zone +import dns.zonetypes @dns.immutable.immutable @@ -16,35 +16,38 @@ class ZONEMD(dns.rdata.Rdata): # See RFC 8976 - __slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest'] + __slots__ = ["serial", "scheme", "hash_algorithm", "digest"] def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest): super().__init__(rdclass, rdtype) self.serial = self._as_uint32(serial) - self.scheme = dns.zone.DigestScheme.make(scheme) - self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm) + self.scheme = dns.zonetypes.DigestScheme.make(scheme) + self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm) self.digest = self._as_bytes(digest) if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2 - raise ValueError('scheme 0 is reserved') + raise ValueError("scheme 0 is reserved") if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3 - raise ValueError('hash_algorithm 0 is reserved') + raise ValueError("hash_algorithm 0 is reserved") - hasher = dns.zone._digest_hashers.get(self.hash_algorithm) + hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm) if hasher and hasher().digest_size != len(self.digest): - raise ValueError('digest length inconsistent with hash algorithm') + raise ValueError("digest length inconsistent with hash algorithm") def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %d %s' % (self.serial, self.scheme, self.hash_algorithm, - dns.rdata._hexify(self.digest, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.serial, + self.scheme, + self.hash_algorithm, + dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): serial = tok.get_uint32() scheme = tok.get_uint8() hash_algorithm = tok.get_uint8() @@ -53,8 +56,7 @@ class ZONEMD(dns.rdata.Rdata): return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!IBB", self.serial, self.scheme, - self.hash_algorithm) + header = struct.pack("!IBB", self.serial, self.scheme, self.hash_algorithm) file.write(header) file.write(self.digest) diff --git a/lib/dns/rdtypes/ANY/__init__.py b/lib/dns/rdtypes/ANY/__init__.py index 2cadcde3..3824a0a0 100644 --- a/lib/dns/rdtypes/ANY/__init__.py +++ b/lib/dns/rdtypes/ANY/__init__.py @@ -18,51 +18,51 @@ """Class ANY (generic) rdata type classes.""" __all__ = [ - 'AFSDB', - 'AMTRELAY', - 'AVC', - 'CAA', - 'CDNSKEY', - 'CDS', - 'CERT', - 'CNAME', - 'CSYNC', - 'DLV', - 'DNAME', - 'DNSKEY', - 'DS', - 'EUI48', - 'EUI64', - 'GPOS', - 'HINFO', - 'HIP', - 'ISDN', - 'L32', - 'L64', - 'LOC', - 'LP', - 'MX', - 'NID', - 'NINFO', - 'NS', - 'NSEC', - 'NSEC3', - 'NSEC3PARAM', - 'OPENPGPKEY', - 'OPT', - 'PTR', - 'RP', - 'RRSIG', - 'RT', - 'SMIMEA', - 'SOA', - 'SPF', - 'SSHFP', - 'TKEY', - 'TLSA', - 'TSIG', - 'TXT', - 'URI', - 'X25', - 'ZONEMD', + "AFSDB", + "AMTRELAY", + "AVC", + "CAA", + "CDNSKEY", + "CDS", + "CERT", + "CNAME", + "CSYNC", + "DLV", + "DNAME", + "DNSKEY", + "DS", + "EUI48", + "EUI64", + "GPOS", + "HINFO", + "HIP", + "ISDN", + "L32", + "L64", + "LOC", + "LP", + "MX", + "NID", + "NINFO", + "NS", + "NSEC", + "NSEC3", + "NSEC3PARAM", + "OPENPGPKEY", + "OPT", + "PTR", + "RP", + "RRSIG", + "RT", + "SMIMEA", + "SOA", + "SPF", + "SSHFP", + "TKEY", + "TLSA", + "TSIG", + "TXT", + "URI", + "X25", + "ZONEMD", ] diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py index 828701b4..9905c7c9 100644 --- a/lib/dns/rdtypes/CH/A.py +++ b/lib/dns/rdtypes/CH/A.py @@ -20,6 +20,7 @@ import struct import dns.rdtypes.mxbase import dns.immutable + @dns.immutable.immutable class A(dns.rdata.Rdata): @@ -28,7 +29,7 @@ class A(dns.rdata.Rdata): # domain: the domain of the address # address: the 16-bit address - __slots__ = ['domain', 'address'] + __slots__ = ["domain", "address"] def __init__(self, rdclass, rdtype, domain, address): super().__init__(rdclass, rdtype) @@ -37,11 +38,12 @@ class A(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): domain = self.domain.choose_relativity(origin, relativize) - return '%s %o' % (domain, self.address) + return "%s %o" % (domain, self.address) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): domain = tok.get_name(origin, relativize, relativize_to) address = tok.get_uint16(base=8) return cls(rdclass, rdtype, domain, address) diff --git a/lib/dns/rdtypes/CH/__init__.py b/lib/dns/rdtypes/CH/__init__.py index 7184a733..0760c26c 100644 --- a/lib/dns/rdtypes/CH/__init__.py +++ b/lib/dns/rdtypes/CH/__init__.py @@ -18,5 +18,5 @@ """Class CH rdata type classes.""" __all__ = [ - 'A', + "A", ] diff --git a/lib/dns/rdtypes/IN/A.py b/lib/dns/rdtypes/IN/A.py index 74b591ef..713d5eea 100644 --- a/lib/dns/rdtypes/IN/A.py +++ b/lib/dns/rdtypes/IN/A.py @@ -27,7 +27,7 @@ class A(dns.rdata.Rdata): """A record.""" - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -37,8 +37,9 @@ class A(dns.rdata.Rdata): return self.address @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_identifier() return cls(rdclass, rdtype, address) diff --git a/lib/dns/rdtypes/IN/AAAA.py b/lib/dns/rdtypes/IN/AAAA.py index 2d3ec902..f8237b44 100644 --- a/lib/dns/rdtypes/IN/AAAA.py +++ b/lib/dns/rdtypes/IN/AAAA.py @@ -27,7 +27,7 @@ class AAAA(dns.rdata.Rdata): """AAAA record.""" - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -37,8 +37,9 @@ class AAAA(dns.rdata.Rdata): return self.address @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_identifier() return cls(rdclass, rdtype, address) diff --git a/lib/dns/rdtypes/IN/APL.py b/lib/dns/rdtypes/IN/APL.py index ae94fb24..05e1689f 100644 --- a/lib/dns/rdtypes/IN/APL.py +++ b/lib/dns/rdtypes/IN/APL.py @@ -26,12 +26,13 @@ import dns.ipv6 import dns.rdata import dns.tokenizer + @dns.immutable.immutable class APLItem: """An APL list item.""" - __slots__ = ['family', 'negation', 'address', 'prefix'] + __slots__ = ["family", "negation", "address", "prefix"] def __init__(self, family, negation, address, prefix): self.family = dns.rdata.Rdata._as_uint16(family) @@ -67,12 +68,12 @@ class APLItem: if address[i] != 0: last = i + 1 break - address = address[0: last] + address = address[0:last] l = len(address) assert l < 128 if self.negation: l |= 0x80 - header = struct.pack('!HBB', self.family, self.prefix, l) + header = struct.pack("!HBB", self.family, self.prefix, l) file.write(header) file.write(address) @@ -84,32 +85,33 @@ class APL(dns.rdata.Rdata): # see: RFC 3123 - __slots__ = ['items'] + __slots__ = ["items"] def __init__(self, rdclass, rdtype, items): super().__init__(rdclass, rdtype) for item in items: if not isinstance(item, APLItem): - raise ValueError('item not an APLItem') + raise ValueError("item not an APLItem") self.items = tuple(items) def to_text(self, origin=None, relativize=True, **kw): - return ' '.join(map(str, self.items)) + return " ".join(map(str, self.items)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): items = [] for token in tok.get_remaining(): item = token.unescape().value - if item[0] == '!': + if item[0] == "!": negation = True item = item[1:] else: negation = False - (family, rest) = item.split(':', 1) + (family, rest) = item.split(":", 1) family = int(family) - (address, prefix) = rest.split('/', 1) + (address, prefix) = rest.split("/", 1) prefix = int(prefix) item = APLItem(family, negation, address, prefix) items.append(item) @@ -125,7 +127,7 @@ class APL(dns.rdata.Rdata): items = [] while parser.remaining() > 0: - header = parser.get_struct('!HBB') + header = parser.get_struct("!HBB") afdlen = header[2] if afdlen > 127: negation = True @@ -136,16 +138,16 @@ class APL(dns.rdata.Rdata): l = len(address) if header[0] == 1: if l < 4: - address += b'\x00' * (4 - l) + address += b"\x00" * (4 - l) elif header[0] == 2: if l < 16: - address += b'\x00' * (16 - l) + address += b"\x00" * (16 - l) else: # # This isn't really right according to the RFC, but it # seems better than throwing an exception # - address = codecs.encode(address, 'hex_codec') + address = codecs.encode(address, "hex_codec") item = APLItem(header[0], negation, address, header[1]) items.append(item) return cls(rdclass, rdtype, items) diff --git a/lib/dns/rdtypes/IN/DHCID.py b/lib/dns/rdtypes/IN/DHCID.py index a9185989..65f85897 100644 --- a/lib/dns/rdtypes/IN/DHCID.py +++ b/lib/dns/rdtypes/IN/DHCID.py @@ -19,6 +19,7 @@ import base64 import dns.exception import dns.immutable +import dns.rdata @dns.immutable.immutable @@ -28,7 +29,7 @@ class DHCID(dns.rdata.Rdata): # see: RFC 4701 - __slots__ = ['data'] + __slots__ = ["data"] def __init__(self, rdclass, rdtype, data): super().__init__(rdclass, rdtype) @@ -38,8 +39,9 @@ class DHCID(dns.rdata.Rdata): return dns.rdata._base64ify(self.data, **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): b64 = tok.concatenate_remaining_identifiers().encode() data = base64.b64decode(b64) return cls(rdclass, rdtype, data) diff --git a/lib/dns/rdtypes/IN/HTTPS.py b/lib/dns/rdtypes/IN/HTTPS.py index 6a67e8ed..7797fbaf 100644 --- a/lib/dns/rdtypes/IN/HTTPS.py +++ b/lib/dns/rdtypes/IN/HTTPS.py @@ -3,6 +3,7 @@ import dns.rdtypes.svcbbase import dns.immutable + @dns.immutable.immutable class HTTPS(dns.rdtypes.svcbbase.SVCBBase): """HTTPS record""" diff --git a/lib/dns/rdtypes/IN/IPSECKEY.py b/lib/dns/rdtypes/IN/IPSECKEY.py index d1d39438..1255739f 100644 --- a/lib/dns/rdtypes/IN/IPSECKEY.py +++ b/lib/dns/rdtypes/IN/IPSECKEY.py @@ -24,7 +24,8 @@ import dns.rdtypes.util class Gateway(dns.rdtypes.util.Gateway): - name = 'IPSECKEY gateway' + name = "IPSECKEY gateway" + @dns.immutable.immutable class IPSECKEY(dns.rdata.Rdata): @@ -33,10 +34,11 @@ class IPSECKEY(dns.rdata.Rdata): # see: RFC 4025 - __slots__ = ['precedence', 'gateway_type', 'algorithm', 'gateway', 'key'] + __slots__ = ["precedence", "gateway_type", "algorithm", "gateway", "key"] - def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm, - gateway, key): + def __init__( + self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key + ): super().__init__(rdclass, rdtype) gateway = Gateway(gateway_type, gateway) self.precedence = self._as_uint8(precedence) @@ -46,38 +48,45 @@ class IPSECKEY(dns.rdata.Rdata): self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): - gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, - relativize) - return '%d %d %d %s %s' % (self.precedence, self.gateway_type, - self.algorithm, gateway, - dns.rdata._base64ify(self.key, **kw)) + gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize) + return "%d %d %d %s %s" % ( + self.precedence, + self.gateway_type, + self.algorithm, + gateway, + dns.rdata._base64ify(self.key, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): precedence = tok.get_uint8() gateway_type = tok.get_uint8() algorithm = tok.get_uint8() - gateway = Gateway.from_text(gateway_type, tok, origin, relativize, - relativize_to) + gateway = Gateway.from_text( + gateway_type, tok, origin, relativize, relativize_to + ) b64 = tok.concatenate_remaining_identifiers().encode() key = base64.b64decode(b64) - return cls(rdclass, rdtype, precedence, gateway_type, algorithm, - gateway.gateway, key) + return cls( + rdclass, rdtype, precedence, gateway_type, algorithm, gateway.gateway, key + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!BBB", self.precedence, self.gateway_type, - self.algorithm) + header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm) file.write(header) - Gateway(self.gateway_type, self.gateway).to_wire(file, compress, - origin, canonicalize) + Gateway(self.gateway_type, self.gateway).to_wire( + file, compress, origin, canonicalize + ) file.write(self.key) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct('!BBB') + header = parser.get_struct("!BBB") gateway_type = header[1] gateway = Gateway.from_wire_parser(gateway_type, parser, origin) key = parser.get_remaining() - return cls(rdclass, rdtype, header[0], gateway_type, header[2], - gateway.gateway, key) + return cls( + rdclass, rdtype, header[0], gateway_type, header[2], gateway.gateway, key + ) diff --git a/lib/dns/rdtypes/IN/NAPTR.py b/lib/dns/rdtypes/IN/NAPTR.py index b107974d..1f1f5a12 100644 --- a/lib/dns/rdtypes/IN/NAPTR.py +++ b/lib/dns/rdtypes/IN/NAPTR.py @@ -27,7 +27,7 @@ import dns.rdtypes.util def _write_string(file, s): l = len(s) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(s) @@ -38,11 +38,11 @@ class NAPTR(dns.rdata.Rdata): # see: RFC 3403 - __slots__ = ['order', 'preference', 'flags', 'service', 'regexp', - 'replacement'] + __slots__ = ["order", "preference", "flags", "service", "regexp", "replacement"] - def __init__(self, rdclass, rdtype, order, preference, flags, service, - regexp, replacement): + def __init__( + self, rdclass, rdtype, order, preference, flags, service, regexp, replacement + ): super().__init__(rdclass, rdtype) self.flags = self._as_bytes(flags, True, 255) self.service = self._as_bytes(service, True, 255) @@ -53,24 +53,28 @@ class NAPTR(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): replacement = self.replacement.choose_relativity(origin, relativize) - return '%d %d "%s" "%s" "%s" %s' % \ - (self.order, self.preference, - dns.rdata._escapify(self.flags), - dns.rdata._escapify(self.service), - dns.rdata._escapify(self.regexp), - replacement) + return '%d %d "%s" "%s" "%s" %s' % ( + self.order, + self.preference, + dns.rdata._escapify(self.flags), + dns.rdata._escapify(self.service), + dns.rdata._escapify(self.regexp), + replacement, + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): order = tok.get_uint16() preference = tok.get_uint16() flags = tok.get_string() service = tok.get_string() regexp = tok.get_string() replacement = tok.get_name(origin, relativize, relativize_to) - return cls(rdclass, rdtype, order, preference, flags, service, - regexp, replacement) + return cls( + rdclass, rdtype, order, preference, flags, service, regexp, replacement + ) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): two_ints = struct.pack("!HH", self.order, self.preference) @@ -82,14 +86,22 @@ class NAPTR(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (order, preference) = parser.get_struct('!HH') + (order, preference) = parser.get_struct("!HH") strings = [] for _ in range(3): s = parser.get_counted_bytes() strings.append(s) replacement = parser.get_name(origin) - return cls(rdclass, rdtype, order, preference, strings[0], strings[1], - strings[2], replacement) + return cls( + rdclass, + rdtype, + order, + preference, + strings[0], + strings[1], + strings[2], + replacement, + ) def _processing_priority(self): return (self.order, self.preference) diff --git a/lib/dns/rdtypes/IN/NSAP.py b/lib/dns/rdtypes/IN/NSAP.py index 23ae9b1a..be8581e6 100644 --- a/lib/dns/rdtypes/IN/NSAP.py +++ b/lib/dns/rdtypes/IN/NSAP.py @@ -30,7 +30,7 @@ class NSAP(dns.rdata.Rdata): # see: RFC 1706 - __slots__ = ['address'] + __slots__ = ["address"] def __init__(self, rdclass, rdtype, address): super().__init__(rdclass, rdtype) @@ -40,14 +40,15 @@ class NSAP(dns.rdata.Rdata): return "0x%s" % binascii.hexlify(self.address).decode() @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() - if address[0:2] != '0x': - raise dns.exception.SyntaxError('string does not start with 0x') - address = address[2:].replace('.', '') + if address[0:2] != "0x": + raise dns.exception.SyntaxError("string does not start with 0x") + address = address[2:].replace(".", "") if len(address) % 2 != 0: - raise dns.exception.SyntaxError('hexstring has odd length') + raise dns.exception.SyntaxError("hexstring has odd length") address = binascii.unhexlify(address.encode()) return cls(rdclass, rdtype, address) diff --git a/lib/dns/rdtypes/IN/PX.py b/lib/dns/rdtypes/IN/PX.py index 113d409c..b2216d6b 100644 --- a/lib/dns/rdtypes/IN/PX.py +++ b/lib/dns/rdtypes/IN/PX.py @@ -31,7 +31,7 @@ class PX(dns.rdata.Rdata): # see: RFC 2163 - __slots__ = ['preference', 'map822', 'mapx400'] + __slots__ = ["preference", "map822", "mapx400"] def __init__(self, rdclass, rdtype, preference, map822, mapx400): super().__init__(rdclass, rdtype) @@ -42,11 +42,12 @@ class PX(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): map822 = self.map822.choose_relativity(origin, relativize) mapx400 = self.mapx400.choose_relativity(origin, relativize) - return '%d %s %s' % (self.preference, map822, mapx400) + return "%d %s %s" % (self.preference, map822, mapx400) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() map822 = tok.get_name(origin, relativize, relativize_to) mapx400 = tok.get_name(origin, relativize, relativize_to) diff --git a/lib/dns/rdtypes/IN/SRV.py b/lib/dns/rdtypes/IN/SRV.py index 5b5ff422..8b0b6bf7 100644 --- a/lib/dns/rdtypes/IN/SRV.py +++ b/lib/dns/rdtypes/IN/SRV.py @@ -31,7 +31,7 @@ class SRV(dns.rdata.Rdata): # see: RFC 2782 - __slots__ = ['priority', 'weight', 'port', 'target'] + __slots__ = ["priority", "weight", "port", "target"] def __init__(self, rdclass, rdtype, priority, weight, port, target): super().__init__(rdclass, rdtype) @@ -42,12 +42,12 @@ class SRV(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) - return '%d %d %d %s' % (self.priority, self.weight, self.port, - target) + return "%d %d %d %s" % (self.priority, self.weight, self.port, target) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): priority = tok.get_uint16() weight = tok.get_uint16() port = tok.get_uint16() @@ -61,7 +61,7 @@ class SRV(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - (priority, weight, port) = parser.get_struct('!HHH') + (priority, weight, port) = parser.get_struct("!HHH") target = parser.get_name(origin) return cls(rdclass, rdtype, priority, weight, port, target) diff --git a/lib/dns/rdtypes/IN/SVCB.py b/lib/dns/rdtypes/IN/SVCB.py index 14838e16..9a1ad101 100644 --- a/lib/dns/rdtypes/IN/SVCB.py +++ b/lib/dns/rdtypes/IN/SVCB.py @@ -3,6 +3,7 @@ import dns.rdtypes.svcbbase import dns.immutable + @dns.immutable.immutable class SVCB(dns.rdtypes.svcbbase.SVCBBase): """SVCB record""" diff --git a/lib/dns/rdtypes/IN/WKS.py b/lib/dns/rdtypes/IN/WKS.py index 264e45d3..a671e203 100644 --- a/lib/dns/rdtypes/IN/WKS.py +++ b/lib/dns/rdtypes/IN/WKS.py @@ -23,13 +23,14 @@ import dns.immutable import dns.rdata try: - _proto_tcp = socket.getprotobyname('tcp') - _proto_udp = socket.getprotobyname('udp') + _proto_tcp = socket.getprotobyname("tcp") + _proto_udp = socket.getprotobyname("udp") except OSError: # Fall back to defaults in case /etc/protocols is unavailable. _proto_tcp = 6 _proto_udp = 17 + @dns.immutable.immutable class WKS(dns.rdata.Rdata): @@ -37,7 +38,7 @@ class WKS(dns.rdata.Rdata): # see: RFC 1035 - __slots__ = ['address', 'protocol', 'bitmap'] + __slots__ = ["address", "protocol", "bitmap"] def __init__(self, rdclass, rdtype, address, protocol, bitmap): super().__init__(rdclass, rdtype) @@ -51,12 +52,13 @@ class WKS(dns.rdata.Rdata): for j in range(0, 8): if byte & (0x80 >> j): bits.append(str(i * 8 + j)) - text = ' '.join(bits) - return '%s %d %s' % (self.address, self.protocol, text) + text = " ".join(bits) + return "%s %d %s" % (self.address, self.protocol, text) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): address = tok.get_string() protocol = tok.get_string() if protocol.isdigit(): @@ -87,7 +89,7 @@ class WKS(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(dns.ipv4.inet_aton(self.address)) - protocol = struct.pack('!B', self.protocol) + protocol = struct.pack("!B", self.protocol) file.write(protocol) file.write(self.bitmap) diff --git a/lib/dns/rdtypes/IN/__init__.py b/lib/dns/rdtypes/IN/__init__.py index d51b99e7..dcec4dd2 100644 --- a/lib/dns/rdtypes/IN/__init__.py +++ b/lib/dns/rdtypes/IN/__init__.py @@ -18,18 +18,18 @@ """Class IN rdata type classes.""" __all__ = [ - 'A', - 'AAAA', - 'APL', - 'DHCID', - 'HTTPS', - 'IPSECKEY', - 'KX', - 'NAPTR', - 'NSAP', - 'NSAP_PTR', - 'PX', - 'SRV', - 'SVCB', - 'WKS', + "A", + "AAAA", + "APL", + "DHCID", + "HTTPS", + "IPSECKEY", + "KX", + "NAPTR", + "NSAP", + "NSAP_PTR", + "PX", + "SRV", + "SVCB", + "WKS", ] diff --git a/lib/dns/rdtypes/__init__.py b/lib/dns/rdtypes/__init__.py index c3af264e..3997f84c 100644 --- a/lib/dns/rdtypes/__init__.py +++ b/lib/dns/rdtypes/__init__.py @@ -18,16 +18,16 @@ """DNS rdata type classes""" __all__ = [ - 'ANY', - 'IN', - 'CH', - 'dnskeybase', - 'dsbase', - 'euibase', - 'mxbase', - 'nsbase', - 'svcbbase', - 'tlsabase', - 'txtbase', - 'util' + "ANY", + "IN", + "CH", + "dnskeybase", + "dsbase", + "euibase", + "mxbase", + "nsbase", + "svcbbase", + "tlsabase", + "txtbase", + "util", ] diff --git a/lib/dns/rdtypes/dnskeybase.py b/lib/dns/rdtypes/dnskeybase.py index 788bb2bf..1d17f70f 100644 --- a/lib/dns/rdtypes/dnskeybase.py +++ b/lib/dns/rdtypes/dnskeybase.py @@ -21,11 +21,12 @@ import struct import dns.exception import dns.immutable -import dns.dnssec +import dns.dnssectypes import dns.rdata # wildcard import -__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 +__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 + class Flag(enum.IntFlag): SEP = 0x0001 @@ -38,22 +39,27 @@ class DNSKEYBase(dns.rdata.Rdata): """Base class for rdata that is like a DNSKEY record""" - __slots__ = ['flags', 'protocol', 'algorithm', 'key'] + __slots__ = ["flags", "protocol", "algorithm", "key"] def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): super().__init__(rdclass, rdtype) self.flags = self._as_uint16(flags) self.protocol = self._as_uint8(protocol) - self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): - return '%d %d %d %s' % (self.flags, self.protocol, self.algorithm, - dns.rdata._base64ify(self.key, **kw)) + return "%d %d %d %s" % ( + self.flags, + self.protocol, + self.algorithm, + dns.rdata._base64ify(self.key, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): flags = tok.get_uint16() protocol = tok.get_uint8() algorithm = tok.get_string() @@ -68,10 +74,10 @@ class DNSKEYBase(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - header = parser.get_struct('!HBB') + header = parser.get_struct("!HBB") key = parser.get_remaining() - return cls(rdclass, rdtype, header[0], header[1], header[2], - key) + return cls(rdclass, rdtype, header[0], header[1], header[2], key) + ### BEGIN generated Flag constants diff --git a/lib/dns/rdtypes/dnskeybase.pyi b/lib/dns/rdtypes/dnskeybase.pyi deleted file mode 100644 index 1b999cfd..00000000 --- a/lib/dns/rdtypes/dnskeybase.pyi +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Set, Any - -SEP : int -REVOKE : int -ZONE : int - -def flags_to_text_set(flags : int) -> Set[str]: - ... - -def flags_from_text_set(texts_set) -> int: - ... - -from .. import rdata - -class DNSKEYBase(rdata.Rdata): - def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): - self.flags : int - self.protocol : int - self.key : str - self.algorithm : int - - def to_text(self, origin : Any = None, relativize=True, **kw : Any): - ... - - @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): - ... - - def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - ... - - @classmethod - def from_parser(cls, rdclass, rdtype, parser, origin=None): - ... - - def flags_to_text_set(self) -> Set[str]: - ... diff --git a/lib/dns/rdtypes/dsbase.py b/lib/dns/rdtypes/dsbase.py index 0c2e7471..b6032b0f 100644 --- a/lib/dns/rdtypes/dsbase.py +++ b/lib/dns/rdtypes/dsbase.py @@ -18,7 +18,7 @@ import struct import binascii -import dns.dnssec +import dns.dnssectypes import dns.immutable import dns.rdata import dns.rdatatype @@ -29,9 +29,10 @@ class DSBase(dns.rdata.Rdata): """Base class for rdata that is like a DS record""" - __slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest'] + __slots__ = ["key_tag", "algorithm", "digest_type", "digest"] - # Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml + # Digest types registry: + # https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml _digest_length_by_type = { 1: 20, # SHA-1, RFC 3658 Sec. 2.4 2: 32, # SHA-256, RFC 4509 Sec. 2.2 @@ -39,43 +40,42 @@ class DSBase(dns.rdata.Rdata): 4: 48, # SHA-384, RFC 6605 Sec. 2 } - def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, - digest): + def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest): super().__init__(rdclass, rdtype) self.key_tag = self._as_uint16(key_tag) - self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.digest_type = self._as_uint8(digest_type) self.digest = self._as_bytes(digest) try: if len(self.digest) != self._digest_length_by_type[self.digest_type]: - raise ValueError('digest length inconsistent with digest type') + raise ValueError("digest length inconsistent with digest type") except KeyError: if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4 - raise ValueError('digest type 0 is reserved') + raise ValueError("digest type 0 is reserved") def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %d %s' % (self.key_tag, self.algorithm, - self.digest_type, - dns.rdata._hexify(self.digest, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.key_tag, + self.algorithm, + self.digest_type, + dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): key_tag = tok.get_uint16() algorithm = tok.get_string() digest_type = tok.get_uint8() digest = tok.concatenate_remaining_identifiers().encode() digest = binascii.unhexlify(digest) - return cls(rdclass, rdtype, key_tag, algorithm, digest_type, - digest) + return cls(rdclass, rdtype, key_tag, algorithm, digest_type, digest) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - header = struct.pack("!HBB", self.key_tag, self.algorithm, - self.digest_type) + header = struct.pack("!HBB", self.key_tag, self.algorithm, self.digest_type) file.write(header) file.write(self.digest) diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index 48b69bd3..e524aea9 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -27,7 +27,7 @@ class EUIBase(dns.rdata.Rdata): # see: rfc7043.txt - __slots__ = ['eui'] + __slots__ = ["eui"] # define these in subclasses # byte_len = 6 # 0123456789ab (in hex) # text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab @@ -36,28 +36,30 @@ class EUIBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.eui = self._as_bytes(eui) if len(self.eui) != self.byte_len: - raise dns.exception.FormError('EUI%s rdata has to have %s bytes' - % (self.byte_len * 8, self.byte_len)) + raise dns.exception.FormError( + "EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len) + ) def to_text(self, origin=None, relativize=True, **kw): - return dns.rdata._hexify(self.eui, chunksize=2, separator=b'-', **kw) + return dns.rdata._hexify(self.eui, chunksize=2, separator=b"-", **kw) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): text = tok.get_string() if len(text) != cls.text_len: raise dns.exception.SyntaxError( - 'Input text must have %s characters' % cls.text_len) + "Input text must have %s characters" % cls.text_len + ) for i in range(2, cls.byte_len * 3 - 1, 3): - if text[i] != '-': - raise dns.exception.SyntaxError('Dash expected at position %s' - % i) - text = text.replace('-', '') + if text[i] != "-": + raise dns.exception.SyntaxError("Dash expected at position %s" % i) + text = text.replace("-", "") try: data = binascii.unhexlify(text.encode()) except (ValueError, TypeError) as ex: - raise dns.exception.SyntaxError('Hex decoding error: %s' % str(ex)) + raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex)) return cls(rdclass, rdtype, data) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): diff --git a/lib/dns/rdtypes/mxbase.py b/lib/dns/rdtypes/mxbase.py index 56418234..b4b9b088 100644 --- a/lib/dns/rdtypes/mxbase.py +++ b/lib/dns/rdtypes/mxbase.py @@ -31,7 +31,7 @@ class MXBase(dns.rdata.Rdata): """Base class for rdata that is like an MX record.""" - __slots__ = ['preference', 'exchange'] + __slots__ = ["preference", "exchange"] def __init__(self, rdclass, rdtype, preference, exchange): super().__init__(rdclass, rdtype) @@ -40,11 +40,12 @@ class MXBase(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): exchange = self.exchange.choose_relativity(origin, relativize) - return '%d %s' % (self.preference, exchange) + return "%d %s" % (self.preference, exchange) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): preference = tok.get_uint16() exchange = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, preference, exchange) diff --git a/lib/dns/rdtypes/nsbase.py b/lib/dns/rdtypes/nsbase.py index b3e25506..ba7a2ab7 100644 --- a/lib/dns/rdtypes/nsbase.py +++ b/lib/dns/rdtypes/nsbase.py @@ -28,7 +28,7 @@ class NSBase(dns.rdata.Rdata): """Base class for rdata that is like an NS record.""" - __slots__ = ['target'] + __slots__ = ["target"] def __init__(self, rdclass, rdtype, target): super().__init__(rdclass, rdtype) @@ -39,8 +39,9 @@ class NSBase(dns.rdata.Rdata): return str(target) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): target = tok.get_name(origin, relativize, relativize_to) return cls(rdclass, rdtype, target) diff --git a/lib/dns/rdtypes/svcbbase.py b/lib/dns/rdtypes/svcbbase.py index 3362571c..8d6fb1c6 100644 --- a/lib/dns/rdtypes/svcbbase.py +++ b/lib/dns/rdtypes/svcbbase.py @@ -63,44 +63,48 @@ def _validate_key(key): if isinstance(key, bytes): # We decode to latin-1 so we get 0-255 as valid and do NOT interpret # UTF-8 sequences - key = key.decode('latin-1') + key = key.decode("latin-1") if isinstance(key, str): - if key.lower().startswith('key'): + if key.lower().startswith("key"): force_generic = True - if key[3:].startswith('0') and len(key) != 4: + if key[3:].startswith("0") and len(key) != 4: # key has leading zeros - raise ValueError('leading zeros in key') - key = key.replace('-', '_') + raise ValueError("leading zeros in key") + key = key.replace("-", "_") return (ParamKey.make(key), force_generic) + def key_to_text(key): - return ParamKey.to_text(key).replace('_', '-').lower() + return ParamKey.to_text(key).replace("_", "-").lower() + # Like rdata escapify, but escapes ',' too. _escaped = b'",\\' + def _escapify(qstring): - text = '' + text = "" for c in qstring: if c in _escaped: - text += '\\' + chr(c) + text += "\\" + chr(c) elif c >= 0x20 and c < 0x7F: text += chr(c) else: - text += '\\%03d' % c + text += "\\%03d" % c return text + def _unescape(value): - if value == '': + if value == "": return value - unescaped = b'' + unescaped = b"" l = len(value) i = 0 while i < l: c = value[i] i += 1 - if c == '\\': + if c == "\\": if i >= l: # pragma: no cover (can't happen via tokenizer get()) raise dns.exception.UnexpectedEnd c = value[i] @@ -119,7 +123,7 @@ def _unescape(value): codepoint = int(c) * 100 + int(c2) * 10 + int(c3) if codepoint > 255: raise dns.exception.SyntaxError - unescaped += b'%c' % (codepoint) + unescaped += b"%c" % (codepoint) continue unescaped += c.encode() return unescaped @@ -129,21 +133,21 @@ def _split(value): l = len(value) i = 0 items = [] - unescaped = b'' + unescaped = b"" while i < l: c = value[i] i += 1 - if c == ord('\\'): + if c == ord("\\"): if i >= l: # pragma: no cover (can't happen via tokenizer get()) raise dns.exception.UnexpectedEnd c = value[i] i += 1 - unescaped += b'%c' % (c) - elif c == ord(','): + unescaped += b"%c" % (c) + elif c == ord(","): items.append(unescaped) - unescaped = b'' + unescaped = b"" else: - unescaped += b'%c' % (c) + unescaped += b"%c" % (c) items.append(unescaped) return items @@ -159,8 +163,8 @@ class Param: @dns.immutable.immutable class GenericParam(Param): - """Generic SVCB parameter - """ + """Generic SVCB parameter""" + def __init__(self, value): self.value = dns.rdata.Rdata._as_bytes(value, True) @@ -198,19 +202,19 @@ class MandatoryParam(Param): prior_k = None for k in keys: if k == prior_k: - raise ValueError(f'duplicate key {k:d}') + raise ValueError(f"duplicate key {k:d}") prior_k = k if k == ParamKey.MANDATORY: - raise ValueError('listed the mandatory key as mandatory') + raise ValueError("listed the mandatory key as mandatory") self.keys = tuple(keys) @classmethod def from_value(cls, value): - keys = [k.encode() for k in value.split(',')] + keys = [k.encode() for k in value.split(",")] return cls(keys) def to_text(self): - return '"' + ','.join([key_to_text(key) for key in self.keys]) + '"' + return '"' + ",".join([key_to_text(key) for key in self.keys]) + '"' @classmethod def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 @@ -219,28 +223,29 @@ class MandatoryParam(Param): while parser.remaining() > 0: key = parser.get_uint16() if key < last_key: - raise dns.exception.FormError('manadatory keys not ascending') + raise dns.exception.FormError("manadatory keys not ascending") last_key = key keys.append(key) return cls(keys) def to_wire(self, file, origin=None): # pylint: disable=W0613 for key in self.keys: - file.write(struct.pack('!H', key)) + file.write(struct.pack("!H", key)) @dns.immutable.immutable class ALPNParam(Param): def __init__(self, ids): self.ids = dns.rdata.Rdata._as_tuple( - ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False)) + ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False) + ) @classmethod def from_value(cls, value): return cls(_split(_unescape(value))) def to_text(self): - value = ','.join([_escapify(id) for id in self.ids]) + value = ",".join([_escapify(id) for id in self.ids]) return '"' + dns.rdata._escapify(value.encode()) + '"' @classmethod @@ -253,7 +258,7 @@ class ALPNParam(Param): def to_wire(self, file, origin=None): # pylint: disable=W0613 for id in self.ids: - file.write(struct.pack('!B', len(id))) + file.write(struct.pack("!B", len(id))) file.write(id) @@ -269,10 +274,10 @@ class NoDefaultALPNParam(Param): @classmethod def from_value(cls, value): - if value is None or value == '': + if value is None or value == "": return None else: - raise ValueError('no-default-alpn with non-empty value') + raise ValueError("no-default-alpn with non-empty value") def to_text(self): raise NotImplementedError # pragma: no cover @@ -306,22 +311,23 @@ class PortParam(Param): return cls(port) def to_wire(self, file, origin=None): # pylint: disable=W0613 - file.write(struct.pack('!H', self.port)) + file.write(struct.pack("!H", self.port)) @dns.immutable.immutable class IPv4HintParam(Param): def __init__(self, addresses): self.addresses = dns.rdata.Rdata._as_tuple( - addresses, dns.rdata.Rdata._as_ipv4_address) + addresses, dns.rdata.Rdata._as_ipv4_address + ) @classmethod def from_value(cls, value): - addresses = value.split(',') + addresses = value.split(",") return cls(addresses) def to_text(self): - return '"' + ','.join(self.addresses) + '"' + return '"' + ",".join(self.addresses) + '"' @classmethod def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 @@ -340,15 +346,16 @@ class IPv4HintParam(Param): class IPv6HintParam(Param): def __init__(self, addresses): self.addresses = dns.rdata.Rdata._as_tuple( - addresses, dns.rdata.Rdata._as_ipv6_address) + addresses, dns.rdata.Rdata._as_ipv6_address + ) @classmethod def from_value(cls, value): - addresses = value.split(',') + addresses = value.split(",") return cls(addresses) def to_text(self): - return '"' + ','.join(self.addresses) + '"' + return '"' + ",".join(self.addresses) + '"' @classmethod def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 @@ -370,13 +377,13 @@ class ECHParam(Param): @classmethod def from_value(cls, value): - if '\\' in value: - raise ValueError('escape in ECH value') + if "\\" in value: + raise ValueError("escape in ECH value") value = base64.b64decode(value.encode()) return cls(value) def to_text(self): - b64 = base64.b64encode(self.ech).decode('ascii') + b64 = base64.b64encode(self.ech).decode("ascii") return f'"{b64}"' @classmethod @@ -407,7 +414,7 @@ def _validate_and_define(params, key, value): emptiness = cls.emptiness() if value is None: if emptiness == Emptiness.NEVER: - raise SyntaxError('value cannot be empty') + raise SyntaxError("value cannot be empty") value = cls.from_value(value) else: if force_generic: @@ -422,9 +429,9 @@ class SVCBBase(dns.rdata.Rdata): """Base class for SVCB-like records""" - # see: draft-ietf-dnsop-svcb-https-01 + # see: draft-ietf-dnsop-svcb-https-11 - __slots__ = ['priority', 'target', 'params'] + __slots__ = ["priority", "target", "params"] def __init__(self, rdclass, rdtype, priority, target, params): super().__init__(rdclass, rdtype) @@ -433,7 +440,7 @@ class SVCBBase(dns.rdata.Rdata): for k, v in params.items(): k = ParamKey.make(k) if not isinstance(v, Param) and v is not None: - raise ValueError("not a Param") + raise ValueError(f"{k:d} not a Param") self.params = dns.immutable.Dict(params) # Make sure any parameter listed as mandatory is present in the # record. @@ -443,12 +450,11 @@ class SVCBBase(dns.rdata.Rdata): # Note we have to say "not in" as we have None as a value # so a get() and a not None test would be wrong. if key not in params: - raise ValueError(f'key {key:d} declared mandatory but not ' - 'present') + raise ValueError(f"key {key:d} declared mandatory but not present") # The no-default-alpn parameter requires the alpn parameter. if ParamKey.NO_DEFAULT_ALPN in params: if ParamKey.ALPN not in params: - raise ValueError('no-default-alpn present, but alpn missing') + raise ValueError("no-default-alpn present, but alpn missing") def to_text(self, origin=None, relativize=True, **kw): target = self.target.choose_relativity(origin, relativize) @@ -458,23 +464,24 @@ class SVCBBase(dns.rdata.Rdata): if value is None: params.append(key_to_text(key)) else: - kv = key_to_text(key) + '=' + value.to_text() + kv = key_to_text(key) + "=" + value.to_text() params.append(kv) if len(params) > 0: - space = ' ' + space = " " else: - space = '' - return '%d %s%s%s' % (self.priority, target, space, ' '.join(params)) + space = "" + return "%d %s%s%s" % (self.priority, target, space, " ".join(params)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): priority = tok.get_uint16() target = tok.get_name(origin, relativize, relativize_to) if priority == 0: token = tok.get() if not token.is_eol_or_eof(): - raise SyntaxError('parameters in AliasMode') + raise SyntaxError("parameters in AliasMode") tok.unget(token) params = {} while True: @@ -483,20 +490,20 @@ class SVCBBase(dns.rdata.Rdata): tok.unget(token) break if token.ttype != dns.tokenizer.IDENTIFIER: - raise SyntaxError('parameter is not an identifier') - equals = token.value.find('=') + raise SyntaxError("parameter is not an identifier") + equals = token.value.find("=") if equals == len(token.value) - 1: # 'key=', so next token should be a quoted string without # any intervening whitespace. key = token.value[:-1] token = tok.get(want_leading=True) if token.ttype != dns.tokenizer.QUOTED_STRING: - raise SyntaxError('whitespace after =') + raise SyntaxError("whitespace after =") value = token.value elif equals > 0: # key=value key = token.value[:equals] - value = token.value[equals + 1:] + value = token.value[equals + 1 :] elif equals == 0: # =key raise SyntaxError('parameter cannot start with "="') @@ -532,13 +539,13 @@ class SVCBBase(dns.rdata.Rdata): priority = parser.get_uint16() target = parser.get_name(origin) if priority == 0 and parser.remaining() != 0: - raise dns.exception.FormError('parameters in AliasMode') + raise dns.exception.FormError("parameters in AliasMode") params = {} prior_key = -1 while parser.remaining() > 0: key = parser.get_uint16() if key < prior_key: - raise dns.exception.FormError('keys not in order') + raise dns.exception.FormError("keys not in order") prior_key = key vlen = parser.get_uint16() pcls = _class_for_key.get(key, GenericParam) diff --git a/lib/dns/rdtypes/tlsabase.py b/lib/dns/rdtypes/tlsabase.py index 786fca55..a3fdc354 100644 --- a/lib/dns/rdtypes/tlsabase.py +++ b/lib/dns/rdtypes/tlsabase.py @@ -30,10 +30,9 @@ class TLSABase(dns.rdata.Rdata): # see: RFC 6698 - __slots__ = ['usage', 'selector', 'mtype', 'cert'] + __slots__ = ["usage", "selector", "mtype", "cert"] - def __init__(self, rdclass, rdtype, usage, selector, - mtype, cert): + def __init__(self, rdclass, rdtype, usage, selector, mtype, cert): super().__init__(rdclass, rdtype) self.usage = self._as_uint8(usage) self.selector = self._as_uint8(selector) @@ -42,17 +41,18 @@ class TLSABase(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() - chunksize = kw.pop('chunksize', 128) - return '%d %d %d %s' % (self.usage, - self.selector, - self.mtype, - dns.rdata._hexify(self.cert, - chunksize=chunksize, - **kw)) + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.usage, + self.selector, + self.mtype, + dns.rdata._hexify(self.cert, chunksize=chunksize, **kw), + ) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): usage = tok.get_uint8() selector = tok.get_uint8() mtype = tok.get_uint8() diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index 68071ee0..d4cb9bb2 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -17,6 +17,8 @@ """TXT-like base class.""" +from typing import Any, Dict, Iterable, Optional, Tuple, Union + import struct import dns.exception @@ -30,9 +32,14 @@ class TXTBase(dns.rdata.Rdata): """Base class for rdata that is like a TXT record (see RFC 1035).""" - __slots__ = ['strings'] + __slots__ = ["strings"] - def __init__(self, rdclass, rdtype, strings): + def __init__( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + strings: Iterable[Union[bytes, str]], + ): """Initialize a TXT-like rdata. *rdclass*, an ``int`` is the rdataclass of the Rdata. @@ -42,27 +49,41 @@ class TXTBase(dns.rdata.Rdata): *strings*, a tuple of ``bytes`` """ super().__init__(rdclass, rdtype) - self.strings = self._as_tuple(strings, - lambda x: self._as_bytes(x, True, 255)) + self.strings: Tuple[bytes] = self._as_tuple( + strings, lambda x: self._as_bytes(x, True, 255) + ) - def to_text(self, origin=None, relativize=True, **kw): - txt = '' - prefix = '' + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: + txt = "" + prefix = "" for s in self.strings: txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) - prefix = ' ' + prefix = " " return txt @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.rdata.Rdata: strings = [] for token in tok.get_remaining(): token = token.unescape_to_bytes() # The 'if' below is always true in the current code, but we # are leaving this check in in case things change some day. - if not (token.is_quoted_string() or - token.is_identifier()): # pragma: no cover + if not ( + token.is_quoted_string() or token.is_identifier() + ): # pragma: no cover raise dns.exception.SyntaxError("expected a string") if len(token.value) > 255: raise dns.exception.SyntaxError("string too long") @@ -75,7 +96,7 @@ class TXTBase(dns.rdata.Rdata): for s in self.strings: l = len(s) assert l < 256 - file.write(struct.pack('!B', l)) + file.write(struct.pack("!B", l)) file.write(s) @classmethod diff --git a/lib/dns/rdtypes/txtbase.pyi b/lib/dns/rdtypes/txtbase.pyi deleted file mode 100644 index f8d5df98..00000000 --- a/lib/dns/rdtypes/txtbase.pyi +++ /dev/null @@ -1,12 +0,0 @@ -import typing -from .. import rdata - -class TXTBase(rdata.Rdata): - strings: typing.Tuple[bytes, ...] - - def __init__(self, rdclass: int, rdtype: int, strings: typing.Iterable[bytes]) -> None: - ... - def to_text(self, origin: typing.Any, relativize: bool, **kw: typing.Any) -> str: - ... -class TXT(TXTBase): - ... diff --git a/lib/dns/rdtypes/util.py b/lib/dns/rdtypes/util.py index 9bf8f7e9..74596f05 100644 --- a/lib/dns/rdtypes/util.py +++ b/lib/dns/rdtypes/util.py @@ -28,6 +28,7 @@ import dns.rdata class Gateway: """A helper class for the IPSECKEY gateway and AMTRELAY relay fields""" + name = "" def __init__(self, type, gateway=None): @@ -67,15 +68,17 @@ class Gateway: raise ValueError(self._invalid_type(self.type)) # pragma: no cover @classmethod - def from_text(cls, gateway_type, tok, origin=None, relativize=True, - relativize_to=None): + def from_text( + cls, gateway_type, tok, origin=None, relativize=True, relativize_to=None + ): if gateway_type in (0, 1, 2): gateway = tok.get_string() elif gateway_type == 3: gateway = tok.get_name(origin, relativize, relativize_to) else: raise dns.exception.SyntaxError( - cls._invalid_type(gateway_type)) # pragma: no cover + cls._invalid_type(gateway_type) + ) # pragma: no cover return cls(gateway_type, gateway) # pylint: disable=unused-argument @@ -90,6 +93,7 @@ class Gateway: self.gateway.to_wire(file, None, origin, False) else: raise ValueError(self._invalid_type(self.type)) # pragma: no cover + # pylint: enable=unused-argument @classmethod @@ -109,6 +113,7 @@ class Gateway: class Bitmap: """A helper class for the NSEC/NSEC3/CSYNC type bitmaps""" + type_name = "" def __init__(self, windows=None): @@ -136,7 +141,7 @@ class Bitmap: if byte & (0x80 >> j): rdtype = window * 256 + i * 8 + j bits.append(dns.rdatatype.to_text(rdtype)) - text += (' ' + ' '.join(bits)) + text += " " + " ".join(bits) return text @classmethod @@ -151,7 +156,7 @@ class Bitmap: window = 0 octets = 0 prior_rdtype = 0 - bitmap = bytearray(b'\0' * 32) + bitmap = bytearray(b"\0" * 32) windows = [] for rdtype in rdtypes: if rdtype == prior_rdtype: @@ -161,7 +166,7 @@ class Bitmap: if new_window != window: if octets != 0: windows.append((window, bytes(bitmap[0:octets]))) - bitmap = bytearray(b'\0' * 32) + bitmap = bytearray(b"\0" * 32) window = new_window offset = rdtype % 256 byte = offset // 8 @@ -174,7 +179,7 @@ class Bitmap: def to_wire(self, file): for (window, bitmap) in self.windows: - file.write(struct.pack('!BB', window, len(bitmap))) + file.write(struct.pack("!BB", window, len(bitmap))) file.write(bitmap) @classmethod @@ -193,6 +198,7 @@ def _priority_table(items): by_priority[rdata._processing_priority()].append(rdata) return by_priority + def priority_processing_order(iterable): items = list(iterable) if len(items) == 1: @@ -205,8 +211,10 @@ def priority_processing_order(iterable): ordered.extend(rdatas) return ordered + _no_weight = 0.1 + def weighted_processing_order(iterable): items = list(iterable) if len(items) == 1: @@ -215,8 +223,7 @@ def weighted_processing_order(iterable): ordered = [] for k in sorted(by_priority.keys()): rdatas = by_priority[k] - total = sum(rdata._processing_weight() or _no_weight - for rdata in rdatas) + total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) while len(rdatas) > 1: r = random.uniform(0, total) for (n, rdata) in enumerate(rdatas): @@ -230,15 +237,16 @@ def weighted_processing_order(iterable): ordered.append(rdatas[0]) return ordered + def parse_formatted_hex(formatted, num_chunks, chunk_size, separator): if len(formatted) != num_chunks * (chunk_size + 1) - 1: - raise ValueError('invalid formatted hex string') - value = b'' + raise ValueError("invalid formatted hex string") + value = b"" for _ in range(num_chunks): chunk = formatted[0:chunk_size] - value += int(chunk, 16).to_bytes(chunk_size // 2, 'big') + value += int(chunk, 16).to_bytes(chunk_size // 2, "big") formatted = formatted[chunk_size:] if len(formatted) > 0 and formatted[0] != separator: - raise ValueError('invalid formatted hex string') + raise ValueError("invalid formatted hex string") formatted = formatted[1:] return value diff --git a/lib/dns/renderer.py b/lib/dns/renderer.py index 4e4391cd..3c495f61 100644 --- a/lib/dns/renderer.py +++ b/lib/dns/renderer.py @@ -48,13 +48,17 @@ class Renderer: r.add_rrset(dns.renderer.ANSWER, rrset_1) r.add_rrset(dns.renderer.ANSWER, rrset_2) r.add_rrset(dns.renderer.AUTHORITY, ns_rrset) - r.add_edns(0, 0, 4096) r.add_rrset(dns.renderer.ADDITIONAL, ad_rrset_1) r.add_rrset(dns.renderer.ADDITIONAL, ad_rrset_2) + r.add_edns(0, 0, 4096) r.write_header() r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac) wire = r.get_wire() + If padding is going to be used, then the OPT record MUST be + written after everything else in the additional section except for + the TSIG (if any). + output, an io.BytesIO, where rendering is written id: the message id @@ -88,8 +92,10 @@ class Renderer: self.compress = {} self.section = QUESTION self.counts = [0, 0, 0, 0] - self.output.write(b'\x00' * 12) - self.mac = '' + self.output.write(b"\x00" * 12) + self.mac = "" + self.reserved = 0 + self.was_padded = False def _rollback(self, where): """Truncate the output buffer at offset *where*, and remove any @@ -160,21 +166,52 @@ class Renderer: self._set_section(section) with self._track_size(): - n = rdataset.to_wire(name, self.output, self.compress, self.origin, - **kw) + n = rdataset.to_wire(name, self.output, self.compress, self.origin, **kw) self.counts[section] += n + def add_opt(self, opt, pad=0, opt_size=0, tsig_size=0): + """Add *opt* to the additional section, applying padding if desired. The + padding will take the specified precomputed OPT size and TSIG size into + account. + + Note that we don't have reliable way of knowing how big a GSS-TSIG digest + might be, so we we might not get an even multiple of the pad in that case.""" + if pad: + ttl = opt.ttl + assert opt_size >= 11 + opt_rdata = opt[0] + size_without_padding = self.output.tell() + opt_size + tsig_size + remainder = size_without_padding % pad + if remainder: + pad = b"\x00" * (pad - remainder) + else: + pad = b"" + options = list(opt_rdata.options) + options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad)) + opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options) + self.was_padded = True + self.add_rrset(ADDITIONAL, opt) + def add_edns(self, edns, ednsflags, payload, options=None): """Add an EDNS OPT record to the message.""" # make sure the EDNS version in ednsflags agrees with edns ednsflags &= 0xFF00FFFF - ednsflags |= (edns << 16) + ednsflags |= edns << 16 opt = dns.message.Message._make_opt(ednsflags, payload, options) - self.add_rrset(ADDITIONAL, opt) + self.add_opt(opt) - def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, - request_mac, algorithm=dns.tsig.default_algorithm): + def add_tsig( + self, + keyname, + secret, + fudge, + id, + tsig_error, + other_data, + request_mac, + algorithm=dns.tsig.default_algorithm, + ): """Add a TSIG signature to the message.""" s = self.output.getvalue() @@ -183,15 +220,24 @@ class Renderer: key = secret else: key = dns.tsig.Key(keyname, secret, algorithm) - tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, - b'', id, tsig_error, other_data) - (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), - request_mac) + tsig = dns.message.Message._make_tsig( + keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data + ) + (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac) self._write_tsig(tsig, keyname) - def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error, - other_data, request_mac, - algorithm=dns.tsig.default_algorithm): + def add_multi_tsig( + self, + ctx, + keyname, + secret, + fudge, + id, + tsig_error, + other_data, + request_mac, + algorithm=dns.tsig.default_algorithm, + ): """Add a TSIG signature to the message. Unlike add_tsig(), this can be used for a series of consecutive DNS envelopes, e.g. for a zone transfer over TCP [RFC2845, 4.4]. @@ -206,28 +252,35 @@ class Renderer: key = secret else: key = dns.tsig.Key(keyname, secret, algorithm) - tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, - b'', id, tsig_error, other_data) - (tsig, ctx) = dns.tsig.sign(s, key, tsig[0], int(time.time()), - request_mac, ctx, True) + tsig = dns.message.Message._make_tsig( + keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data + ) + (tsig, ctx) = dns.tsig.sign( + s, key, tsig[0], int(time.time()), request_mac, ctx, True + ) self._write_tsig(tsig, keyname) return ctx def _write_tsig(self, tsig, keyname): + if self.was_padded: + compress = None + else: + compress = self.compress self._set_section(ADDITIONAL) with self._track_size(): - keyname.to_wire(self.output, self.compress, self.origin) - self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, - dns.rdataclass.ANY, 0, 0)) + keyname.to_wire(self.output, compress, self.origin) + self.output.write( + struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0) + ) rdata_start = self.output.tell() tsig.to_wire(self.output) after = self.output.tell() self.output.seek(rdata_start - 2) - self.output.write(struct.pack('!H', after - rdata_start)) + self.output.write(struct.pack("!H", after - rdata_start)) self.counts[ADDITIONAL] += 1 self.output.seek(10) - self.output.write(struct.pack('!H', self.counts[ADDITIONAL])) + self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) self.output.seek(0, io.SEEK_END) def write_header(self): @@ -239,12 +292,34 @@ class Renderer: """ self.output.seek(0) - self.output.write(struct.pack('!HHHHHH', self.id, self.flags, - self.counts[0], self.counts[1], - self.counts[2], self.counts[3])) + self.output.write( + struct.pack( + "!HHHHHH", + self.id, + self.flags, + self.counts[0], + self.counts[1], + self.counts[2], + self.counts[3], + ) + ) self.output.seek(0, io.SEEK_END) def get_wire(self): """Return the wire format message.""" return self.output.getvalue() + + def reserve(self, size: int) -> None: + """Reserve *size* bytes.""" + if size < 0: + raise ValueError("reserved amount must be non-negative") + if size > self.max_size: + raise ValueError("cannot reserve more than the maximum size") + self.reserved += size + self.max_size -= size + + def release_reserved(self) -> None: + """Release the reserved bytes.""" + self.max_size += self.reserved + self.reserved = 0 diff --git a/lib/dns/resolver.py b/lib/dns/resolver.py index 7da7a613..a5b66c1d 100644 --- a/lib/dns/resolver.py +++ b/lib/dns/resolver.py @@ -16,19 +16,20 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. """DNS stub resolver.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + from urllib.parse import urlparse import contextlib import socket import sys +import threading import time import random import warnings -try: - import threading as _threading -except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore import dns.exception +import dns.edns import dns.flags import dns.inet import dns.ipv4 @@ -42,18 +43,24 @@ import dns.rdatatype import dns.reversename import dns.tsig -if sys.platform == 'win32': +if sys.platform == "win32": import dns.win32util + class NXDOMAIN(dns.exception.DNSException): """The DNS query name does not exist.""" - supp_kwargs = {'qnames', 'responses'} + + supp_kwargs = {"qnames", "responses"} fmt = None # we have our own __str__ implementation # pylint: disable=arguments-differ - def _check_kwargs(self, qnames, - responses=None): + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _check_kwargs(self, qnames, responses=None): if not isinstance(qnames, (list, tuple, set)): raise AttributeError("qnames must be a list, tuple or set") if len(qnames) == 0: @@ -66,23 +73,23 @@ class NXDOMAIN(dns.exception.DNSException): return kwargs def __str__(self): - if 'qnames' not in self.kwargs: + if "qnames" not in self.kwargs: return super().__str__() - qnames = self.kwargs['qnames'] + qnames = self.kwargs["qnames"] if len(qnames) > 1: - msg = 'None of DNS query names exist' + msg = "None of DNS query names exist" else: - msg = 'The DNS query name does not exist' - qnames = ', '.join(map(str, qnames)) + msg = "The DNS query name does not exist" + qnames = ", ".join(map(str, qnames)) return "{}: {}".format(msg, qnames) @property def canonical_name(self): """Return the unresolved canonical name.""" - if 'qnames' not in self.kwargs: + if "qnames" not in self.kwargs: raise TypeError("parametrized exception required") - for qname in self.kwargs['qnames']: - response = self.kwargs['responses'][qname] + for qname in self.kwargs["qnames"]: + response = self.kwargs["responses"][qname] try: cname = response.canonical_name() if cname != qname: @@ -91,14 +98,14 @@ class NXDOMAIN(dns.exception.DNSException): # We can just eat this exception as it means there was # something wrong with the response. pass - return self.kwargs['qnames'][0] + return self.kwargs["qnames"][0] def __add__(self, e_nx): """Augment by results from another NXDOMAIN exception.""" - qnames0 = list(self.kwargs.get('qnames', [])) - responses0 = dict(self.kwargs.get('responses', {})) - responses1 = e_nx.kwargs.get('responses', {}) - for qname1 in e_nx.kwargs.get('qnames', []): + qnames0 = list(self.kwargs.get("qnames", [])) + responses0 = dict(self.kwargs.get("responses", {})) + responses1 = e_nx.kwargs.get("responses", {}) + for qname1 in e_nx.kwargs.get("qnames", []): if qname1 not in qnames0: qnames0.append(qname1) if qname1 in responses1: @@ -110,7 +117,7 @@ class NXDOMAIN(dns.exception.DNSException): Returns a list of ``dns.name.Name``. """ - return self.kwargs['qnames'] + return self.kwargs["qnames"] def responses(self): """A map from queried names to their NXDOMAIN responses. @@ -118,26 +125,34 @@ class NXDOMAIN(dns.exception.DNSException): Returns a dict mapping a ``dns.name.Name`` to a ``dns.message.Message``. """ - return self.kwargs['responses'] + return self.kwargs["responses"] def response(self, qname): """The response for query *qname*. Returns a ``dns.message.Message``. """ - return self.kwargs['responses'][qname] + return self.kwargs["responses"][qname] class YXDOMAIN(dns.exception.DNSException): """The DNS query name is too long after DNAME substitution.""" -def _errors_to_text(errors): +ErrorTuple = Tuple[ + Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message] +] + + +def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: """Turn a resolution errors trace into a list of text.""" texts = [] for err in errors: - texts.append('Server {} {} port {} answered {}'.format(err[0], - 'TCP' if err[1] else 'UDP', err[2], err[3])) + texts.append( + "Server {} {} port {} answered {}".format( + err[0], "TCP" if err[1] else "UDP", err[2], err[3] + ) + ) return texts @@ -146,12 +161,18 @@ class LifetimeTimeout(dns.exception.Timeout): msg = "The resolution lifetime expired." fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1] - supp_kwargs = {'timeout', 'errors'} + supp_kwargs = {"timeout", "errors"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _fmt_kwargs(self, **kwargs): - srv_msgs = _errors_to_text(kwargs['errors']) - return super()._fmt_kwargs(timeout=kwargs['timeout'], - errors='; '.join(srv_msgs)) + srv_msgs = _errors_to_text(kwargs["errors"]) + return super()._fmt_kwargs( + timeout=kwargs["timeout"], errors="; ".join(srv_msgs) + ) # We added more detail to resolution timeouts, but they are still @@ -162,15 +183,20 @@ Timeout = LifetimeTimeout class NoAnswer(dns.exception.DNSException): """The DNS response does not contain an answer to the question.""" - fmt = 'The DNS response does not contain an answer ' + \ - 'to the question: {query}' - supp_kwargs = {'response'} + + fmt = "The DNS response does not contain an answer " + "to the question: {query}" + supp_kwargs = {"response"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _fmt_kwargs(self, **kwargs): - return super()._fmt_kwargs(query=kwargs['response'].question) + return super()._fmt_kwargs(query=kwargs["response"].question) def response(self): - return self.kwargs['response'] + return self.kwargs["response"] class NoNameservers(dns.exception.DNSException): @@ -184,12 +210,18 @@ class NoNameservers(dns.exception.DNSException): msg = "All nameservers failed to answer the query." fmt = "%s {query}: {errors}" % msg[:-1] - supp_kwargs = {'request', 'errors'} + supp_kwargs = {"request", "errors"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _fmt_kwargs(self, **kwargs): - srv_msgs = _errors_to_text(kwargs['errors']) - return super()._fmt_kwargs(query=kwargs['request'].question, - errors='; '.join(srv_msgs)) + srv_msgs = _errors_to_text(kwargs["errors"]) + return super()._fmt_kwargs( + query=kwargs["request"].question, errors="; ".join(srv_msgs) + ) class NotAbsolute(dns.exception.DNSException): @@ -203,9 +235,11 @@ class NoRootSOA(dns.exception.DNSException): class NoMetaqueries(dns.exception.DNSException): """DNS metaqueries are not allowed.""" + class NoResolverConfiguration(dns.exception.DNSException): """Resolver configuration could not be read or specified no nameservers.""" + class Answer: """DNS stub resolver answer. @@ -222,8 +256,15 @@ class Answer: RRset's name might not be the query name. """ - def __init__(self, qname, rdtype, rdclass, response, nameserver=None, - port=None): + def __init__( + self, + qname: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + rdclass: dns.rdataclass.RdataClass, + response: dns.message.QueryMessage, + nameserver: Optional[str] = None, + port: Optional[int] = None, + ): self.qname = qname self.rdtype = rdtype self.rdclass = rdclass @@ -238,15 +279,15 @@ class Answer: self.expiration = time.time() + self.chaining_result.minimum_ttl def __getattr__(self, attr): # pragma: no cover - if attr == 'name': + if attr == "name": return self.rrset.name - elif attr == 'ttl': + elif attr == "ttl": return self.rrset.ttl - elif attr == 'covers': + elif attr == "covers": return self.rrset.covers - elif attr == 'rdclass': + elif attr == "rdclass": return self.rrset.rdclass - elif attr == 'rdtype': + elif attr == "rdtype": return self.rrset.rdtype else: raise AttributeError(attr) @@ -269,8 +310,7 @@ class Answer: class CacheStatistics: - """Cache Statistics - """ + """Cache Statistics""" def __init__(self, hits=0, misses=0): self.hits = hits @@ -280,31 +320,31 @@ class CacheStatistics: self.hits = 0 self.misses = 0 - def clone(self): + def clone(self) -> "CacheStatistics": return CacheStatistics(self.hits, self.misses) class CacheBase: def __init__(self): - self.lock = _threading.Lock() + self.lock = threading.Lock() self.statistics = CacheStatistics() - def reset_statistics(self): + def reset_statistics(self) -> None: """Reset all statistics to zero.""" with self.lock: self.statistics.reset() - def hits(self): + def hits(self) -> int: """How many hits has the cache had?""" with self.lock: return self.statistics.hits - def misses(self): + def misses(self) -> int: """How many misses has the cache had?""" with self.lock: return self.statistics.misses - def get_statistics_snapshot(self): + def get_statistics_snapshot(self) -> CacheStatistics: """Return a consistent snapshot of all the statistics. If running with multiple threads, it's better to take a @@ -315,20 +355,23 @@ class CacheBase: return self.statistics.clone() +CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass] + + class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" - def __init__(self, cleaning_interval=300.0): + def __init__(self, cleaning_interval: float = 300.0): """*cleaning_interval*, a ``float`` is the number of seconds between periodic cleanings. """ super().__init__() - self.data = {} + self.data: Dict[CacheKey, Answer] = {} self.cleaning_interval = cleaning_interval - self.next_cleaning = time.time() + self.cleaning_interval + self.next_cleaning: float = time.time() + self.cleaning_interval - def _maybe_clean(self): + def _maybe_clean(self) -> None: """Clean the cache if it's time to do so.""" now = time.time() @@ -342,13 +385,13 @@ class Cache(CacheBase): now = time.time() self.next_cleaning = now + self.cleaning_interval - def get(self, key): + def get(self, key: CacheKey) -> Optional[Answer]: """Get the answer associated with *key*. Returns None if no answer is cached for the key. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. Returns a ``dns.resolver.Answer`` or ``None``. """ @@ -362,11 +405,11 @@ class Cache(CacheBase): self.statistics.hits += 1 return v - def put(self, key, value): + def put(self, key: CacheKey, value: Answer) -> None: """Associate key and value in the cache. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. *value*, a ``dns.resolver.Answer``, the answer. """ @@ -375,14 +418,14 @@ class Cache(CacheBase): self._maybe_clean() self.data[key] = value - def flush(self, key=None): + def flush(self, key: Optional[CacheKey] = None) -> None: """Flush the cache. - If *key* is not ``None``, only that item is flushed. Otherwise - the entire cache is flushed. + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache + is flushed. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. """ with self.lock: @@ -425,30 +468,30 @@ class LRUCache(CacheBase): for a new one. """ - def __init__(self, max_size=100000): + def __init__(self, max_size: int = 100000): """*max_size*, an ``int``, is the maximum number of nodes to cache; it must be greater than 0. """ super().__init__() - self.data = {} + self.data: Dict[CacheKey, LRUCacheNode] = {} self.set_max_size(max_size) - self.sentinel = LRUCacheNode(None, None) + self.sentinel: LRUCacheNode = LRUCacheNode(None, None) self.sentinel.prev = self.sentinel self.sentinel.next = self.sentinel - def set_max_size(self, max_size): + def set_max_size(self, max_size: int) -> None: if max_size < 1: max_size = 1 self.max_size = max_size - def get(self, key): + def get(self, key: CacheKey) -> Optional[Answer]: """Get the answer associated with *key*. Returns None if no answer is cached for the key. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. Returns a ``dns.resolver.Answer`` or ``None``. """ @@ -470,7 +513,7 @@ class LRUCache(CacheBase): node.hits += 1 return node.value - def get_hits_for_key(self, key): + def get_hits_for_key(self, key: CacheKey) -> int: """Return the number of cache hits associated with the specified key.""" with self.lock: node = self.data.get(key) @@ -479,11 +522,11 @@ class LRUCache(CacheBase): else: return node.hits - def put(self, key, value): + def put(self, key: CacheKey, value: Answer) -> None: """Associate key and value in the cache. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. *value*, a ``dns.resolver.Answer``, the answer. """ @@ -494,21 +537,21 @@ class LRUCache(CacheBase): node.unlink() del self.data[node.key] while len(self.data) >= self.max_size: - node = self.sentinel.prev - node.unlink() - del self.data[node.key] + gnode = self.sentinel.prev + gnode.unlink() + del self.data[gnode.key] node = LRUCacheNode(key, value) node.link_after(self.sentinel) self.data[key] = node - def flush(self, key=None): + def flush(self, key: Optional[CacheKey] = None) -> None: """Flush the cache. - If *key* is not ``None``, only that item is flushed. Otherwise - the entire cache is flushed. + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache + is flushed. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the - query name, rdtype, and rdclass respectively. + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. """ with self.lock: @@ -518,13 +561,14 @@ class LRUCache(CacheBase): node.unlink() del self.data[node.key] else: - node = self.sentinel.next - while node != self.sentinel: - next = node.next - node.unlink() - node = next + gnode = self.sentinel.next + while gnode != self.sentinel: + next = gnode.next + gnode.unlink() + gnode = next self.data = {} + class _Resolution: """Helper class for dns.resolver.Resolver.resolve(). @@ -537,38 +581,47 @@ class _Resolution: resolver data structures directly. """ - def __init__(self, resolver, qname, rdtype, rdclass, tcp, - raise_on_no_answer, search): + def __init__( + self, + resolver: "BaseResolver", + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str], + tcp: bool, + raise_on_no_answer: bool, + search: Optional[bool], + ): if isinstance(qname, str): qname = dns.name.from_text(qname, None) - rdtype = dns.rdatatype.RdataType.make(rdtype) - if dns.rdatatype.is_metatype(rdtype): + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(the_rdtype): raise NoMetaqueries - rdclass = dns.rdataclass.RdataClass.make(rdclass) - if dns.rdataclass.is_metaclass(rdclass): + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(the_rdclass): raise NoMetaqueries self.resolver = resolver self.qnames_to_try = resolver._get_qnames_to_try(qname, search) self.qnames = self.qnames_to_try[:] - self.rdtype = rdtype - self.rdclass = rdclass + self.rdtype = the_rdtype + self.rdclass = the_rdclass self.tcp = tcp self.raise_on_no_answer = raise_on_no_answer - self.nxdomain_responses = {} - # + self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} # Initialize other things to help analysis tools self.qname = dns.name.empty - self.nameservers = [] - self.current_nameservers = [] - self.errors = [] - self.nameserver = None + self.nameservers: List[str] = [] + self.current_nameservers: List[str] = [] + self.errors: List[ErrorTuple] = [] + self.nameserver: Optional[str] = None self.port = 0 self.tcp_attempt = False self.retry_with_tcp = False - self.request = None - self.backoff = 0 + self.request: Optional[dns.message.QueryMessage] = None + self.backoff = 0.0 - def next_request(self): + def next_request( + self, + ) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]: """Get the next request to send, and check the cache. Returns a (request, answer) tuple. At most one of request or @@ -583,31 +636,37 @@ class _Resolution: # Do we know the answer? if self.resolver.cache: - answer = self.resolver.cache.get((self.qname, self.rdtype, - self.rdclass)) + answer = self.resolver.cache.get( + (self.qname, self.rdtype, self.rdclass) + ) if answer is not None: if answer.rrset is None and self.raise_on_no_answer: raise NoAnswer(response=answer.response) else: return (None, answer) - answer = self.resolver.cache.get((self.qname, - dns.rdatatype.ANY, - self.rdclass)) - if answer is not None and \ - answer.response.rcode() == dns.rcode.NXDOMAIN: + answer = self.resolver.cache.get( + (self.qname, dns.rdatatype.ANY, self.rdclass) + ) + if answer is not None and answer.response.rcode() == dns.rcode.NXDOMAIN: # cached NXDOMAIN; record it and continue to next # name. self.nxdomain_responses[self.qname] = answer.response continue # Build the request - request = dns.message.make_query(self.qname, self.rdtype, - self.rdclass) + request = dns.message.make_query(self.qname, self.rdtype, self.rdclass) if self.resolver.keyname is not None: - request.use_tsig(self.resolver.keyring, self.resolver.keyname, - algorithm=self.resolver.keyalgorithm) - request.use_edns(self.resolver.edns, self.resolver.ednsflags, - self.resolver.payload) + request.use_tsig( + self.resolver.keyring, + self.resolver.keyname, + algorithm=self.resolver.keyalgorithm, + ) + request.use_edns( + self.resolver.edns, + self.resolver.ednsflags, + self.resolver.payload, + options=self.resolver.ednsoptions, + ) if self.resolver.flags is not None: request.flags = self.resolver.flags @@ -629,17 +688,16 @@ class _Resolution: # it's only NXDOMAINs as anything else would have returned # before now.) # - raise NXDOMAIN(qnames=self.qnames_to_try, - responses=self.nxdomain_responses) + raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) - def next_nameserver(self): + def next_nameserver(self) -> Tuple[str, int, bool, float]: if self.retry_with_tcp: assert self.nameserver is not None self.tcp_attempt = True self.retry_with_tcp = False return (self.nameserver, self.port, True, 0) - backoff = 0 + backoff = 0.0 if not self.current_nameservers: if len(self.nameservers) == 0: # Out of things to try! @@ -649,24 +707,31 @@ class _Resolution: self.backoff = min(self.backoff * 2, 2) self.nameserver = self.current_nameservers.pop(0) - self.port = self.resolver.nameserver_ports.get(self.nameserver, - self.resolver.port) + self.port = self.resolver.nameserver_ports.get( + self.nameserver, self.resolver.port + ) self.tcp_attempt = self.tcp return (self.nameserver, self.port, self.tcp_attempt, backoff) - def query_result(self, response, ex): + def query_result( + self, response: Optional[dns.message.Message], ex: Optional[Exception] + ) -> Tuple[Optional[Answer], bool]: # # returns an (answer: Answer, end_loop: bool) tuple. # + assert self.nameserver is not None if ex: # Exception during I/O or from_wire() assert response is None - self.errors.append((self.nameserver, self.tcp_attempt, self.port, - ex, response)) - if isinstance(ex, dns.exception.FormError) or \ - isinstance(ex, EOFError) or \ - isinstance(ex, OSError) or \ - isinstance(ex, NotImplementedError): + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, ex, response) + ) + if ( + isinstance(ex, dns.exception.FormError) + or isinstance(ex, EOFError) + or isinstance(ex, OSError) + or isinstance(ex, NotImplementedError) + ): # This nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) elif isinstance(ex, dns.message.Truncated): @@ -678,20 +743,27 @@ class _Resolution: return (None, False) # We got an answer! assert response is not None + assert isinstance(response, dns.message.QueryMessage) rcode = response.rcode() if rcode == dns.rcode.NOERROR: try: - answer = Answer(self.qname, self.rdtype, self.rdclass, response, - self.nameserver, self.port) + answer = Answer( + self.qname, + self.rdtype, + self.rdclass, + response, + self.nameserver, + self.port, + ) except Exception as e: - self.errors.append((self.nameserver, self.tcp_attempt, - self.port, e, response)) + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, e, response) + ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) return (None, False) if self.resolver.cache: - self.resolver.cache.put((self.qname, self.rdtype, - self.rdclass), answer) + self.resolver.cache.put((self.qname, self.rdtype, self.rdclass), answer) if answer.rrset is None and self.raise_on_no_answer: raise NoAnswer(response=answer.response) return (answer, True) @@ -699,26 +771,29 @@ class _Resolution: # Further validate the response by making an Answer, even # if we aren't going to cache it. try: - answer = Answer(self.qname, dns.rdatatype.ANY, - dns.rdataclass.IN, response) + answer = Answer( + self.qname, dns.rdatatype.ANY, dns.rdataclass.IN, response + ) except Exception as e: - self.errors.append((self.nameserver, self.tcp_attempt, - self.port, e, response)) + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, e, response) + ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) return (None, False) self.nxdomain_responses[self.qname] = response if self.resolver.cache: - self.resolver.cache.put((self.qname, - dns.rdatatype.ANY, - self.rdclass), answer) + self.resolver.cache.put( + (self.qname, dns.rdatatype.ANY, self.rdclass), answer + ) # Make next_nameserver() return None, so caller breaks its # inner loop and calls next_request(). return (None, True) elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() - self.errors.append((self.nameserver, self.tcp_attempt, - self.port, yex, response)) + self.errors.append( + (self.nameserver, self.tcp_attempt, self.port, yex, response) + ) raise yex else: # @@ -727,10 +802,18 @@ class _Resolution: # if rcode != dns.rcode.SERVFAIL or not self.resolver.retry_servfail: self.nameservers.remove(self.nameserver) - self.errors.append((self.nameserver, self.tcp_attempt, self.port, - dns.rcode.to_text(rcode), response)) + self.errors.append( + ( + self.nameserver, + self.tcp_attempt, + self.port, + dns.rcode.to_text(rcode), + response, + ) + ) return (None, False) + class BaseResolver: """DNS stub resolver.""" @@ -738,7 +821,27 @@ class BaseResolver: # # pylint: disable=attribute-defined-outside-init - def __init__(self, filename='/etc/resolv.conf', configure=True): + domain: dns.name.Name + nameserver_ports: Dict[str, int] + port: int + search: List[dns.name.Name] + use_search_by_default: bool + timeout: float + lifetime: float + keyring: Optional[Any] + keyname: Optional[Union[dns.name.Name, str]] + keyalgorithm: Union[dns.name.Name, str] + edns: int + ednsflags: int + ednsoptions: Optional[List[dns.edns.Option]] + payload: int + cache: Any + flags: Optional[int] + retry_servfail: bool + rotate: bool + ndots: Optional[int] + + def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True): """*filename*, a ``str`` or file object, specifying a file in standard /etc/resolv.conf format. This parameter is meaningful only when *configure* is true and the platform is POSIX. @@ -752,7 +855,7 @@ class BaseResolver: self.reset() if configure: - if sys.platform == 'win32': + if sys.platform == "win32": self.read_registry() elif filename: self.read_resolv_conf(filename) @@ -760,8 +863,7 @@ class BaseResolver: def reset(self): """Reset all resolver configuration to the defaults.""" - self.domain = \ - dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) + self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) if len(self.domain) == 0: self.domain = dns.name.root self.nameservers = [] @@ -776,6 +878,7 @@ class BaseResolver: self.keyalgorithm = dns.tsig.default_algorithm self.edns = -1 self.ednsflags = 0 + self.ednsoptions = None self.payload = 0 self.cache = None self.flags = None @@ -783,7 +886,7 @@ class BaseResolver: self.rotate = False self.ndots = None - def read_resolv_conf(self, f): + def read_resolv_conf(self, f: Any) -> None: """Process *f* as a file in the /etc/resolv.conf format. If f is a ``str``, it is used as the name of the file to open; otherwise it is treated as the file itself. @@ -800,16 +903,17 @@ class BaseResolver: """ - with contextlib.ExitStack() as stack: - if isinstance(f, str): - try: - f = stack.enter_context(open(f)) - except OSError: - # /etc/resolv.conf doesn't exist, can't be read, etc. - raise NoResolverConfiguration(f'cannot open {f}') - + if isinstance(f, str): + try: + cm: contextlib.AbstractContextManager = open(f) + except OSError: + # /etc/resolv.conf doesn't exist, can't be read, etc. + raise NoResolverConfiguration(f"cannot open {f}") + else: + cm = contextlib.nullcontext(f) + with cm as f: for l in f: - if len(l) == 0 or l[0] == '#' or l[0] == ';': + if len(l) == 0 or l[0] == "#" or l[0] == ";": continue tokens = l.split() @@ -817,42 +921,42 @@ class BaseResolver: if len(tokens) < 2: continue - if tokens[0] == 'nameserver': + if tokens[0] == "nameserver": self.nameservers.append(tokens[1]) - elif tokens[0] == 'domain': + elif tokens[0] == "domain": self.domain = dns.name.from_text(tokens[1]) # domain and search are exclusive self.search = [] - elif tokens[0] == 'search': + elif tokens[0] == "search": # the last search wins self.search = [] for suffix in tokens[1:]: self.search.append(dns.name.from_text(suffix)) # We don't set domain as it is not used if # len(self.search) > 0 - elif tokens[0] == 'options': + elif tokens[0] == "options": for opt in tokens[1:]: - if opt == 'rotate': + if opt == "rotate": self.rotate = True - elif opt == 'edns0': + elif opt == "edns0": self.use_edns() - elif 'timeout' in opt: + elif "timeout" in opt: try: - self.timeout = int(opt.split(':')[1]) + self.timeout = int(opt.split(":")[1]) except (ValueError, IndexError): pass - elif 'ndots' in opt: + elif "ndots" in opt: try: - self.ndots = int(opt.split(':')[1]) + self.ndots = int(opt.split(":")[1]) except (ValueError, IndexError): pass if len(self.nameservers) == 0: - raise NoResolverConfiguration('no nameservers') + raise NoResolverConfiguration("no nameservers") - def read_registry(self): + def read_registry(self) -> None: """Extract resolver configuration from the Windows registry.""" try: - info = dns.win32util.get_dns_info() + info = dns.win32util.get_dns_info() # type: ignore if info.domain is not None: self.domain = info.domain self.nameservers = info.nameservers @@ -860,7 +964,12 @@ class BaseResolver: except AttributeError: raise NotImplementedError - def _compute_timeout(self, start, lifetime=None, errors=None): + def _compute_timeout( + self, + start: float, + lifetime: Optional[float] = None, + errors: Optional[List[ErrorTuple]] = None, + ) -> float: lifetime = self.lifetime if lifetime is None else lifetime now = time.time() duration = now - start @@ -874,12 +983,14 @@ class BaseResolver: # Time went backwards, but only a little. This can # happen, e.g. under vmware with older linux kernels. # Pretend it didn't happen. - now = start + duration = 0 if duration >= lifetime: raise LifetimeTimeout(timeout=duration, errors=errors) return min(lifetime - duration, self.timeout) - def _get_qnames_to_try(self, qname, search): + def _get_qnames_to_try( + self, qname: dns.name.Name, search: Optional[bool] + ) -> List[dns.name.Name]: # This is a separate method so we can unit test the search # rules without requiring the Internet. if search is None: @@ -918,8 +1029,12 @@ class BaseResolver: qnames_to_try.append(abs_qname) return qnames_to_try - def use_tsig(self, keyring, keyname=None, - algorithm=dns.tsig.default_algorithm): + def use_tsig( + self, + keyring: Any, + keyname: Optional[Union[dns.name.Name, str]] = None, + algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + ) -> None: """Add a TSIG signature to each query. The parameters are passed to ``dns.message.Message.use_tsig()``; @@ -930,8 +1045,13 @@ class BaseResolver: self.keyname = keyname self.keyalgorithm = algorithm - def use_edns(self, edns=0, ednsflags=0, - payload=dns.message.DEFAULT_EDNS_PAYLOAD): + def use_edns( + self, + edns: Optional[Union[int, bool]] = 0, + ednsflags: int = 0, + payload: int = dns.message.DEFAULT_EDNS_PAYLOAD, + options: Optional[List[dns.edns.Option]] = None, + ) -> None: """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -944,6 +1064,9 @@ class BaseResolver: *payload*, an ``int``, is the EDNS sender's payload field, which is the maximum size of UDP datagram the sender can handle. I.e. how big a response to this message can be. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. """ if edns is None or edns is False: @@ -953,8 +1076,9 @@ class BaseResolver: self.edns = edns self.ednsflags = ednsflags self.payload = payload + self.ednsoptions = options - def set_flags(self, flags): + def set_flags(self, flags: int) -> None: """Overrides the default flags with your own. *flags*, an ``int``, the message flags to use. @@ -963,11 +1087,11 @@ class BaseResolver: self.flags = flags @property - def nameservers(self): + def nameservers(self) -> List[str]: return self._nameservers @nameservers.setter - def nameservers(self, nameservers): + def nameservers(self, nameservers: List[str]) -> None: """ *nameservers*, a ``list`` of nameservers. @@ -978,23 +1102,35 @@ class BaseResolver: for nameserver in nameservers: if not dns.inet.is_address(nameserver): try: - if urlparse(nameserver).scheme != 'https': + if urlparse(nameserver).scheme != "https": raise NotImplementedError except Exception: - raise ValueError(f'nameserver {nameserver} is not an ' - 'IP address or valid https URL') + raise ValueError( + f"nameserver {nameserver} is not an " + "IP address or valid https URL" + ) self._nameservers = nameservers else: - raise ValueError('nameservers must be a list' - ' (not a {})'.format(type(nameservers))) + raise ValueError( + "nameservers must be a list (not a {})".format(type(nameservers)) + ) class Resolver(BaseResolver): """DNS stub resolver.""" - def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, source_port=0, - lifetime=None, search=None): # pylint: disable=arguments-differ + def resolve( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + ) -> Answer: # pylint: disable=arguments-differ """Query nameservers to find the answer to the question. The *qname*, *rdtype*, and *rdclass* parameters may be objects @@ -1046,8 +1182,9 @@ class Resolver(BaseResolver): """ - resolution = _Resolution(self, qname, rdtype, rdclass, tcp, - raise_on_no_answer, search) + resolution = _Resolution( + self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search + ) start = time.time() while True: (request, answer) = resolution.next_request() @@ -1058,32 +1195,36 @@ class Resolver(BaseResolver): if answer is not None: # cache hit! return answer + assert request is not None # needed for type checking done = False while not done: (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: time.sleep(backoff) - timeout = self._compute_timeout(start, lifetime, - resolution.errors) + timeout = self._compute_timeout(start, lifetime, resolution.errors) try: if dns.inet.is_address(nameserver): if tcp: - response = dns.query.tcp(request, nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port) + response = dns.query.tcp( + request, + nameserver, + timeout=timeout, + port=port, + source=source, + source_port=source_port, + ) else: - response = dns.query.udp(request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port, - raise_on_truncation=True) + response = dns.query.udp( + request, + nameserver, + timeout=timeout, + port=port, + source=source, + source_port=source_port, + raise_on_truncation=True, + ) else: - response = dns.query.https(request, nameserver, - timeout=timeout) + response = dns.query.https(request, nameserver, timeout=timeout) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -1095,9 +1236,17 @@ class Resolver(BaseResolver): if answer is not None: return answer - def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, source_port=0, - lifetime=None): # pragma: no cover + def query( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + ) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. This method calls resolve() with ``search=True``, and is @@ -1105,13 +1254,24 @@ class Resolver(BaseResolver): dnspython. See the documentation for the resolve() method for further details. """ - warnings.warn('please use dns.resolver.Resolver.resolve() instead', - DeprecationWarning, stacklevel=2) - return self.resolve(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port, lifetime, - True) + warnings.warn( + "please use dns.resolver.Resolver.resolve() instead", + DeprecationWarning, + stacklevel=2, + ) + return self.resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + True, + ) - def resolve_address(self, ipaddr, *args, **kwargs): + def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Any) -> Answer: """Use a resolver to run a reverse query for PTR records. This utilizes the resolve() method to perform a PTR lookup on the @@ -1124,15 +1284,20 @@ class Resolver(BaseResolver): except for rdtype and rdclass are also supported by this function. """ - - return self.resolve(dns.reversename.from_address(ipaddr), - rdtype=dns.rdatatype.PTR, - rdclass=dns.rdataclass.IN, - *args, **kwargs) + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs["rdtype"] = dns.rdatatype.PTR + modified_kwargs["rdclass"] = dns.rdataclass.IN + return self.resolve( + dns.reversename.from_address(ipaddr), *args, **modified_kwargs + ) # type: ignore[arg-type] # pylint: disable=redefined-outer-name - def canonical_name(self, name): + def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. The canonical name is the name the resolver uses for queries @@ -1157,13 +1322,14 @@ class Resolver(BaseResolver): #: The default resolver. -default_resolver = None +default_resolver: Optional[Resolver] = None -def get_default_resolver(): +def get_default_resolver() -> Resolver: """Get the default resolver, initializing it if necessary.""" if default_resolver is None: reset_default_resolver() + assert default_resolver is not None return default_resolver @@ -1178,9 +1344,18 @@ def reset_default_resolver(): default_resolver = Resolver() -def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None, search=None): +def resolve( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, +) -> Answer: # pragma: no cover + """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver @@ -1190,13 +1365,29 @@ def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, parameters. """ - return get_default_resolver().resolve(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port, - lifetime, search) + return get_default_resolver().resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + ) -def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None): # pragma: no cover + +def query( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, +) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. This method calls resolve() with ``search=True``, and is @@ -1204,14 +1395,23 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, dnspython. See the documentation for the resolve() method for further details. """ - warnings.warn('please use dns.resolver.resolve() instead', - DeprecationWarning, stacklevel=2) - return resolve(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port, lifetime, - True) + warnings.warn( + "please use dns.resolver.resolve() instead", DeprecationWarning, stacklevel=2 + ) + return resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + True, + ) -def resolve_address(ipaddr, *args, **kwargs): +def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer: """Use a resolver to run a reverse query for PTR records. See ``dns.resolver.Resolver.resolve_address`` for more information on the @@ -1221,7 +1421,7 @@ def resolve_address(ipaddr, *args, **kwargs): return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) -def canonical_name(name): +def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. See ``dns.resolver.Resolver.canonical_name`` for more information on the @@ -1231,8 +1431,13 @@ def canonical_name(name): return get_default_resolver().canonical_name(name) -def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, - lifetime=None): +def zone_for_name( + name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + tcp: bool = False, + resolver: Optional[Resolver] = None, + lifetime: Optional[float] = None, +) -> dns.name.Name: """Find the name of the zone which contains the specified name. *name*, an absolute ``dns.name.Name`` or ``str``, the query name. @@ -1265,20 +1470,24 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, if not name.is_absolute(): raise NotAbsolute(name) start = time.time() + expiration: Optional[float] if lifetime is not None: expiration = start + lifetime else: expiration = None while 1: try: + rlifetime: Optional[float] if expiration: rlifetime = expiration - time.time() if rlifetime <= 0: rlifetime = 0 else: rlifetime = None - answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp, - lifetime=rlifetime) + answer = resolver.resolve( + name, dns.rdatatype.SOA, rdclass, tcp, lifetime=rlifetime + ) + assert answer.rrset is not None if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher @@ -1289,8 +1498,7 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, response = e.response() # pylint: disable=no-value-for-parameter if response: for rrs in response.authority: - if rrs.rdtype == dns.rdatatype.SOA and \ - rrs.rdclass == rdclass: + if rrs.rdtype == dns.rdatatype.SOA and rrs.rdclass == rdclass: (nr, _, _) = rrs.name.fullcompare(name) if nr == dns.name.NAMERELN_SUPERDOMAIN: # We're doing a proper superdomain check as @@ -1307,6 +1515,7 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, except dns.name.NoParent: raise NoRootSOA + # # Support for overriding the system resolver for all python code in the # running process. @@ -1326,16 +1535,16 @@ _original_gethostbyname_ex = socket.gethostbyname_ex _original_gethostbyaddr = socket.gethostbyaddr -def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, - proto=0, flags=0): +def _getaddrinfo( + host=None, service=None, family=socket.AF_UNSPEC, socktype=0, proto=0, flags=0 +): if flags & socket.AI_NUMERICHOST != 0: # Short circuit directly into the system's getaddrinfo(). We're # not adding any value in this case, and this avoids infinite loops # because dns.query.* needs to call getaddrinfo() for IPv6 scoping # reasons. We will also do this short circuit below if we # discover that the host is an address literal. - return _original_getaddrinfo(host, service, family, socktype, proto, - flags) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0: # Not implemented. We raise a gaierror as opposed to a # NotImplementedError as it helps callers handle errors more @@ -1345,55 +1554,53 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, # no EAI_SYSTEM on Windows [Issue #416]. We didn't go for # EAI_BADFLAGS as the flags aren't bad, we just don't # implement them. - raise socket.gaierror(socket.EAI_FAIL, - 'Non-recoverable failure in name resolution') + raise socket.gaierror( + socket.EAI_FAIL, "Non-recoverable failure in name resolution" + ) if host is None and service is None: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") v6addrs = [] v4addrs = [] canonical_name = None # pylint: disable=redefined-outer-name # Is host None or an address literal? If so, use the system's # getaddrinfo(). if host is None: - return _original_getaddrinfo(host, service, family, socktype, - proto, flags) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) try: # We don't care about the result of af_for_address(), we're just # calling it so it raises an exception if host is not an IPv4 or # IPv6 address. dns.inet.af_for_address(host) - return _original_getaddrinfo(host, service, family, socktype, - proto, flags) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) except Exception: pass # Something needs resolution! try: if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - v6 = _resolver.resolve(host, dns.rdatatype.AAAA, - raise_on_no_answer=False) + v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False) # Note that setting host ensures we query the same name - # for A as we did for AAAA. + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. host = v6.qname canonical_name = v6.canonical_name.to_text(True) if v6.rrset is not None: for rdata in v6.rrset: v6addrs.append(rdata.address) if family == socket.AF_INET or family == socket.AF_UNSPEC: - v4 = _resolver.resolve(host, dns.rdatatype.A, - raise_on_no_answer=False) - host = v4.qname + v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False) canonical_name = v4.canonical_name.to_text(True) if v4.rrset is not None: for rdata in v4.rrset: v4addrs.append(rdata.address) except dns.resolver.NXDOMAIN: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") except Exception: # We raise EAI_AGAIN here as the failure may be temporary # (e.g. a timeout) and EAI_SYSTEM isn't defined on Windows. # [Issue #416] - raise socket.gaierror(socket.EAI_AGAIN, - 'Temporary failure in name resolution') + raise socket.gaierror(socket.EAI_AGAIN, "Temporary failure in name resolution") port = None try: # Is it a port literal? @@ -1408,7 +1615,7 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, except Exception: pass if port is None: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") tuples = [] if socktype == 0: socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM] @@ -1417,21 +1624,23 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, if flags & socket.AI_CANONNAME != 0: cname = canonical_name else: - cname = '' + cname = "" if family == socket.AF_INET6 or family == socket.AF_UNSPEC: for addr in v6addrs: for socktype in socktypes: for proto in _protocols_for_socktype[socktype]: - tuples.append((socket.AF_INET6, socktype, proto, - cname, (addr, port, 0, 0))) + tuples.append( + (socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0)) + ) if family == socket.AF_INET or family == socket.AF_UNSPEC: for addr in v4addrs: for socktype in socktypes: for proto in _protocols_for_socktype[socktype]: - tuples.append((socket.AF_INET, socktype, proto, - cname, (addr, port))) + tuples.append( + (socket.AF_INET, socktype, proto, cname, (addr, port)) + ) if len(tuples) == 0: - raise socket.gaierror(socket.EAI_NONAME, 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") return tuples @@ -1444,31 +1653,29 @@ def _getnameinfo(sockaddr, flags=0): else: scope = None family = socket.AF_INET - tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, - socket.SOL_TCP, 0) + tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0) if len(tuples) > 1: - raise socket.error('sockaddr resolved to multiple addresses') + raise socket.error("sockaddr resolved to multiple addresses") addr = tuples[0][4][0] if flags & socket.NI_DGRAM: - pname = 'udp' + pname = "udp" else: - pname = 'tcp' + pname = "tcp" qname = dns.reversename.from_address(addr) if flags & socket.NI_NUMERICHOST == 0: try: - answer = _resolver.resolve(qname, 'PTR') + answer = _resolver.resolve(qname, "PTR") hostname = answer.rrset[0].target.to_text(True) except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): if flags & socket.NI_NAMEREQD: - raise socket.gaierror(socket.EAI_NONAME, - 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") hostname = addr if scope is not None: - hostname += '%' + str(scope) + hostname += "%" + str(scope) else: hostname = addr if scope is not None: - hostname += '%' + str(scope) + hostname += "%" + str(scope) if flags & socket.NI_NUMERICSERV: service = str(port) else: @@ -1495,8 +1702,9 @@ def _gethostbyname(name): def _gethostbyname_ex(name): aliases = [] addresses = [] - tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM, - socket.SOL_TCP, socket.AI_CANONNAME) + tuples = _getaddrinfo( + name, 0, socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME + ) canonical = tuples[0][3] for item in tuples: addresses.append(item[4][0]) @@ -1513,15 +1721,15 @@ def _gethostbyaddr(ip): try: dns.ipv4.inet_aton(ip) except Exception: - raise socket.gaierror(socket.EAI_NONAME, - 'Name or service not known') + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") sockaddr = (ip, 80) family = socket.AF_INET (name, _) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) aliases = [] addresses = [] - tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, - socket.AI_CANONNAME) + tuples = _getaddrinfo( + name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME + ) canonical = tuples[0][3] # We only want to include an address from the tuples if it's the # same as the one we asked about. We do this comparison in binary @@ -1536,7 +1744,7 @@ def _gethostbyaddr(ip): return (canonical, aliases, addresses) -def override_system_resolver(resolver=None): +def override_system_resolver(resolver: Optional[Resolver] = None) -> None: """Override the system resolver routines in the socket module with versions which use dnspython's resolver. @@ -1562,7 +1770,7 @@ def override_system_resolver(resolver=None): socket.gethostbyaddr = _gethostbyaddr -def restore_system_resolver(): +def restore_system_resolver() -> None: """Undo the effects of prior override_system_resolver().""" global _resolver diff --git a/lib/dns/resolver.pyi b/lib/dns/resolver.pyi deleted file mode 100644 index 348df4da..00000000 --- a/lib/dns/resolver.pyi +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Union, Optional, List, Any, Dict -from . import exception, rdataclass, name, rdatatype - -import socket -_gethostbyname = socket.gethostbyname - -class NXDOMAIN(exception.DNSException): ... -class YXDOMAIN(exception.DNSException): ... -class NoAnswer(exception.DNSException): ... -class NoNameservers(exception.DNSException): ... -class NotAbsolute(exception.DNSException): ... -class NoRootSOA(exception.DNSException): ... -class NoMetaqueries(exception.DNSException): ... -class NoResolverConfiguration(exception.DNSException): ... -Timeout = exception.Timeout - -def resolve(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime : Optional[float]=None, - search : Optional[bool]=None): - ... -def query(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime : Optional[float]=None): - ... -def resolve_address(ipaddr: str, *args: Any, **kwargs: Optional[Dict]): - ... -class LRUCache: - def __init__(self, max_size=1000): - ... - def get(self, key): - ... - def put(self, key, val): - ... -class Answer: - def __init__(self, qname, rdtype, rdclass, response, - raise_on_no_answer=True): - ... -def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False, - resolver : Optional[Resolver] = None): - ... - -class Resolver: - def __init__(self, filename : Optional[str] = '/etc/resolv.conf', - configure : Optional[bool] = True): - self.nameservers : List[str] - def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - lifetime : Optional[float]=None, - search : Optional[bool]=None): - ... - def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - lifetime : Optional[float]=None): - ... -default_resolver: typing.Optional[Resolver] -def reset_default_resolver() -> None: - ... -def get_default_resolver() -> Resolver: - ... diff --git a/lib/dns/reversename.py b/lib/dns/reversename.py index e0beb03d..eb6a3b6b 100644 --- a/lib/dns/reversename.py +++ b/lib/dns/reversename.py @@ -23,12 +23,15 @@ import dns.name import dns.ipv6 import dns.ipv4 -ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.') -ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') +ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.") +ipv6_reverse_domain = dns.name.from_text("ip6.arpa.") -def from_address(text, v4_origin=ipv4_reverse_domain, - v6_origin=ipv6_reverse_domain): +def from_address( + text: str, + v4_origin: dns.name.Name = ipv4_reverse_domain, + v6_origin: dns.name.Name = ipv6_reverse_domain, +) -> dns.name.Name: """Convert an IPv4 or IPv6 address in textual form into a Name object whose value is the reverse-map domain name of the address. @@ -51,20 +54,22 @@ def from_address(text, v4_origin=ipv4_reverse_domain, try: v6 = dns.ipv6.inet_aton(text) if dns.ipv6.is_mapped(v6): - parts = ['%d' % byte for byte in v6[12:]] + parts = ["%d" % byte for byte in v6[12:]] origin = v4_origin else: parts = [x for x in str(binascii.hexlify(v6).decode())] origin = v6_origin except Exception: - parts = ['%d' % - byte for byte in dns.ipv4.inet_aton(text)] + parts = ["%d" % byte for byte in dns.ipv4.inet_aton(text)] origin = v4_origin - return dns.name.from_text('.'.join(reversed(parts)), origin=origin) + return dns.name.from_text(".".join(reversed(parts)), origin=origin) -def to_address(name, v4_origin=ipv4_reverse_domain, - v6_origin=ipv6_reverse_domain): +def to_address( + name: dns.name.Name, + v4_origin: dns.name.Name = ipv4_reverse_domain, + v6_origin: dns.name.Name = ipv6_reverse_domain, +) -> str: """Convert a reverse map domain name into textual address form. *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name @@ -84,7 +89,7 @@ def to_address(name, v4_origin=ipv4_reverse_domain, if name.is_subdomain(v4_origin): name = name.relativize(v4_origin) - text = b'.'.join(reversed(name.labels)) + text = b".".join(reversed(name.labels)) # run through inet_ntoa() to check syntax and make pretty. return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) elif name.is_subdomain(v6_origin): @@ -92,9 +97,9 @@ def to_address(name, v4_origin=ipv4_reverse_domain, labels = list(reversed(name.labels)) parts = [] for i in range(0, len(labels), 4): - parts.append(b''.join(labels[i:i + 4])) - text = b':'.join(parts) + parts.append(b"".join(labels[i : i + 4])) + text = b":".join(parts) # run through inet_ntoa() to check syntax and make pretty. return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) else: - raise dns.exception.SyntaxError('unknown reverse-map address family') + raise dns.exception.SyntaxError("unknown reverse-map address family") diff --git a/lib/dns/reversename.pyi b/lib/dns/reversename.pyi deleted file mode 100644 index 97f072ea..00000000 --- a/lib/dns/reversename.pyi +++ /dev/null @@ -1,6 +0,0 @@ -from . import name -def from_address(text : str) -> name.Name: - ... - -def to_address(name : name.Name) -> str: - ... diff --git a/lib/dns/rrset.py b/lib/dns/rrset.py index a71d4573..3f22a90c 100644 --- a/lib/dns/rrset.py +++ b/lib/dns/rrset.py @@ -17,6 +17,7 @@ """DNS RRsets (an RRset is a named rdataset)""" +from typing import Any, cast, Collection, Dict, Optional, Union import dns.name import dns.rdataset @@ -35,10 +36,16 @@ class RRset(dns.rdataset.Rdataset): name. """ - __slots__ = ['name', 'deleting'] + __slots__ = ["name", "deleting"] - def __init__(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE, - deleting=None): + def __init__( + self, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + ): """Create a new RRset.""" super().__init__(rdclass, rdtype, covers) @@ -53,17 +60,26 @@ class RRset(dns.rdataset.Rdataset): def __repr__(self): if self.covers == 0: - ctext = '' + ctext = "" else: - ctext = '(' + dns.rdatatype.to_text(self.covers) + ')' + ctext = "(" + dns.rdatatype.to_text(self.covers) + ")" if self.deleting is not None: - dtext = ' delete=' + dns.rdataclass.to_text(self.deleting) + dtext = " delete=" + dns.rdataclass.to_text(self.deleting) else: - dtext = '' - return '' + dtext = "" + return ( + "" + ) def __str__(self): return self.to_text() @@ -76,7 +92,7 @@ class RRset(dns.rdataset.Rdataset): return False return super().__eq__(other) - def match(self, *args, **kwargs): + def match(self, *args: Any, **kwargs: Any) -> bool: # type: ignore[override] """Does this rrset match the specified attributes? Behaves as :py:func:`full_match()` if the first argument is a @@ -89,12 +105,18 @@ class RRset(dns.rdataset.Rdataset): compatibility.) """ if isinstance(args[0], dns.name.Name): - return self.full_match(*args, **kwargs) + return self.full_match(*args, **kwargs) # type: ignore[arg-type] else: - return super().match(*args, **kwargs) + return super().match(*args, **kwargs) # type: ignore[arg-type] - def full_match(self, name, rdclass, rdtype, covers, - deleting=None): + def full_match( + self, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + deleting: Optional[dns.rdataclass.RdataClass] = None, + ) -> bool: """Returns ``True`` if this rrset matches the specified name, class, type, covers, and deletion state. """ @@ -106,7 +128,12 @@ class RRset(dns.rdataset.Rdataset): # pylint: disable=arguments-differ - def to_text(self, origin=None, relativize=True, **kw): + def to_text( # type: ignore[override] + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any] + ) -> str: """Convert the RRset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -123,11 +150,17 @@ class RRset(dns.rdataset.Rdataset): to *origin*. """ - return super().to_text(self.name, origin, relativize, - self.deleting, **kw) + return super().to_text( + self.name, origin, relativize, self.deleting, **kw # type: ignore + ) - def to_wire(self, file, compress=None, origin=None, - **kw): + def to_wire( # type: ignore[override] + self, + file: Any, + compress: Optional[dns.name.CompressType] = None, # type: ignore + origin: Optional[dns.name.Name] = None, + **kw: Dict[str, Any] + ) -> int: """Convert the RRset to wire format. All keyword arguments are passed to ``dns.rdataset.to_wire()``; see @@ -136,12 +169,13 @@ class RRset(dns.rdataset.Rdataset): Returns an ``int``, the number of records emitted. """ - return super().to_wire(self.name, file, compress, origin, - self.deleting, **kw) + return super().to_wire( + self.name, file, compress, origin, self.deleting, **kw # type:ignore + ) # pylint: enable=arguments-differ - def to_rdataset(self): + def to_rdataset(self) -> dns.rdataset.Rdataset: """Convert an RRset into an Rdataset. Returns a ``dns.rdataset.Rdataset``. @@ -149,9 +183,17 @@ class RRset(dns.rdataset.Rdataset): return dns.rdataset.from_rdata_list(self.ttl, list(self)) -def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, - idna_codec=None, origin=None, relativize=True, - relativize_to=None): +def from_text_list( + name: Union[dns.name.Name, str], + ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> RRset: """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. @@ -172,28 +214,42 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, if isinstance(name, str): name = dns.name.from_text(name, None, idna_codec=idna_codec) - rdclass = dns.rdataclass.RdataClass.make(rdclass) - rdtype = dns.rdatatype.RdataType.make(rdtype) - r = RRset(name, rdclass, rdtype) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + r = RRset(name, the_rdclass, the_rdtype) r.update_ttl(ttl) for t in text_rdatas: - rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, - relativize_to, idna_codec) + rd = dns.rdata.from_text( + r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec + ) r.add(rd) return r -def from_text(name, ttl, rdclass, rdtype, *text_rdatas): +def from_text( + name: Union[dns.name.Name, str], + ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + *text_rdatas: Any +) -> RRset: """Create an RRset with the specified name, TTL, class, and type and with the specified rdatas in text format. Returns a ``dns.rrset.RRset`` object. """ - return from_text_list(name, ttl, rdclass, rdtype, text_rdatas) + return from_text_list( + name, ttl, rdclass, rdtype, cast(Collection[str], text_rdatas) + ) -def from_rdata_list(name, ttl, rdatas, idna_codec=None): +def from_rdata_list( + name: Union[dns.name.Name, str], + ttl: int, + rdatas: Collection[dns.rdata.Rdata], + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> RRset: """Create an RRset with the specified name and TTL, and with the specified list of rdata objects. @@ -216,14 +272,15 @@ def from_rdata_list(name, ttl, rdatas, idna_codec=None): r = RRset(name, rd.rdclass, rd.rdtype) r.update_ttl(ttl) r.add(rd) + assert r is not None return r -def from_rdata(name, ttl, *rdatas): +def from_rdata(name: Union[dns.name.Name, str], ttl: int, *rdatas: Any) -> RRset: """Create an RRset with the specified name and TTL, and with the specified rdata objects. Returns a ``dns.rrset.RRset`` object. """ - return from_rdata_list(name, ttl, rdatas) + return from_rdata_list(name, ttl, cast(Collection[dns.rdata.Rdata], rdatas)) diff --git a/lib/dns/rrset.pyi b/lib/dns/rrset.pyi deleted file mode 100644 index 0a81a2a0..00000000 --- a/lib/dns/rrset.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import List, Optional -from . import rdataset, rdatatype - -class RRset(rdataset.Rdataset): - def __init__(self, name, rdclass : int , rdtype : int, covers=rdatatype.NONE, - deleting : Optional[int] =None) -> None: - self.name = name - self.deleting = deleting -def from_text(name : str, ttl : int, rdclass : str, rdtype : str, *text_rdatas : str): - ... diff --git a/lib/dns/serial.py b/lib/dns/serial.py index b0474151..3417299b 100644 --- a/lib/dns/serial.py +++ b/lib/dns/serial.py @@ -2,13 +2,14 @@ """Serial Number Arthimetic from RFC 1982""" + class Serial: - def __init__(self, value, bits=32): - self.value = value % 2 ** bits + def __init__(self, value: int, bits: int = 32): + self.value = value % 2**bits self.bits = bits def __repr__(self): - return f'dns.serial.Serial({self.value}, {self.bits})' + return f"dns.serial.Serial({self.value}, {self.bits})" def __eq__(self, other): if isinstance(other, int): @@ -29,11 +30,11 @@ class Serial: other = Serial(other, self.bits) elif not isinstance(other, Serial) or other.bits != self.bits: return NotImplemented - if self.value < other.value and \ - other.value - self.value < 2 ** (self.bits - 1): + if self.value < other.value and other.value - self.value < 2 ** (self.bits - 1): return True - elif self.value > other.value and \ - self.value - other.value > 2 ** (self.bits - 1): + elif self.value > other.value and self.value - other.value > 2 ** ( + self.bits - 1 + ): return True else: return False @@ -46,11 +47,11 @@ class Serial: other = Serial(other, self.bits) elif not isinstance(other, Serial) or other.bits != self.bits: return NotImplemented - if self.value < other.value and \ - other.value - self.value > 2 ** (self.bits - 1): + if self.value < other.value and other.value - self.value > 2 ** (self.bits - 1): return True - elif self.value > other.value and \ - self.value - other.value < 2 ** (self.bits - 1): + elif self.value > other.value and self.value - other.value < 2 ** ( + self.bits - 1 + ): return True else: return False @@ -69,7 +70,7 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v += delta - v = v % 2 ** self.bits + v = v % 2**self.bits return Serial(v, self.bits) def __iadd__(self, other): @@ -83,7 +84,7 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v += delta - v = v % 2 ** self.bits + v = v % 2**self.bits self.value = v return self @@ -98,7 +99,7 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v -= delta - v = v % 2 ** self.bits + v = v % 2**self.bits return Serial(v, self.bits) def __isub__(self, other): @@ -112,6 +113,6 @@ class Serial: if abs(delta) > (2 ** (self.bits - 1) - 1): raise ValueError v -= delta - v = v % 2 ** self.bits + v = v % 2**self.bits self.value = v return self diff --git a/lib/dns/set.py b/lib/dns/set.py index 1fd4d0ae..fa50ed97 100644 --- a/lib/dns/set.py +++ b/lib/dns/set.py @@ -16,12 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import itertools -import sys -if sys.version_info >= (3, 7): - odict = dict -else: - from collections import OrderedDict as odict # pragma: no cover class Set: @@ -33,7 +28,7 @@ class Set: ability is widely used in dnspython applications. """ - __slots__ = ['items'] + __slots__ = ["items"] def __init__(self, items=None): """Initialize the set. @@ -41,24 +36,24 @@ class Set: *items*, an iterable or ``None``, the initial set of items. """ - self.items = odict() + self.items = dict() if items is not None: for item in items: - self.add(item) + # This is safe for how we use set, but if other code + # subclasses it could be a legitimate issue. + self.add(item) # lgtm[py/init-calls-subclass] def __repr__(self): return "dns.set.Set(%s)" % repr(list(self.items.keys())) def add(self, item): - """Add an item to the set. - """ + """Add an item to the set.""" if item not in self.items: self.items[item] = None def remove(self, item): - """Remove an item from the set. - """ + """Remove an item from the set.""" try: del self.items[item] @@ -66,12 +61,16 @@ class Set: raise ValueError def discard(self, item): - """Remove an item from the set if present. - """ + """Remove an item from the set if present.""" self.items.pop(item, None) - def _clone(self): + def pop(self): + """Remove an arbitrary item from the set.""" + (k, _) = self.items.popitem() + return k + + def _clone(self) -> "Set": """Make a (shallow) copy of the set. There is a 'clone protocol' that subclasses of this class @@ -84,24 +83,22 @@ class Set: subclasses. """ - if hasattr(self, '_clone_class'): - cls = self._clone_class + if hasattr(self, "_clone_class"): + cls = self._clone_class # type: ignore else: cls = self.__class__ obj = cls.__new__(cls) - obj.items = odict() + obj.items = dict() obj.items.update(self.items) return obj def __copy__(self): - """Make a (shallow) copy of the set. - """ + """Make a (shallow) copy of the set.""" return self._clone() def copy(self): - """Make a (shallow) copy of the set. - """ + """Make a (shallow) copy of the set.""" return self._clone() @@ -111,8 +108,8 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') - if self is other: + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] return for item in other.items: self.add(item) @@ -123,8 +120,8 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') - if self is other: + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] return # we make a copy of the list so that we can remove items from # the list without breaking the iterator. @@ -138,13 +135,25 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') - if self is other: + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] self.items.clear() else: for item in other.items: self.discard(item) + def symmetric_difference_update(self, other): + """Update the set, retaining only elements unique to both sets.""" + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] + self.items.clear() + else: + overlap = self.intersection(other) + self.union_update(other) + self.difference_update(overlap) + def union(self, other): """Return a new set which is the union of ``self`` and ``other``. @@ -177,6 +186,18 @@ class Set: obj.difference_update(other) return obj + def symmetric_difference(self, other): + """Return a new set which (``self`` - ``other``) | (``other`` + - ``self), ie: the items in either ``self`` or ``other`` which + are not contained in their intersection. + + Returns the same Set type as this set. + """ + + obj = self._clone() + obj.symmetric_difference_update(other) + return obj + def __or__(self, other): return self.union(other) @@ -189,6 +210,9 @@ class Set: def __sub__(self, other): return self.difference(other) + def __xor__(self, other): + return self.symmetric_difference(other) + def __ior__(self, other): self.union_update(other) return self @@ -205,6 +229,10 @@ class Set: self.difference_update(other) return self + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + def update(self, other): """Update the set, adding any elements from other which are not already in the set. @@ -221,13 +249,7 @@ class Set: self.items.clear() def __eq__(self, other): - if odict == dict: - return self.items == other.items - else: - # We don't want an ordered comparison. - if len(self.items) != len(other.items): - return False - return all(elt in other.items for elt in self.items) + return self.items == other.items def __ne__(self, other): return not self.__eq__(other) @@ -258,7 +280,7 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") for item in self.items: if item not in other.items: return False @@ -271,8 +293,16 @@ class Set: """ if not isinstance(other, Set): - raise ValueError('other must be a Set instance') + raise ValueError("other must be a Set instance") for item in other.items: if item not in self.items: return False return True + + def isdisjoint(self, other): + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + for item in other.items: + if item in self.items: + return False + return True diff --git a/lib/dns/tokenizer.py b/lib/dns/tokenizer.py index cb6a6302..0551578a 100644 --- a/lib/dns/tokenizer.py +++ b/lib/dns/tokenizer.py @@ -17,6 +17,8 @@ """Tokenize DNS zone file format""" +from typing import Any, Optional, List, Tuple + import io import sys @@ -24,7 +26,7 @@ import dns.exception import dns.name import dns.ttl -_DELIMITERS = {' ', '\t', '\n', ';', '(', ')', '"'} +_DELIMITERS = {" ", "\t", "\n", ";", "(", ")", '"'} _QUOTING_DELIMITERS = {'"'} EOF = 0 @@ -48,7 +50,13 @@ class Token: has_escape: Does the token value contain escapes? """ - def __init__(self, ttype, value='', has_escape=False, comment=None): + def __init__( + self, + ttype: int, + value: Any = "", + has_escape: bool = False, + comment: Optional[str] = None, + ): """Initialize a token instance.""" self.ttype = ttype @@ -56,55 +64,53 @@ class Token: self.has_escape = has_escape self.comment = comment - def is_eof(self): + def is_eof(self) -> bool: return self.ttype == EOF - def is_eol(self): + def is_eol(self) -> bool: return self.ttype == EOL - def is_whitespace(self): + def is_whitespace(self) -> bool: return self.ttype == WHITESPACE - def is_identifier(self): + def is_identifier(self) -> bool: return self.ttype == IDENTIFIER - def is_quoted_string(self): + def is_quoted_string(self) -> bool: return self.ttype == QUOTED_STRING - def is_comment(self): + def is_comment(self) -> bool: return self.ttype == COMMENT - def is_delimiter(self): # pragma: no cover (we don't return delimiters yet) + def is_delimiter(self) -> bool: # pragma: no cover (we don't return delimiters yet) return self.ttype == DELIMITER - def is_eol_or_eof(self): + def is_eol_or_eof(self) -> bool: return self.ttype == EOL or self.ttype == EOF def __eq__(self, other): if not isinstance(other, Token): return False - return (self.ttype == other.ttype and - self.value == other.value) + return self.ttype == other.ttype and self.value == other.value def __ne__(self, other): if not isinstance(other, Token): return True - return (self.ttype != other.ttype or - self.value != other.value) + return self.ttype != other.ttype or self.value != other.value def __str__(self): return '%d "%s"' % (self.ttype, self.value) - def unescape(self): + def unescape(self) -> "Token": if not self.has_escape: return self - unescaped = '' + unescaped = "" l = len(self.value) i = 0 while i < l: c = self.value[i] i += 1 - if c == '\\': + if c == "\\": if i >= l: # pragma: no cover (can't happen via get()) raise dns.exception.UnexpectedEnd c = self.value[i] @@ -127,7 +133,7 @@ class Token: unescaped += c return Token(self.ttype, unescaped) - def unescape_to_bytes(self): + def unescape_to_bytes(self) -> "Token": # We used to use unescape() for TXT-like records, but this # caused problems as we'd process DNS escapes into Unicode code # points instead of byte values, and then a to_text() of the @@ -152,13 +158,13 @@ class Token: # # foo\226\128\139bar # - unescaped = b'' + unescaped = b"" l = len(self.value) i = 0 while i < l: c = self.value[i] i += 1 - if c == '\\': + if c == "\\": if i >= l: # pragma: no cover (can't happen via get()) raise dns.exception.UnexpectedEnd c = self.value[i] @@ -177,7 +183,7 @@ class Token: codepoint = int(c) * 100 + int(c2) * 10 + int(c3) if codepoint > 255: raise dns.exception.SyntaxError - unescaped += b'%c' % (codepoint) + unescaped += b"%c" % (codepoint) else: # Note that as mentioned above, if c is a Unicode # code point outside of the ASCII range, then this @@ -223,7 +229,12 @@ class Tokenizer: encoder/decoder is used. """ - def __init__(self, f=sys.stdin, filename=None, idna_codec=None): + def __init__( + self, + f: Any = sys.stdin, + filename: Optional[str] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + ): """Initialize a tokenizer instance. f: The file to tokenize. The default is sys.stdin. @@ -241,49 +252,50 @@ class Tokenizer: if isinstance(f, str): f = io.StringIO(f) if filename is None: - filename = '' + filename = "" elif isinstance(f, bytes): f = io.StringIO(f.decode()) if filename is None: - filename = '' + filename = "" else: if filename is None: if f is sys.stdin: - filename = '' + filename = "" else: - filename = '' + filename = "" self.file = f - self.ungotten_char = None - self.ungotten_token = None + self.ungotten_char: Optional[str] = None + self.ungotten_token: Optional[Token] = None self.multiline = 0 self.quoting = False self.eof = False self.delimiters = _DELIMITERS self.line_number = 1 + assert filename is not None self.filename = filename if idna_codec is None: - idna_codec = dns.name.IDNA_2003 - self.idna_codec = idna_codec + self.idna_codec: dns.name.IDNACodec = dns.name.IDNA_2003 + else: + self.idna_codec = idna_codec - def _get_char(self): - """Read a character from input. - """ + def _get_char(self) -> str: + """Read a character from input.""" if self.ungotten_char is None: if self.eof: - c = '' + c = "" else: c = self.file.read(1) - if c == '': + if c == "": self.eof = True - elif c == '\n': + elif c == "\n": self.line_number += 1 else: c = self.ungotten_char self.ungotten_char = None return c - def where(self): + def where(self) -> Tuple[str, int]: """Return the current location in the input. Returns a (string, int) tuple. The first item is the filename of @@ -292,7 +304,7 @@ class Tokenizer: return (self.filename, self.line_number) - def _unget_char(self, c): + def _unget_char(self, c: str) -> None: """Unget a character. The unget buffer for characters is only one character large; it is @@ -308,7 +320,7 @@ class Tokenizer: raise UngetBufferFull # pragma: no cover self.ungotten_char = c - def skip_whitespace(self): + def skip_whitespace(self) -> int: """Consume input until a non-whitespace character is encountered. The non-whitespace character is then ungotten, and the number of @@ -322,13 +334,13 @@ class Tokenizer: skipped = 0 while True: c = self._get_char() - if c != ' ' and c != '\t': - if (c != '\n') or not self.multiline: + if c != " " and c != "\t": + if (c != "\n") or not self.multiline: self._unget_char(c) return skipped skipped += 1 - def get(self, want_leading=False, want_comment=False): + def get(self, want_leading: bool = False, want_comment: bool = False) -> Token: """Get the next token. want_leading: If True, return a WHITESPACE token if the @@ -345,33 +357,33 @@ class Tokenizer: """ if self.ungotten_token is not None: - token = self.ungotten_token + utoken = self.ungotten_token self.ungotten_token = None - if token.is_whitespace(): + if utoken.is_whitespace(): if want_leading: - return token - elif token.is_comment(): + return utoken + elif utoken.is_comment(): if want_comment: - return token + return utoken else: - return token + return utoken skipped = self.skip_whitespace() if want_leading and skipped > 0: - return Token(WHITESPACE, ' ') - token = '' + return Token(WHITESPACE, " ") + token = "" ttype = IDENTIFIER has_escape = False while True: c = self._get_char() - if c == '' or c in self.delimiters: - if c == '' and self.quoting: + if c == "" or c in self.delimiters: + if c == "" and self.quoting: raise dns.exception.UnexpectedEnd - if token == '' and ttype != QUOTED_STRING: - if c == '(': + if token == "" and ttype != QUOTED_STRING: + if c == "(": self.multiline += 1 self.skip_whitespace() continue - elif c == ')': + elif c == ")": if self.multiline <= 0: raise dns.exception.SyntaxError self.multiline -= 1 @@ -388,28 +400,29 @@ class Tokenizer: self.delimiters = _DELIMITERS self.skip_whitespace() continue - elif c == '\n': - return Token(EOL, '\n') - elif c == ';': + elif c == "\n": + return Token(EOL, "\n") + elif c == ";": while 1: c = self._get_char() - if c == '\n' or c == '': + if c == "\n" or c == "": break token += c if want_comment: self._unget_char(c) return Token(COMMENT, token) - elif c == '': + elif c == "": if self.multiline: raise dns.exception.SyntaxError( - 'unbalanced parentheses') + "unbalanced parentheses" + ) return Token(EOF, comment=token) elif self.multiline: self.skip_whitespace() - token = '' + token = "" continue else: - return Token(EOL, '\n', comment=token) + return Token(EOL, "\n", comment=token) else: # This code exists in case we ever want a # delimiter to be returned. It never produces @@ -419,9 +432,9 @@ class Tokenizer: else: self._unget_char(c) break - elif self.quoting and c == '\n': - raise dns.exception.SyntaxError('newline in quoted string') - elif c == '\\': + elif self.quoting and c == "\n": + raise dns.exception.SyntaxError("newline in quoted string") + elif c == "\\": # # It's an escape. Put it and the next character into # the token; it will be checked later for goodness. @@ -429,16 +442,16 @@ class Tokenizer: token += c has_escape = True c = self._get_char() - if c == '' or (c == '\n' and not self.quoting): + if c == "" or (c == "\n" and not self.quoting): raise dns.exception.UnexpectedEnd token += c - if token == '' and ttype != QUOTED_STRING: + if token == "" and ttype != QUOTED_STRING: if self.multiline: - raise dns.exception.SyntaxError('unbalanced parentheses') + raise dns.exception.SyntaxError("unbalanced parentheses") ttype = EOF return Token(ttype, token, has_escape) - def unget(self, token): + def unget(self, token: Token) -> None: """Unget a token. The unget buffer for tokens is only one token large; it is @@ -472,7 +485,7 @@ class Tokenizer: # Helpers - def get_int(self, base=10): + def get_int(self, base: int = 10) -> int: """Read the next token and interpret it as an unsigned integer. Raises dns.exception.SyntaxError if not an unsigned integer. @@ -482,12 +495,12 @@ class Tokenizer: token = self.get().unescape() if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") if not token.value.isdigit(): - raise dns.exception.SyntaxError('expecting an integer') + raise dns.exception.SyntaxError("expecting an integer") return int(token.value, base) - def get_uint8(self): + def get_uint8(self) -> int: """Read the next token and interpret it as an 8-bit unsigned integer. @@ -499,10 +512,11 @@ class Tokenizer: value = self.get_int() if value < 0 or value > 255: raise dns.exception.SyntaxError( - '%d is not an unsigned 8-bit integer' % value) + "%d is not an unsigned 8-bit integer" % value + ) return value - def get_uint16(self, base=10): + def get_uint16(self, base: int = 10) -> int: """Read the next token and interpret it as a 16-bit unsigned integer. @@ -515,13 +529,15 @@ class Tokenizer: if value < 0 or value > 65535: if base == 8: raise dns.exception.SyntaxError( - '%o is not an octal unsigned 16-bit integer' % value) + "%o is not an octal unsigned 16-bit integer" % value + ) else: raise dns.exception.SyntaxError( - '%d is not an unsigned 16-bit integer' % value) + "%d is not an unsigned 16-bit integer" % value + ) return value - def get_uint32(self, base=10): + def get_uint32(self, base: int = 10) -> int: """Read the next token and interpret it as a 32-bit unsigned integer. @@ -533,10 +549,11 @@ class Tokenizer: value = self.get_int(base=base) if value < 0 or value > 4294967295: raise dns.exception.SyntaxError( - '%d is not an unsigned 32-bit integer' % value) + "%d is not an unsigned 32-bit integer" % value + ) return value - def get_uint48(self, base=10): + def get_uint48(self, base: int = 10) -> int: """Read the next token and interpret it as a 48-bit unsigned integer. @@ -548,10 +565,11 @@ class Tokenizer: value = self.get_int(base=base) if value < 0 or value > 281474976710655: raise dns.exception.SyntaxError( - '%d is not an unsigned 48-bit integer' % value) + "%d is not an unsigned 48-bit integer" % value + ) return value - def get_string(self, max_length=None): + def get_string(self, max_length: Optional[int] = None) -> str: """Read the next token and interpret it as a string. Raises dns.exception.SyntaxError if not a string. @@ -563,12 +581,12 @@ class Tokenizer: token = self.get().unescape() if not (token.is_identifier() or token.is_quoted_string()): - raise dns.exception.SyntaxError('expecting a string') + raise dns.exception.SyntaxError("expecting a string") if max_length and len(token.value) > max_length: raise dns.exception.SyntaxError("string too long") return token.value - def get_identifier(self): + def get_identifier(self) -> str: """Read the next token, which should be an identifier. Raises dns.exception.SyntaxError if not an identifier. @@ -578,10 +596,10 @@ class Tokenizer: token = self.get().unescape() if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") return token.value - def get_remaining(self, max_tokens=None): + def get_remaining(self, max_tokens: Optional[int] = None) -> List[Token]: """Return the remaining tokens on the line, until an EOL or EOF is seen. max_tokens: If not None, stop after this number of tokens. @@ -600,7 +618,7 @@ class Tokenizer: break return tokens - def concatenate_remaining_identifiers(self, allow_empty=False): + def concatenate_remaining_identifiers(self, allow_empty: bool = False) -> str: """Read the remaining tokens on the line, which should be identifiers. Raises dns.exception.SyntaxError if there are no remaining tokens, @@ -622,10 +640,16 @@ class Tokenizer: raise dns.exception.SyntaxError s += token.value if not (allow_empty or s): - raise dns.exception.SyntaxError('expecting another identifier') + raise dns.exception.SyntaxError("expecting another identifier") return s - def as_name(self, token, origin=None, relativize=False, relativize_to=None): + def as_name( + self, + token: Token, + origin: Optional[dns.name.Name] = None, + relativize: bool = False, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.name.Name: """Try to interpret the token as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -633,11 +657,16 @@ class Tokenizer: Returns a dns.name.Name. """ if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") name = dns.name.from_text(token.value, origin, self.idna_codec) return name.choose_relativity(relativize_to or origin, relativize) - def get_name(self, origin=None, relativize=False, relativize_to=None): + def get_name( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = False, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.name.Name: """Read the next token and interpret it as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -648,7 +677,7 @@ class Tokenizer: token = self.get() return self.as_name(token, origin, relativize, relativize_to) - def get_eol_as_token(self): + def get_eol_as_token(self) -> Token: """Read the next token and raise an exception if it isn't EOL or EOF. @@ -658,14 +687,14 @@ class Tokenizer: token = self.get() if not token.is_eol_or_eof(): raise dns.exception.SyntaxError( - 'expected EOL or EOF, got %d "%s"' % (token.ttype, - token.value)) + 'expected EOL or EOF, got %d "%s"' % (token.ttype, token.value) + ) return token - def get_eol(self): + def get_eol(self) -> str: return self.get_eol_as_token().value - def get_ttl(self): + def get_ttl(self) -> int: """Read the next token and interpret it as a DNS TTL. Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an @@ -676,5 +705,5 @@ class Tokenizer: token = self.get().unescape() if not token.is_identifier(): - raise dns.exception.SyntaxError('expecting an identifier') + raise dns.exception.SyntaxError("expecting an identifier") return dns.ttl.from_text(token.value) diff --git a/lib/dns/transaction.py b/lib/dns/transaction.py index d7254924..c4a9e1f6 100644 --- a/lib/dns/transaction.py +++ b/lib/dns/transaction.py @@ -1,9 +1,12 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Any, Callable, List, Optional, Tuple, Union + import collections import dns.exception import dns.name +import dns.node import dns.rdataclass import dns.rdataset import dns.rdatatype @@ -13,11 +16,11 @@ import dns.ttl class TransactionManager: - def reader(self): + def reader(self) -> "Transaction": """Begin a read-only transaction.""" raise NotImplementedError # pragma: no cover - def writer(self, replacement=False): + def writer(self, replacement: bool = False) -> "Transaction": """Begin a writable transaction. *replacement*, a ``bool``. If `True`, the content of the @@ -27,7 +30,9 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover - def origin_information(self): + def origin_information( + self, + ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: """Returns a tuple (absolute_origin, relativize, effective_origin) @@ -52,14 +57,12 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover - def get_class(self): - """The class of the transaction manager. - """ + def get_class(self) -> dns.rdataclass.RdataClass: + """The class of the transaction manager.""" raise NotImplementedError # pragma: no cover - def from_wire_origin(self): - """Origin to use in from_wire() calls. - """ + def from_wire_origin(self) -> Optional[dns.name.Name]: + """Origin to use in from_wire() calls.""" (absolute_origin, relativize, _) = self.origin_information() if relativize: return absolute_origin @@ -84,28 +87,51 @@ def _ensure_immutable_rdataset(rdataset): return rdataset return dns.rdataset.ImmutableRdataset(rdataset) + def _ensure_immutable_node(node): if node is None or node.is_immutable(): return node return dns.node.ImmutableNode(node) -class Transaction: +CheckPutRdatasetType = Callable[ + ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None +] +CheckDeleteRdatasetType = Callable[ + ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType], + None, +] +CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None] - def __init__(self, manager, replacement=False, read_only=False): + +class Transaction: + def __init__( + self, + manager: TransactionManager, + replacement: bool = False, + read_only: bool = False, + ): self.manager = manager self.replacement = replacement self.read_only = read_only self._ended = False - self._check_put_rdataset = [] - self._check_delete_rdataset = [] - self._check_delete_name = [] + self._check_put_rdataset: List[CheckPutRdatasetType] = [] + self._check_delete_rdataset: List[CheckDeleteRdatasetType] = [] + self._check_delete_name: List[CheckDeleteNameType] = [] # # This is the high level API # + # Note that we currently use non-immutable types in the return type signature to + # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be + # unhappy if we return an ImmutableRdataset. - def get(self, name, rdtype, covers=dns.rdatatype.NONE): + def get( + self, + name: Optional[Union[dns.name.Name, str]], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> dns.rdataset.Rdataset: """Return the rdataset associated with *name*, *rdtype*, and *covers*, or `None` if not found. @@ -115,21 +141,22 @@ class Transaction: if isinstance(name, str): name = dns.name.from_text(name, None) rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) rdataset = self._get_rdataset(name, rdtype, covers) return _ensure_immutable_rdataset(rdataset) - def get_node(self, name): + def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: """Return the node at *name*, if any. Returns an immutable node or ``None``. """ return _ensure_immutable_node(self._get_node(name)) - def _check_read_only(self): + def _check_read_only(self) -> None: if self.read_only: raise ReadOnly - def add(self, *args): + def add(self, *args: Any) -> None: """Add records. The arguments may be: @@ -142,9 +169,9 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._add(False, args) + self._add(False, args) - def replace(self, *args): + def replace(self, *args: Any) -> None: """Replace the existing rdataset at the name with the specified rdataset, or add the specified rdataset if there was no existing rdataset. @@ -163,9 +190,9 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._add(True, args) + self._add(True, args) - def delete(self, *args): + def delete(self, *args: Any) -> None: """Delete records. It is not an error if some of the records are not in the existing @@ -185,9 +212,9 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._delete(False, args) + self._delete(False, args) - def delete_exact(self, *args): + def delete_exact(self, *args: Any) -> None: """Delete records. The arguments may be: @@ -208,16 +235,21 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._delete(True, args) + self._delete(True, args) - def name_exists(self, name): + def name_exists(self, name: Union[dns.name.Name, str]) -> bool: """Does the specified name exist?""" self._check_ended() if isinstance(name, str): name = dns.name.from_text(name, None) return self._name_exists(name) - def update_serial(self, value=1, relative=True, name=dns.name.empty): + def update_serial( + self, + value: int = 1, + relative: bool = True, + name: dns.name.Name = dns.name.empty, + ) -> None: """Update the serial number. *value*, an `int`, is an increment if *relative* is `True`, or the @@ -231,11 +263,10 @@ class Transaction: """ self._check_ended() if value < 0: - raise ValueError('negative update_serial() value') + raise ValueError("negative update_serial() value") if isinstance(name, str): name = dns.name.from_text(name, None) - rdataset = self._get_rdataset(name, dns.rdatatype.SOA, - dns.rdatatype.NONE) + rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE) if rdataset is None or len(rdataset) == 0: raise KeyError if relative: @@ -253,7 +284,7 @@ class Transaction: self._check_ended() return self._iterate_rdatasets() - def changed(self): + def changed(self) -> bool: """Has this transaction changed anything? For read-only transactions, the result is always `False`. @@ -264,7 +295,7 @@ class Transaction: self._check_ended() return self._changed() - def commit(self): + def commit(self) -> None: """Commit the transaction. Normally transactions are used as context managers and commit @@ -277,7 +308,7 @@ class Transaction: """ self._end(True) - def rollback(self): + def rollback(self) -> None: """Rollback the transaction. Normally transactions are used as context managers and commit @@ -289,7 +320,7 @@ class Transaction: """ self._end(False) - def check_put_rdataset(self, check): + def check_put_rdataset(self, check: CheckPutRdatasetType) -> None: """Call *check* before putting (storing) an rdataset. The function is called with the transaction, the name, and the rdataset. @@ -301,7 +332,7 @@ class Transaction: """ self._check_put_rdataset.append(check) - def check_delete_rdataset(self, check): + def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None: """Call *check* before deleting an rdataset. The function is called with the transaction, the name, the rdatatype, @@ -314,7 +345,7 @@ class Transaction: """ self._check_delete_rdataset.append(check) - def check_delete_name(self, check): + def check_delete_name(self, check: CheckDeleteNameType) -> None: """Call *check* before putting (storing) an rdataset. The function is called with the transaction and the name. @@ -332,7 +363,7 @@ class Transaction: def _raise_if_not_empty(self, method, args): if len(args) != 0: - raise TypeError(f'extra parameters to {method}') + raise TypeError(f"extra parameters to {method}") def _rdataset_from_args(self, method, deleting, args): try: @@ -348,29 +379,29 @@ class Transaction: if isinstance(arg, int): ttl = arg if ttl > dns.ttl.MAX_TTL: - raise ValueError(f'{method}: TTL value too big') + raise ValueError(f"{method}: TTL value too big") else: - raise TypeError(f'{method}: expected a TTL') + raise TypeError(f"{method}: expected a TTL") arg = args.popleft() if isinstance(arg, dns.rdata.Rdata): rdataset = dns.rdataset.from_rdata(ttl, arg) else: - raise TypeError(f'{method}: expected an Rdata') + raise TypeError(f"{method}: expected an Rdata") return rdataset except IndexError: if deleting: return None else: # reraise - raise TypeError(f'{method}: expected more arguments') + raise TypeError(f"{method}: expected more arguments") def _add(self, replace, args): try: args = collections.deque(args) if replace: - method = 'replace()' + method = "replace()" else: - method = 'add()' + method = "add()" arg = args.popleft() if isinstance(arg, str): arg = dns.name.from_text(arg, None) @@ -384,44 +415,45 @@ class Transaction: # same and can't be stored in nodes, so convert. rdataset = rrset.to_rdataset() else: - raise TypeError(f'{method} requires a name or RRset ' + - 'as the first argument') + raise TypeError( + f"{method} requires a name or RRset " + "as the first argument" + ) if rdataset.rdclass != self.manager.get_class(): - raise ValueError(f'{method} has objects of wrong RdataClass') + raise ValueError(f"{method} has objects of wrong RdataClass") if rdataset.rdtype == dns.rdatatype.SOA: (_, _, origin) = self._origin_information() if name != origin: - raise ValueError(f'{method} has non-origin SOA') + raise ValueError(f"{method} has non-origin SOA") self._raise_if_not_empty(method, args) if not replace: - existing = self._get_rdataset(name, rdataset.rdtype, - rdataset.covers) + existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) if existing is not None: if isinstance(existing, dns.rdataset.ImmutableRdataset): - trds = dns.rdataset.Rdataset(existing.rdclass, - existing.rdtype, - existing.covers) + trds = dns.rdataset.Rdataset( + existing.rdclass, existing.rdtype, existing.covers + ) trds.update(existing) existing = trds rdataset = existing.union(rdataset) self._checked_put_rdataset(name, rdataset) except IndexError: - raise TypeError(f'not enough parameters to {method}') + raise TypeError(f"not enough parameters to {method}") def _delete(self, exact, args): try: args = collections.deque(args) if exact: - method = 'delete_exact()' + method = "delete_exact()" else: - method = 'delete()' + method = "delete()" arg = args.popleft() if isinstance(arg, str): arg = dns.name.from_text(arg, None) if isinstance(arg, dns.name.Name): name = arg - if len(args) > 0 and (isinstance(args[0], int) or - isinstance(args[0], str)): + if len(args) > 0 and ( + isinstance(args[0], int) or isinstance(args[0], str) + ): # deleting by type and (optionally) covers rdtype = dns.rdatatype.RdataType.make(args.popleft()) if len(args) > 0: @@ -432,7 +464,7 @@ class Transaction: existing = self._get_rdataset(name, rdtype, covers) if existing is None: if exact: - raise DeleteNotExact(f'{method}: missing rdataset') + raise DeleteNotExact(f"{method}: missing rdataset") else: self._delete_rdataset(name, rdtype, covers) return @@ -442,34 +474,34 @@ class Transaction: rdataset = arg # rrsets are also rdatasets name = rdataset.name else: - raise TypeError(f'{method} requires a name or RRset ' + - 'as the first argument') + raise TypeError( + f"{method} requires a name or RRset " + "as the first argument" + ) self._raise_if_not_empty(method, args) if rdataset: if rdataset.rdclass != self.manager.get_class(): - raise ValueError(f'{method} has objects of wrong ' - 'RdataClass') - existing = self._get_rdataset(name, rdataset.rdtype, - rdataset.covers) + raise ValueError(f"{method} has objects of wrong RdataClass") + existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) if existing is not None: if exact: intersection = existing.intersection(rdataset) if intersection != rdataset: - raise DeleteNotExact(f'{method}: missing rdatas') + raise DeleteNotExact(f"{method}: missing rdatas") rdataset = existing.difference(rdataset) if len(rdataset) == 0: - self._checked_delete_rdataset(name, rdataset.rdtype, - rdataset.covers) + self._checked_delete_rdataset( + name, rdataset.rdtype, rdataset.covers + ) else: self._checked_put_rdataset(name, rdataset) elif exact: - raise DeleteNotExact(f'{method}: missing rdataset') + raise DeleteNotExact(f"{method}: missing rdataset") else: if exact and not self._name_exists(name): - raise DeleteNotExact(f'{method}: name not known') + raise DeleteNotExact(f"{method}: name not known") self._checked_delete_name(name) except IndexError: - raise TypeError(f'not enough parameters to {method}') + raise TypeError(f"not enough parameters to {method}") def _check_ended(self): if self._ended: @@ -575,8 +607,7 @@ class Transaction: raise NotImplementedError # pragma: no cover def _iterate_rdatasets(self): - """Return an iterator that yields (name, rdataset) tuples. - """ + """Return an iterator that yields (name, rdataset) tuples.""" raise NotImplementedError # pragma: no cover def _get_node(self, name): diff --git a/lib/dns/tsig.py b/lib/dns/tsig.py index 50b2d47e..2476fdfb 100644 --- a/lib/dns/tsig.py +++ b/lib/dns/tsig.py @@ -27,6 +27,7 @@ import dns.rdataclass import dns.name import dns.rcode + class BadTime(dns.exception.DNSException): """The current time is not within the TSIG's validity time.""" @@ -87,6 +88,19 @@ GSS_TSIG = dns.name.from_text("gss-tsig") default_algorithm = HMAC_SHA256 +mac_sizes = { + HMAC_SHA1: 20, + HMAC_SHA224: 28, + HMAC_SHA256: 32, + HMAC_SHA256_128: 16, + HMAC_SHA384: 48, + HMAC_SHA384_192: 24, + HMAC_SHA512: 64, + HMAC_SHA512_256: 32, + HMAC_MD5: 16, + GSS_TSIG: 128, # This is what we assume to be the worst case! +} + class GSSTSig: """ @@ -97,10 +111,11 @@ class GSSTSig: In order to avoid a direct GSSAPI dependency, the keyring holds a ref to the GSSAPI object required, rather than the key itself. """ + def __init__(self, gssapi_context): self.gssapi_context = gssapi_context - self.data = b'' - self.name = 'gss-tsig' + self.data = b"" + self.name = "gss-tsig" def update(self, data): self.data += data @@ -139,9 +154,9 @@ class GSSTSigAdapter: # client to complete the GSSAPI negotiation before attempting # to verify the signed response to a TKEY message exchange try: - rrset = message.find_rrset(message.answer, keyname, - dns.rdataclass.ANY, - dns.rdatatype.TKEY) + rrset = message.find_rrset( + message.answer, keyname, dns.rdataclass.ANY, dns.rdatatype.TKEY + ) if rrset: token = rrset[0].key gssapi_context = key.secret @@ -172,8 +187,9 @@ class HMACTSig: try: hashinfo = self._hashes[algorithm] except KeyError: - raise NotImplementedError(f"TSIG algorithm {algorithm} " + - "is not supported") + raise NotImplementedError( + f"TSIG algorithm {algorithm} " + "is not supported" + ) # create the HMAC context if isinstance(hashinfo, tuple): @@ -184,7 +200,7 @@ class HMACTSig: self.size = None self.name = self.hmac_context.name if self.size: - self.name += f'-{self.size}' + self.name += f"-{self.size}" def update(self, data): return self.hmac_context.update(data) @@ -203,8 +219,7 @@ class HMACTSig: raise BadSignature -def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, - multi=None): +def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None): """Return a context containing the TSIG rdata for the input parameters @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object @raises ValueError: I{other_data} is too long @@ -215,25 +230,25 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, if first: ctx = get_context(key) if request_mac: - ctx.update(struct.pack('!H', len(request_mac))) + ctx.update(struct.pack("!H", len(request_mac))) ctx.update(request_mac) - ctx.update(struct.pack('!H', rdata.original_id)) + ctx.update(struct.pack("!H", rdata.original_id)) ctx.update(wire[2:]) if first: ctx.update(key.name.to_digestable()) - ctx.update(struct.pack('!H', dns.rdataclass.ANY)) - ctx.update(struct.pack('!I', 0)) + ctx.update(struct.pack("!H", dns.rdataclass.ANY)) + ctx.update(struct.pack("!I", 0)) if time is None: time = rdata.time_signed - upper_time = (time >> 32) & 0xffff - lower_time = time & 0xffffffff - time_encoded = struct.pack('!HIH', upper_time, lower_time, rdata.fudge) + upper_time = (time >> 32) & 0xFFFF + lower_time = time & 0xFFFFFFFF + time_encoded = struct.pack("!HIH", upper_time, lower_time, rdata.fudge) other_len = len(rdata.other) if other_len > 65535: - raise ValueError('TSIG Other Data is > 65535 bytes') + raise ValueError("TSIG Other Data is > 65535 bytes") if first: ctx.update(key.algorithm.to_digestable() + time_encoded) - ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other) + ctx.update(struct.pack("!HH", rdata.error, other_len) + rdata.other) else: ctx.update(time_encoded) return ctx @@ -246,7 +261,7 @@ def _maybe_start_digest(key, mac, multi): """ if multi: ctx = get_context(key) - ctx.update(struct.pack('!H', len(mac))) + ctx.update(struct.pack("!H", len(mac))) ctx.update(mac) return ctx else: @@ -269,8 +284,9 @@ def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False): return (tsig, _maybe_start_digest(key, mac, multi)) -def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, - multi=False): +def validate( + wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, multi=False +): """Validate the specified TSIG rdata against the other input parameters. @raises FormError: The TSIG is badly formed. @@ -294,7 +310,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, elif rdata.error == dns.rcode.BADTRUNC: raise PeerBadTruncation else: - raise PeerError('unknown TSIG error code %d' % rdata.error) + raise PeerError("unknown TSIG error code %d" % rdata.error) if abs(rdata.time_signed - now) > rdata.fudge: raise BadTime if key.name != owner: @@ -332,14 +348,15 @@ class Key: self.algorithm = algorithm def __eq__(self, other): - return (isinstance(other, Key) and - self.name == other.name and - self.secret == other.secret and - self.algorithm == other.algorithm) + return ( + isinstance(other, Key) + and self.name == other.name + and self.secret == other.secret + and self.algorithm == other.algorithm + ) def __repr__(self): - r = f" Dict[dns.name.Name, dns.tsig.Key]: """Convert a dictionary containing (textual DNS name, base64 secret) pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or a dictionary containing (textual DNS name, (algorithm, base64 secret)) @@ -32,16 +34,16 @@ def from_text(textring): keyring = {} for (name, value) in textring.items(): - name = dns.name.from_text(name) + kname = dns.name.from_text(name) if isinstance(value, str): - keyring[name] = dns.tsig.Key(name, value).secret + keyring[kname] = dns.tsig.Key(kname, value).secret else: (algorithm, secret) = value - keyring[name] = dns.tsig.Key(name, secret, algorithm) + keyring[kname] = dns.tsig.Key(kname, secret, algorithm) return keyring -def to_text(keyring): +def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]: """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs into a text keyring which has (textual DNS name, (textual algorithm, base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes) @@ -49,17 +51,19 @@ def to_text(keyring): @rtype: dict""" textring = {} + def b64encode(secret): return base64.encodebytes(secret).decode().rstrip() + for (name, key) in keyring.items(): - name = name.to_text() + tname = name.to_text() if isinstance(key, bytes): - textring[name] = b64encode(key) + textring[tname] = b64encode(key) else: if isinstance(key.secret, bytes): text_secret = b64encode(key.secret) else: text_secret = str(key.secret) - textring[name] = (key.algorithm.to_text(), text_secret) + textring[tname] = (key.algorithm.to_text(), text_secret) return textring diff --git a/lib/dns/tsigkeyring.pyi b/lib/dns/tsigkeyring.pyi deleted file mode 100644 index b5d51e15..00000000 --- a/lib/dns/tsigkeyring.pyi +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Dict -from . import name - -def from_text(textring : Dict[str,str]) -> Dict[name.Name,bytes]: - ... -def to_text(keyring : Dict[name.Name,bytes]) -> Dict[str, str]: - ... diff --git a/lib/dns/ttl.py b/lib/dns/ttl.py index df92b2b6..264b0338 100644 --- a/lib/dns/ttl.py +++ b/lib/dns/ttl.py @@ -17,6 +17,8 @@ """DNS TTL conversion.""" +from typing import Union + import dns.exception # Technically TTLs are supposed to be between 0 and 2**31 - 1, with values @@ -31,7 +33,7 @@ class BadTTL(dns.exception.SyntaxError): """DNS TTL value is not well-formed.""" -def from_text(text): +def from_text(text: str) -> int: """Convert the text form of a TTL to an integer. The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported. @@ -60,15 +62,15 @@ def from_text(text): if need_digit: raise BadTTL c = c.lower() - if c == 'w': + if c == "w": total += current * 604800 - elif c == 'd': + elif c == "d": total += current * 86400 - elif c == 'h': + elif c == "h": total += current * 3600 - elif c == 'm': + elif c == "m": total += current * 60 - elif c == 's': + elif c == "s": total += current else: raise BadTTL("unknown unit '%s'" % c) @@ -81,10 +83,10 @@ def from_text(text): return total -def make(value): +def make(value: Union[int, str]) -> int: if isinstance(value, int): return value elif isinstance(value, str): return dns.ttl.from_text(value) else: - raise ValueError('cannot convert value to TTL') + raise ValueError("cannot convert value to TTL") diff --git a/lib/dns/update.py b/lib/dns/update.py index a541af22..647e5b19 100644 --- a/lib/dns/update.py +++ b/lib/dns/update.py @@ -17,18 +17,21 @@ """DNS Dynamic Update Support""" +from typing import Any, List, Optional, Union import dns.message import dns.name import dns.opcode import dns.rdata import dns.rdataclass +import dns.rdatatype import dns.rdataset import dns.tsig class UpdateSection(dns.enum.IntEnum): """Update sections""" + ZONE = 0 PREREQ = 1 UPDATE = 2 @@ -39,13 +42,20 @@ class UpdateSection(dns.enum.IntEnum): return 3 -class UpdateMessage(dns.message.Message): +class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] - _section_enum = UpdateSection + # ignore the mypy error here as we mean to use a different enum + _section_enum = UpdateSection # type: ignore - def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None, - keyname=None, keyalgorithm=dns.tsig.default_algorithm, - id=None): + def __init__( + self, + zone: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + keyring: Optional[Any] = None, + keyname: Optional[dns.name.Name] = None, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + id: Optional[int] = None, + ): """Initialize a new DNS Update object. See the documentation of the Message class for a complete @@ -69,13 +79,19 @@ class UpdateMessage(dns.message.Message): rdclass = dns.rdataclass.RdataClass.make(rdclass) self.zone_rdclass = rdclass if self.origin: - self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA, - create=True, force_unique=True) + self.find_rrset( + self.zone, + self.origin, + rdclass, + dns.rdatatype.SOA, + create=True, + force_unique=True, + ) if keyring is not None: self.use_tsig(keyring, keyname, algorithm=keyalgorithm) @property - def zone(self): + def zone(self) -> List[dns.rrset.RRset]: """The zone section.""" return self.sections[0] @@ -84,7 +100,7 @@ class UpdateMessage(dns.message.Message): self.sections[0] = v @property - def prerequisite(self): + def prerequisite(self) -> List[dns.rrset.RRset]: """The prerequisite section.""" return self.sections[1] @@ -93,7 +109,7 @@ class UpdateMessage(dns.message.Message): self.sections[1] = v @property - def update(self): + def update(self) -> List[dns.rrset.RRset]: """The update section.""" return self.sections[2] @@ -107,8 +123,9 @@ class UpdateMessage(dns.message.Message): if section is None: section = self.update covers = rd.covers() - rrset = self.find_rrset(section, name, self.zone_rdclass, rd.rdtype, - covers, deleting, True, True) + rrset = self.find_rrset( + section, name, self.zone_rdclass, rd.rdtype, covers, deleting, True, True + ) rrset.add(rd, ttl) def _add(self, replace, section, name, *args): @@ -148,11 +165,10 @@ class UpdateMessage(dns.message.Message): if replace: self.delete(name, rdtype) for s in args: - rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, - self.origin) + rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, self.origin) self._add_rr(name, ttl, rd, section=section) - def add(self, name, *args): + def add(self, name: Union[dns.name.Name, str], *args: Any) -> None: """Add records. The first argument is always a name. The other @@ -167,7 +183,7 @@ class UpdateMessage(dns.message.Message): self._add(False, self.update, name, *args) - def delete(self, name, *args): + def delete(self, name: Union[dns.name.Name, str], *args: Any) -> None: """Delete records. The first argument is always a name. The other @@ -185,33 +201,49 @@ class UpdateMessage(dns.message.Message): if isinstance(name, str): name = dns.name.from_text(name, None) if len(args) == 0: - self.find_rrset(self.update, name, dns.rdataclass.ANY, - dns.rdatatype.ANY, dns.rdatatype.NONE, - dns.rdatatype.ANY, True, True) + self.find_rrset( + self.update, + name, + dns.rdataclass.ANY, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + True, + True, + ) elif isinstance(args[0], dns.rdataset.Rdataset): for rds in args: for rd in rds: self._add_rr(name, 0, rd, dns.rdataclass.NONE) else: - args = list(args) - if isinstance(args[0], dns.rdata.Rdata): - for rd in args: + largs = list(args) + if isinstance(largs[0], dns.rdata.Rdata): + for rd in largs: self._add_rr(name, 0, rd, dns.rdataclass.NONE) else: - rdtype = dns.rdatatype.RdataType.make(args.pop(0)) - if len(args) == 0: - self.find_rrset(self.update, name, - self.zone_rdclass, rdtype, - dns.rdatatype.NONE, - dns.rdataclass.ANY, - True, True) + rdtype = dns.rdatatype.RdataType.make(largs.pop(0)) + if len(largs) == 0: + self.find_rrset( + self.update, + name, + self.zone_rdclass, + rdtype, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + True, + True, + ) else: - for s in args: - rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, - self.origin) + for s in largs: + rd = dns.rdata.from_text( + self.zone_rdclass, + rdtype, + s, # type: ignore[arg-type] + self.origin, + ) self._add_rr(name, 0, rd, dns.rdataclass.NONE) - def replace(self, name, *args): + def replace(self, name: Union[dns.name.Name, str], *args: Any) -> None: """Replace records. The first argument is always a name. The other @@ -229,7 +261,7 @@ class UpdateMessage(dns.message.Message): self._add(True, self.update, name, *args) - def present(self, name, *args): + def present(self, name: Union[dns.name.Name, str], *args: Any) -> None: """Require that an owner name (and optionally an rdata type, or specific rdataset) exists as a prerequisite to the execution of the update. @@ -247,42 +279,74 @@ class UpdateMessage(dns.message.Message): if isinstance(name, str): name = dns.name.from_text(name, None) if len(args) == 0: - self.find_rrset(self.prerequisite, name, - dns.rdataclass.ANY, dns.rdatatype.ANY, - dns.rdatatype.NONE, None, - True, True) - elif isinstance(args[0], dns.rdataset.Rdataset) or \ - isinstance(args[0], dns.rdata.Rdata) or \ - len(args) > 1: + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.ANY, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + None, + True, + True, + ) + elif ( + isinstance(args[0], dns.rdataset.Rdataset) + or isinstance(args[0], dns.rdata.Rdata) + or len(args) > 1 + ): if not isinstance(args[0], dns.rdataset.Rdataset): # Add a 0 TTL - args = list(args) - args.insert(0, 0) - self._add(False, self.prerequisite, name, *args) + largs = list(args) + largs.insert(0, 0) # type: ignore[arg-type] + self._add(False, self.prerequisite, name, *largs) + else: + self._add(False, self.prerequisite, name, *args) else: rdtype = dns.rdatatype.RdataType.make(args[0]) - self.find_rrset(self.prerequisite, name, - dns.rdataclass.ANY, rdtype, - dns.rdatatype.NONE, None, - True, True) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.ANY, + rdtype, + dns.rdatatype.NONE, + None, + True, + True, + ) - def absent(self, name, rdtype=None): + def absent( + self, + name: Union[dns.name.Name, str], + rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None, + ) -> None: """Require that an owner name (and optionally an rdata type) does not exist as a prerequisite to the execution of the update.""" if isinstance(name, str): name = dns.name.from_text(name, None) if rdtype is None: - self.find_rrset(self.prerequisite, name, - dns.rdataclass.NONE, dns.rdatatype.ANY, - dns.rdatatype.NONE, None, - True, True) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.NONE, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + None, + True, + True, + ) else: - rdtype = dns.rdatatype.RdataType.make(rdtype) - self.find_rrset(self.prerequisite, name, - dns.rdataclass.NONE, rdtype, - dns.rdatatype.NONE, None, - True, True) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.NONE, + the_rdtype, + dns.rdatatype.NONE, + None, + True, + True, + ) def _get_one_rr_per_rrset(self, value): # Updates are always one_rr_per_rrset @@ -292,9 +356,11 @@ class UpdateMessage(dns.message.Message): deleting = None empty = False if section == UpdateSection.ZONE: - if dns.rdataclass.is_metaclass(rdclass) or \ - rdtype != dns.rdatatype.SOA or \ - self.zone: + if ( + dns.rdataclass.is_metaclass(rdclass) + or rdtype != dns.rdatatype.SOA + or self.zone + ): raise dns.exception.FormError else: if not self.zone: @@ -302,10 +368,12 @@ class UpdateMessage(dns.message.Message): if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE): deleting = rdclass rdclass = self.zone[0].rdclass - empty = (deleting == dns.rdataclass.ANY or - section == UpdateSection.PREREQ) + empty = ( + deleting == dns.rdataclass.ANY or section == UpdateSection.PREREQ + ) return (rdclass, rdtype, deleting, empty) + # backwards compatibility Update = UpdateMessage diff --git a/lib/dns/update.pyi b/lib/dns/update.pyi deleted file mode 100644 index eeac0591..00000000 --- a/lib/dns/update.pyi +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional,Dict,Union,Any - -from . import message, tsig, rdataclass, name - -class Update(message.Message): - def __init__(self, zone : Union[name.Name, str], rdclass : Union[int,str] = rdataclass.IN, keyring : Optional[Dict[name.Name,bytes]] = None, - keyname : Optional[name.Name] = None, keyalgorithm : Optional[name.Name] = tsig.default_algorithm) -> None: - self.id : int - def add(self, name : Union[str,name.Name], *args : Any): - ... - def delete(self, name, *args : Any): - ... - def replace(self, name : Union[str,name.Name], *args : Any): - ... - def present(self, name : Union[str,name.Name], *args : Any): - ... - def absent(self, name : Union[str,name.Name], rdtype=None): - """Require that an owner name (and optionally an rdata type) does - not exist as a prerequisite to the execution of the update.""" - def to_wire(self, origin : Optional[name.Name] = None, max_size=65535, **kw) -> bytes: - ... diff --git a/lib/dns/version.py b/lib/dns/version.py index 65017872..89d4cf1a 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -20,27 +20,39 @@ #: MAJOR MAJOR = 2 #: MINOR -MINOR = 2 +MINOR = 3 #: MICRO -MICRO = 1 +MICRO = 0 #: RELEASELEVEL -RELEASELEVEL = 0x0f +RELEASELEVEL = 0x0F #: SERIAL -SERIAL = 0 +SERIAL = 1 -if RELEASELEVEL == 0x0f: # pragma: no cover +if RELEASELEVEL == 0x0F: # pragma: no cover lgtm[py/unreachable-statement] #: version - version = '%d.%d.%d' % (MAJOR, MINOR, MICRO) -elif RELEASELEVEL == 0x00: # pragma: no cover - version = '%d.%d.%ddev%d' % \ - (MAJOR, MINOR, MICRO, SERIAL) -elif RELEASELEVEL == 0x0c: # pragma: no cover - version = '%d.%d.%drc%d' % \ - (MAJOR, MINOR, MICRO, SERIAL) -else: # pragma: no cover - version = '%d.%d.%d%x%d' % \ - (MAJOR, MINOR, MICRO, RELEASELEVEL, SERIAL) + version = "%d.%d.%d" % (MAJOR, MINOR, MICRO) # lgtm[py/unreachable-statement] +elif RELEASELEVEL == 0x00: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%ddev%d" % ( + MAJOR, + MINOR, + MICRO, + SERIAL, + ) # lgtm[py/unreachable-statement] +elif RELEASELEVEL == 0x0C: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%drc%d" % ( + MAJOR, + MINOR, + MICRO, + SERIAL, + ) # lgtm[py/unreachable-statement] +else: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%d%x%d" % ( + MAJOR, + MINOR, + MICRO, + RELEASELEVEL, + SERIAL, + ) # lgtm[py/unreachable-statement] #: hexversion -hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | \ - SERIAL +hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | SERIAL diff --git a/lib/dns/versioned.py b/lib/dns/versioned.py index 8b6c275f..02e24122 100644 --- a/lib/dns/versioned.py +++ b/lib/dns/versioned.py @@ -2,16 +2,17 @@ """DNS Versioned Zones.""" +from typing import Callable, Deque, Optional, Set, Union + import collections -try: - import threading as _threading -except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore +import threading import dns.exception import dns.immutable import dns.name +import dns.node import dns.rdataclass +import dns.rdataset import dns.rdatatype import dns.rdtypes.ANY.SOA import dns.zone @@ -30,16 +31,27 @@ ImmutableVersion = dns.zone.ImmutableVersion Transaction = dns.zone.Transaction -class Zone(dns.zone.Zone): +class Zone(dns.zone.Zone): # lgtm[py/missing-equals] - __slots__ = ['_versions', '_versions_lock', '_write_txn', - '_write_waiters', '_write_event', '_pruning_policy', - '_readers'] + __slots__ = [ + "_versions", + "_versions_lock", + "_write_txn", + "_write_waiters", + "_write_event", + "_pruning_policy", + "_readers", + ] node_factory = Node - def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, - pruning_policy=None): + def __init__( + self, + origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None, + ): """Initialize a versioned zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -51,28 +63,30 @@ class Zone(dns.zone.Zone): *relativize*, a ``bool``, determine's whether domain names are relativized to the zone's origin. The default is ``True``. - *pruning policy*, a function taking a `Version` and returning - a `bool`, or `None`. Should the version be pruned? If `None`, + *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning + a ``bool``, or ``None``. Should the version be pruned? If ``None``, the default policy, which retains one version is used. """ super().__init__(origin, rdclass, relativize) - self._versions = collections.deque() - self._version_lock = _threading.Lock() + self._versions: Deque[Version] = collections.deque() + self._version_lock = threading.Lock() if pruning_policy is None: self._pruning_policy = self._default_pruning_policy else: self._pruning_policy = pruning_policy - self._write_txn = None - self._write_event = None - self._write_waiters = collections.deque() - self._readers = set() - self._commit_version_unlocked(None, - WritableVersion(self, replacement=True), - origin) + self._write_txn: Optional[Transaction] = None + self._write_event: Optional[threading.Event] = None + self._write_waiters: Deque[threading.Event] = collections.deque() + self._readers: Set[Transaction] = set() + self._commit_version_unlocked( + None, WritableVersion(self, replacement=True), origin + ) - def reader(self, id=None, serial=None): # pylint: disable=arguments-differ + def reader( + self, id: Optional[int] = None, serial: Optional[int] = None + ) -> Transaction: # pylint: disable=arguments-differ if id is not None and serial is not None: - raise ValueError('cannot specify both id and serial') + raise ValueError("cannot specify both id and serial") with self._version_lock: if id is not None: version = None @@ -81,11 +95,12 @@ class Zone(dns.zone.Zone): version = v break if version is None: - raise KeyError('version not found') + raise KeyError("version not found") elif serial is not None: if self.relativize: oname = dns.name.empty else: + assert self.origin is not None oname = self.origin version = None for v in reversed(self._versions): @@ -96,14 +111,14 @@ class Zone(dns.zone.Zone): version = v break if version is None: - raise KeyError('serial not found') + raise KeyError("serial not found") else: version = self._versions[-1] txn = Transaction(self, False, version) self._readers.add(txn) return txn - def writer(self, replacement=False): + def writer(self, replacement: bool = False) -> Transaction: event = None while True: with self._version_lock: @@ -117,15 +132,16 @@ class Zone(dns.zone.Zone): # give up the lock, so that we hold the lock as # short a time as possible. This is why we call # _setup_version() below. - self._write_txn = Transaction(self, replacement, - make_immutable=True) + self._write_txn = Transaction( + self, replacement, make_immutable=True + ) # give up our exclusive right to make a Transaction self._write_event = None break # Someone else is writing already, so we will have to # wait, but we want to do the actual wait outside the # lock. - event = _threading.Event() + event = threading.Event() self._write_waiters.append(event) # wait (note we gave up the lock!) # @@ -159,6 +175,7 @@ class Zone(dns.zone.Zone): # pylint: disable=unused-argument def _default_pruning_policy(self, zone, version): return True + # pylint: enable=unused-argument def _prune_versions_unlocked(self): @@ -174,25 +191,32 @@ class Zone(dns.zone.Zone): least_kept = min(txn.version.id for txn in self._readers) else: least_kept = self._versions[-1].id - while self._versions[0].id < least_kept and \ - self._pruning_policy(self, self._versions[0]): + while self._versions[0].id < least_kept and self._pruning_policy( + self, self._versions[0] + ): self._versions.popleft() - def set_max_versions(self, max_versions): + def set_max_versions(self, max_versions: Optional[int]) -> None: """Set a pruning policy that retains up to the specified number of versions """ if max_versions is not None and max_versions < 1: - raise ValueError('max versions must be at least 1') + raise ValueError("max versions must be at least 1") if max_versions is None: - def policy(*_): + + def policy(zone, _): # pylint: disable=unused-argument return False + else: + def policy(zone, _): return len(zone._versions) > max_versions + self.set_pruning_policy(policy) - def set_pruning_policy(self, policy): + def set_pruning_policy( + self, policy: Optional[Callable[["Zone", Version], Optional[bool]]] + ) -> None: """Set the pruning policy for the zone. The *policy* function takes a `Version` and returns `True` if @@ -245,30 +269,52 @@ class Zone(dns.zone.Zone): id = 1 return id - def find_node(self, name, create=False): + def find_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> dns.node.Node: if create: raise UseTransaction return super().find_node(name) - def delete_node(self, name): + def delete_node(self, name: Union[dns.name.Name, str]) -> None: raise UseTransaction - def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: if create: raise UseTransaction rdataset = super().find_rdataset(name, rdtype, covers) return dns.rdataset.ImmutableRdataset(rdataset) - def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: if create: raise UseTransaction rdataset = super().get_rdataset(name, rdtype, covers) - return dns.rdataset.ImmutableRdataset(rdataset) + if rdataset is not None: + return dns.rdataset.ImmutableRdataset(rdataset) + else: + return None - def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> None: raise UseTransaction - def replace_rdataset(self, name, replacement): + def replace_rdataset( + self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset + ) -> None: raise UseTransaction diff --git a/lib/dns/win32util.py b/lib/dns/win32util.py index 745317a3..ac314750 100644 --- a/lib/dns/win32util.py +++ b/lib/dns/win32util.py @@ -1,20 +1,26 @@ import sys -if sys.platform == 'win32': +if sys.platform == "win32": + + from typing import Any import dns.name _prefer_wmi = True - import winreg + import winreg # pylint: disable=import-error + + # Keep pylint quiet on non-windows. + try: + WindowsError is None # pylint: disable=used-before-assignment + except KeyError: + WindowsError = Exception try: - try: - import threading as _threading - except ImportError: # pragma: no cover - import dummy_threading as _threading # type: ignore - import pythoncom - import wmi + import threading + import pythoncom # pylint: disable=import-error + import wmi # pylint: disable=import-error + _have_wmi = True except Exception: _have_wmi = False @@ -23,7 +29,7 @@ if sys.platform == 'win32': # Sometimes DHCP servers add a '.' prefix to the default domain, and # Windows just stores such values in the registry (see #687). # Check for this and fix it. - if domain.startswith('.'): + if domain.startswith("."): domain = domain[1:] return dns.name.from_text(domain) @@ -34,7 +40,8 @@ if sys.platform == 'win32': self.search = [] if _have_wmi: - class _WMIGetter(_threading.Thread): + + class _WMIGetter(threading.Thread): def __init__(self): super().__init__() self.info = DnsInfo() @@ -44,11 +51,14 @@ if sys.platform == 'win32': try: system = wmi.WMI() for interface in system.Win32_NetworkAdapterConfiguration(): - if interface.IPEnabled: + if interface.IPEnabled and interface.DNSDomain: self.info.domain = _config_domain(interface.DNSDomain) self.info.nameservers = list(interface.DNSServerSearchOrder) - self.info.search = [dns.name.from_text(x) for x in - interface.DNSDomainSuffixSearchOrder] + if interface.DNSDomainSuffixSearchOrder: + self.info.search = [ + _config_domain(x) + for x in interface.DNSDomainSuffixSearchOrder + ] break finally: pythoncom.CoUninitialize() @@ -59,10 +69,11 @@ if sys.platform == 'win32': self.start() self.join() return self.info - else: - class _WMIGetter: - pass + else: + + class _WMIGetter: # type: ignore + pass class _RegistryGetter: def __init__(self): @@ -74,13 +85,13 @@ if sys.platform == 'win32': # delimiter in between ' ' and ',' (and vice-versa) in various # versions of windows. # - if entry.find(' ') >= 0: - split_char = ' ' - elif entry.find(',') >= 0: - split_char = ',' + if entry.find(" ") >= 0: + split_char = " " + elif entry.find(",") >= 0: + split_char = "," else: # probably a singleton; treat as a space-separated list. - split_char = ' ' + split_char = " " return split_char def _config_nameservers(self, nameservers): @@ -94,44 +105,44 @@ if sys.platform == 'win32': split_char = self._determine_split_char(search) search_list = search.split(split_char) for s in search_list: - s = dns.name.from_text(s) + s = _config_domain(s) if s not in self.info.search: self.info.search.append(s) def _config_fromkey(self, key, always_try_domain): try: - servers, _ = winreg.QueryValueEx(key, 'NameServer') + servers, _ = winreg.QueryValueEx(key, "NameServer") except WindowsError: servers = None if servers: self._config_nameservers(servers) if servers or always_try_domain: try: - dom, _ = winreg.QueryValueEx(key, 'Domain') + dom, _ = winreg.QueryValueEx(key, "Domain") if dom: self.info.domain = _config_domain(dom) except WindowsError: pass else: try: - servers, _ = winreg.QueryValueEx(key, 'DhcpNameServer') + servers, _ = winreg.QueryValueEx(key, "DhcpNameServer") except WindowsError: servers = None if servers: self._config_nameservers(servers) try: - dom, _ = winreg.QueryValueEx(key, 'DhcpDomain') + dom, _ = winreg.QueryValueEx(key, "DhcpDomain") if dom: self.info.domain = _config_domain(dom) except WindowsError: pass try: - search, _ = winreg.QueryValueEx(key, 'SearchList') + search, _ = winreg.QueryValueEx(key, "SearchList") except WindowsError: search = None if search is None: try: - search, _ = winreg.QueryValueEx(key, 'DhcpSearchList') + search, _ = winreg.QueryValueEx(key, "DhcpSearchList") except WindowsError: search = None if search: @@ -148,25 +159,27 @@ if sys.platform == 'win32': # from Windows 2000 through Vista. connection_key = winreg.OpenKey( lm, - r'SYSTEM\CurrentControlSet\Control\Network' - r'\{4D36E972-E325-11CE-BFC1-08002BE10318}' - r'\%s\Connection' % guid) + r"SYSTEM\CurrentControlSet\Control\Network" + r"\{4D36E972-E325-11CE-BFC1-08002BE10318}" + r"\%s\Connection" % guid, + ) try: # The PnpInstanceID points to a key inside Enum (pnp_id, ttype) = winreg.QueryValueEx( - connection_key, 'PnpInstanceID') + connection_key, "PnpInstanceID" + ) if ttype != winreg.REG_SZ: raise ValueError # pragma: no cover device_key = winreg.OpenKey( - lm, r'SYSTEM\CurrentControlSet\Enum\%s' % pnp_id) + lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id + ) try: # Get ConfigFlags for this device - (flags, ttype) = winreg.QueryValueEx( - device_key, 'ConfigFlags') + (flags, ttype) = winreg.QueryValueEx(device_key, "ConfigFlags") if ttype != winreg.REG_DWORD: raise ValueError # pragma: no cover @@ -192,17 +205,19 @@ if sys.platform == 'win32': lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: - tcp_params = winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters') + tcp_params = winreg.OpenKey( + lm, r"SYSTEM\CurrentControlSet" r"\Services\Tcpip\Parameters" + ) try: self._config_fromkey(tcp_params, True) finally: tcp_params.Close() - interfaces = winreg.OpenKey(lm, - r'SYSTEM\CurrentControlSet' - r'\Services\Tcpip\Parameters' - r'\Interfaces') + interfaces = winreg.OpenKey( + lm, + r"SYSTEM\CurrentControlSet" + r"\Services\Tcpip\Parameters" + r"\Interfaces", + ) try: i = 0 while True: @@ -224,6 +239,7 @@ if sys.platform == 'win32': lm.Close() return self.info + _getter_class: Any if _have_wmi and _prefer_wmi: _getter_class = _WMIGetter else: diff --git a/lib/dns/wire.py b/lib/dns/wire.py index 572e27e7..cadf1686 100644 --- a/lib/dns/wire.py +++ b/lib/dns/wire.py @@ -1,13 +1,16 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Iterator, Optional, Tuple + import contextlib import struct import dns.exception import dns.name + class Parser: - def __init__(self, wire, current=0): + def __init__(self, wire: bytes, current: int = 0): self.wire = wire self.current = 0 self.end = len(self.wire) @@ -15,46 +18,47 @@ class Parser: self.seek(current) self.furthest = current - def remaining(self): + def remaining(self) -> int: return self.end - self.current - def get_bytes(self, size): + def get_bytes(self, size: int) -> bytes: + assert size >= 0 if size > self.remaining(): raise dns.exception.FormError - output = self.wire[self.current:self.current + size] + output = self.wire[self.current : self.current + size] self.current += size self.furthest = max(self.furthest, self.current) return output - def get_counted_bytes(self, length_size=1): - length = int.from_bytes(self.get_bytes(length_size), 'big') + def get_counted_bytes(self, length_size: int = 1) -> bytes: + length = int.from_bytes(self.get_bytes(length_size), "big") return self.get_bytes(length) - def get_remaining(self): + def get_remaining(self) -> bytes: return self.get_bytes(self.remaining()) - def get_uint8(self): - return struct.unpack('!B', self.get_bytes(1))[0] + def get_uint8(self) -> int: + return struct.unpack("!B", self.get_bytes(1))[0] - def get_uint16(self): - return struct.unpack('!H', self.get_bytes(2))[0] + def get_uint16(self) -> int: + return struct.unpack("!H", self.get_bytes(2))[0] - def get_uint32(self): - return struct.unpack('!I', self.get_bytes(4))[0] + def get_uint32(self) -> int: + return struct.unpack("!I", self.get_bytes(4))[0] - def get_uint48(self): - return int.from_bytes(self.get_bytes(6), 'big') + def get_uint48(self) -> int: + return int.from_bytes(self.get_bytes(6), "big") - def get_struct(self, format): + def get_struct(self, format: str) -> Tuple: return struct.unpack(format, self.get_bytes(struct.calcsize(format))) - def get_name(self, origin=None): + def get_name(self, origin: Optional["dns.name.Name"] = None) -> "dns.name.Name": name = dns.name.from_wire_parser(self) if origin: name = name.relativize(origin) return name - def seek(self, where): + def seek(self, where: int) -> None: # Note that seeking to the end is OK! (If you try to read # after such a seek, you'll get an exception as expected.) if where < 0 or where > self.end: @@ -62,7 +66,8 @@ class Parser: self.current = where @contextlib.contextmanager - def restrict_to(self, size): + def restrict_to(self, size: int) -> Iterator: + assert size >= 0 if size > self.remaining(): raise dns.exception.FormError saved_end = self.end @@ -78,7 +83,7 @@ class Parser: self.end = saved_end @contextlib.contextmanager - def restore_furthest(self): + def restore_furthest(self) -> Iterator: try: yield None finally: diff --git a/lib/dns/xfr.py b/lib/dns/xfr.py index cf9a163e..bb165888 100644 --- a/lib/dns/xfr.py +++ b/lib/dns/xfr.py @@ -15,12 +15,17 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +from typing import Any, List, Optional, Tuple, Union + import dns.exception import dns.message import dns.name import dns.rcode import dns.serial +import dns.rdataset import dns.rdatatype +import dns.transaction +import dns.tsig import dns.zone @@ -28,7 +33,7 @@ class TransferError(dns.exception.DNSException): """A zone transfer response got a non-zero rcode.""" def __init__(self, rcode): - message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) + message = "Zone transfer error: %s" % dns.rcode.to_text(rcode) super().__init__(message) self.rcode = rcode @@ -46,8 +51,13 @@ class Inbound: State machine for zone transfers. """ - def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR, - serial=None, is_udp=False): + def __init__( + self, + txn_manager: dns.transaction.TransactionManager, + rdtype: dns.rdatatype.RdataType = dns.rdatatype.AXFR, + serial: Optional[int] = None, + is_udp: bool = False, + ): """Initialize an inbound zone transfer. *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. @@ -61,22 +71,22 @@ class Inbound: XFR. """ self.txn_manager = txn_manager - self.txn = None + self.txn: Optional[dns.transaction.Transaction] = None self.rdtype = rdtype if rdtype == dns.rdatatype.IXFR: if serial is None: - raise ValueError('a starting serial must be supplied for IXFRs') + raise ValueError("a starting serial must be supplied for IXFRs") elif is_udp: - raise ValueError('is_udp specified for AXFR') + raise ValueError("is_udp specified for AXFR") self.serial = serial self.is_udp = is_udp (_, _, self.origin) = txn_manager.origin_information() - self.soa_rdataset = None + self.soa_rdataset: Optional[dns.rdataset.Rdataset] = None self.done = False self.expecting_SOA = False self.delete_mode = False - def process_message(self, message): + def process_message(self, message: dns.message.Message) -> bool: """Process one message in the transfer. The message should have the same relativization as was specified when @@ -107,10 +117,8 @@ class Inbound: # the origin. # if not message.answer or message.answer[0].name != self.origin: - raise dns.exception.FormError("No answer or RRset not " - "for zone origin") + raise dns.exception.FormError("No answer or RRset not for zone origin") rrset = message.answer[0] - name = rrset.name rdataset = rrset if rdataset.rdtype != dns.rdatatype.SOA: raise dns.exception.FormError("first RRset is not an SOA") @@ -122,8 +130,7 @@ class Inbound: # We're already up-to-date. # self.done = True - elif dns.serial.Serial(self.soa_rdataset[0].serial) < \ - self.serial: + elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial: # It went backwards! raise SerialWentBackwards else: @@ -147,8 +154,8 @@ class Inbound: rdataset = rrset if self.done: raise dns.exception.FormError("answers after final SOA") - if rdataset.rdtype == dns.rdatatype.SOA and \ - name == self.origin: + assert self.txn is not None # for mypy + if rdataset.rdtype == dns.rdatatype.SOA and name == self.origin: # # Every time we see an origin SOA delete_mode inverts # @@ -160,20 +167,21 @@ class Inbound: # check that we're seeing the record in the expected # part of the response. # - if rdataset == self.soa_rdataset and \ - (self.rdtype == dns.rdatatype.AXFR or - (self.rdtype == dns.rdatatype.IXFR and - self.delete_mode)): + if rdataset == self.soa_rdataset and ( + self.rdtype == dns.rdatatype.AXFR + or (self.rdtype == dns.rdatatype.IXFR and self.delete_mode) + ): # # This is the final SOA # if self.expecting_SOA: # We got an empty IXFR sequence! - raise dns.exception.FormError('empty IXFR sequence') - if self.rdtype == dns.rdatatype.IXFR \ - and self.serial != rdataset[0].serial: - raise dns.exception.FormError('unexpected end of IXFR ' - 'sequence') + raise dns.exception.FormError("empty IXFR sequence") + if ( + self.rdtype == dns.rdatatype.IXFR + and self.serial != rdataset[0].serial + ): + raise dns.exception.FormError("unexpected end of IXFR sequence") self.txn.replace(name, rdataset) self.txn.commit() self.txn = None @@ -188,15 +196,15 @@ class Inbound: # This is the start of an IXFR deletion set if rdataset[0].serial != self.serial: raise dns.exception.FormError( - "IXFR base serial mismatch") + "IXFR base serial mismatch" + ) else: # This is the start of an IXFR addition set self.serial = rdataset[0].serial self.txn.replace(name, rdataset) else: # We saw a non-final SOA for the origin in an AXFR. - raise dns.exception.FormError('unexpected origin SOA ' - 'in AXFR') + raise dns.exception.FormError("unexpected origin SOA in AXFR") continue if self.expecting_SOA: # @@ -223,7 +231,7 @@ class Inbound: # This is a UDP IXFR and we didn't get to done, and we didn't # get the proper "truncated" response # - raise dns.exception.FormError('unexpected end of UDP IXFR') + raise dns.exception.FormError("unexpected end of UDP IXFR") return self.done # @@ -239,11 +247,18 @@ class Inbound: return False -def make_query(txn_manager, serial=0, - use_edns=None, ednsflags=None, payload=None, - request_payload=None, options=None, - keyring=None, keyname=None, - keyalgorithm=dns.tsig.default_algorithm): +def make_query( + txn_manager: dns.transaction.TransactionManager, + serial: Optional[int] = 0, + use_edns: Optional[Union[int, bool]] = None, + ednsflags: Optional[int] = None, + payload: Optional[int] = None, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + keyring: Any = None, + keyname: Optional[dns.name.Name] = None, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, +) -> Tuple[dns.message.QueryMessage, Optional[int]]: """Make an AXFR or IXFR query. *txn_manager* is a ``dns.transaction.TransactionManager``, typically a @@ -264,13 +279,15 @@ def make_query(txn_manager, serial=0, Returns a `(query, serial)` tuple. """ (zone_origin, _, origin) = txn_manager.origin_information() + if zone_origin is None: + raise ValueError("no zone origin") if serial is None: rdtype = dns.rdatatype.AXFR elif not isinstance(serial, int): - raise ValueError('serial is not an integer') + raise ValueError("serial is not an integer") elif serial == 0: with txn_manager.reader() as txn: - rdataset = txn.get(origin, 'SOA') + rdataset = txn.get(origin, "SOA") if rdataset: serial = rdataset[0].serial rdtype = dns.rdatatype.IXFR @@ -280,34 +297,47 @@ def make_query(txn_manager, serial=0, elif serial > 0 and serial < 4294967296: rdtype = dns.rdatatype.IXFR else: - raise ValueError('serial out-of-range') + raise ValueError("serial out-of-range") rdclass = txn_manager.get_class() - q = dns.message.make_query(zone_origin, rdtype, rdclass, - use_edns, False, ednsflags, payload, - request_payload, options) + q = dns.message.make_query( + zone_origin, + rdtype, + rdclass, + use_edns, + False, + ednsflags, + payload, + request_payload, + options, + ) if serial is not None: - rdata = dns.rdata.from_text(rdclass, 'SOA', f'. . {serial} 0 0 0 0') - rrset = q.find_rrset(q.authority, zone_origin, rdclass, - dns.rdatatype.SOA, create=True) + rdata = dns.rdata.from_text(rdclass, "SOA", f". . {serial} 0 0 0 0") + rrset = q.find_rrset( + q.authority, zone_origin, rdclass, dns.rdatatype.SOA, create=True + ) rrset.add(rdata, 0) if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) return (q, serial) -def extract_serial_from_query(query): + +def extract_serial_from_query(query: dns.message.Message) -> Optional[int]: """Extract the SOA serial number from query if it is an IXFR and return it, otherwise return None. *query* is a dns.message.QueryMessage that is an IXFR or AXFR request. Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have - an appropriate SOA RRset in the authority section.""" - + an appropriate SOA RRset in the authority section. + """ + if not isinstance(query, dns.message.QueryMessage): + raise ValueError("query not a QueryMessage") question = query.question[0] if question.rdtype == dns.rdatatype.AXFR: return None elif question.rdtype != dns.rdatatype.IXFR: raise ValueError("query is not an AXFR or IXFR") - soa = query.find_rrset(query.authority, question.name, question.rdclass, - dns.rdatatype.SOA) + soa = query.find_rrset( + query.authority, question.name, question.rdclass, dns.rdatatype.SOA + ) return soa[0].serial diff --git a/lib/dns/zone.py b/lib/dns/zone.py index 5a649404..cc8268da 100644 --- a/lib/dns/zone.py +++ b/lib/dns/zone.py @@ -17,8 +17,9 @@ """DNS Zones.""" +from typing import Any, Dict, Iterator, Iterable, List, Optional, Set, Tuple, Union + import contextlib -import hashlib import io import os import struct @@ -30,6 +31,7 @@ import dns.node import dns.rdataclass import dns.rdatatype import dns.rdata +import dns.rdataset import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.ZONEMD import dns.rrset @@ -38,6 +40,7 @@ import dns.transaction import dns.ttl import dns.grange import dns.zonefile +from dns.zonetypes import DigestScheme, DigestHashAlgorithm, _digest_hashers class BadZone(dns.exception.DNSException): @@ -80,33 +83,6 @@ class DigestVerificationFailure(dns.exception.DNSException): """The ZONEMD digest failed to verify.""" -class DigestScheme(dns.enum.IntEnum): - """ZONEMD Scheme""" - - SIMPLE = 1 - - @classmethod - def _maximum(cls): - return 255 - - -class DigestHashAlgorithm(dns.enum.IntEnum): - """ZONEMD Hash Algorithm""" - - SHA384 = 1 - SHA512 = 2 - - @classmethod - def _maximum(cls): - return 255 - - -_digest_hashers = { - DigestHashAlgorithm.SHA384: hashlib.sha384, - DigestHashAlgorithm.SHA512: hashlib.sha512, -} - - class Zone(dns.transaction.TransactionManager): """A DNS zone. @@ -121,9 +97,14 @@ class Zone(dns.transaction.TransactionManager): node_factory = dns.node.Node - __slots__ = ['rdclass', 'origin', 'nodes', 'relativize'] + __slots__ = ["rdclass", "origin", "nodes", "relativize"] - def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True): + def __init__( + self, + origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + ): """Initialize a zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -140,13 +121,12 @@ class Zone(dns.transaction.TransactionManager): if isinstance(origin, str): origin = dns.name.from_text(origin) elif not isinstance(origin, dns.name.Name): - raise ValueError("origin parameter must be convertible to a " - "DNS name") + raise ValueError("origin parameter must be convertible to a DNS name") if not origin.is_absolute(): raise ValueError("origin parameter must be an absolute name") self.origin = origin self.rdclass = rdclass - self.nodes = {} + self.nodes: Dict[dns.name.Name, dns.node.Node] = {} self.relativize = relativize def __eq__(self, other): @@ -158,9 +138,11 @@ class Zone(dns.transaction.TransactionManager): if not isinstance(other, Zone): return False - if self.rdclass != other.rdclass or \ - self.origin != other.origin or \ - self.nodes != other.nodes: + if ( + self.rdclass != other.rdclass + or self.origin != other.origin + or self.nodes != other.nodes + ): return False return True @@ -172,21 +154,25 @@ class Zone(dns.transaction.TransactionManager): return not self.__eq__(other) - def _validate_name(self, name): + def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: if isinstance(name, str): name = dns.name.from_text(name, None) elif not isinstance(name, dns.name.Name): raise KeyError("name parameter must be convertible to a DNS name") if name.is_absolute(): + if self.origin is None: + # This should probably never happen as other code (e.g. + # _rr_line) will notice the lack of an origin before us, but + # we check just in case! + raise KeyError("no zone origin is defined") if not name.is_subdomain(self.origin): - raise KeyError( - "name parameter must be a subdomain of the zone origin") + raise KeyError("name parameter must be a subdomain of the zone origin") if self.relativize: name = name.relativize(self.origin) elif not self.relativize: # We have a relative name in a non-relative zone, so derelativize. if self.origin is None: - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") name = name.derelativize(self.origin) return name @@ -222,7 +208,9 @@ class Zone(dns.transaction.TransactionManager): key = self._validate_name(key) return key in self.nodes - def find_node(self, name, create=False): + def find_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> dns.node.Node: """Find a node in the zone, possibly creating it. *name*: the name of the node to find. @@ -248,7 +236,9 @@ class Zone(dns.transaction.TransactionManager): self.nodes[name] = node return node - def get_node(self, name, create=False): + def get_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> Optional[dns.node.Node]: """Get a node in the zone, possibly creating it. This method is like ``find_node()``, except it returns None instead @@ -275,7 +265,7 @@ class Zone(dns.transaction.TransactionManager): node = None return node - def delete_node(self, name): + def delete_node(self, name: Union[dns.name.Name, str]) -> None: """Delete the specified node if it exists. *name*: the name of the node to find. @@ -290,8 +280,13 @@ class Zone(dns.transaction.TransactionManager): if name in self.nodes: del self.nodes[name] - def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: """Look for an rdataset with the specified name and type in the zone, and return an rdataset encapsulating it. @@ -305,9 +300,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -326,15 +321,19 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rdataset.Rdataset``. """ - name = self._validate_name(name) - rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) - node = self.find_node(name, create) - return node.find_rdataset(self.rdclass, rdtype, covers, create) + the_name = self._validate_name(name) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_covers = dns.rdatatype.RdataType.make(covers) + node = self.find_node(the_name, create) + return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create) - def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: """Look for an rdataset with the specified name and type in the zone. This method is like ``find_rdataset()``, except it returns None instead @@ -349,9 +348,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -376,45 +375,47 @@ class Zone(dns.transaction.TransactionManager): rdataset = None return rdataset - def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> None: """Delete the rdataset matching *rdtype* and *covers*, if it exists at the node specified by *name*. - It is not an error if the node does not exist, or if there is no - matching rdataset at the node. + It is not an error if the node does not exist, or if there is no matching + rdataset at the node. - If the node has no rdatasets after the deletion, it will itself - be deleted. + If the node has no rdatasets after the deletion, it will itself be deleted. - *name*: the name of the node to find. - The value may be a ``dns.name.Name`` or a ``str``. If absolute, the - name must be a subdomain of the zone's origin. If ``zone.relativize`` - is ``True``, then the name will be relativized. + *name*: the name of the node to find. The value may be a ``dns.name.Name`` or a + ``str``. If absolute, the name must be a subdomain of the zone's origin. If + ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. - Usually this value is ``dns.rdatatype.NONE``, but if the - rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, - then the covers value will be the rdata type the SIG/RRSIG - covers. The library treats the SIG and RRSIG types as if they - were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). - This makes RRSIGs much easier to work with than if RRSIGs - covering different rdata types were aggregated into a single - RRSIG rdataset. + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered + type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is + ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be + the rdata type the SIG/RRSIG covers. The library treats the SIG and RRSIG types + as if they were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This + makes RRSIGs much easier to work with than if RRSIGs covering different rdata + types were aggregated into a single RRSIG rdataset. """ - name = self._validate_name(name) - rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) - node = self.get_node(name) + the_name = self._validate_name(name) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_covers = dns.rdatatype.RdataType.make(covers) + node = self.get_node(the_name) if node is not None: - node.delete_rdataset(self.rdclass, rdtype, covers) + node.delete_rdataset(self.rdclass, the_rdtype, the_covers) if len(node) == 0: - self.delete_node(name) + self.delete_node(the_name) - def replace_rdataset(self, name, replacement): + def replace_rdataset( + self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset + ) -> None: """Replace an rdataset at name. It is not an error if there is no rdataset matching I{replacement}. @@ -434,11 +435,16 @@ class Zone(dns.transaction.TransactionManager): """ if replacement.rdclass != self.rdclass: - raise ValueError('replacement.rdclass != zone.rdclass') + raise ValueError("replacement.rdclass != zone.rdclass") node = self.find_node(name, True) node.replace_rdataset(replacement) - def find_rrset(self, name, rdtype, covers=dns.rdatatype.NONE): + def find_rrset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> dns.rrset.RRset: """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. @@ -456,9 +462,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -477,16 +483,20 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rrset.RRset`` or ``None``. """ - name = self._validate_name(name) - rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) - rdataset = self.nodes[name].find_rdataset(self.rdclass, rdtype, covers) - rrset = dns.rrset.RRset(name, self.rdclass, rdtype, covers) + vname = self._validate_name(name) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_covers = dns.rdatatype.RdataType.make(covers) + rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers) + rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers) rrset.update(rdataset) return rrset - def get_rrset(self, name, rdtype, covers=dns.rdatatype.NONE): + def get_rrset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Optional[dns.rrset.RRset]: """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. @@ -503,9 +513,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -530,16 +540,19 @@ class Zone(dns.transaction.TransactionManager): rrset = None return rrset - def iterate_rdatasets(self, rdtype=dns.rdatatype.ANY, - covers=dns.rdatatype.NONE): + def iterate_rdatasets( + self, + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY, + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: """Return a generator which yields (name, rdataset) tuples for all rdatasets in the zone which have the specified *rdtype* and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, then all rdatasets will be matched. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -551,24 +564,27 @@ class Zone(dns.transaction.TransactionManager): """ rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) + covers = dns.rdatatype.RdataType.make(covers) for (name, node) in self.items(): for rds in node: - if rdtype == dns.rdatatype.ANY or \ - (rds.rdtype == rdtype and rds.covers == covers): + if rdtype == dns.rdatatype.ANY or ( + rds.rdtype == rdtype and rds.covers == covers + ): yield (name, rds) - def iterate_rdatas(self, rdtype=dns.rdatatype.ANY, - covers=dns.rdatatype.NONE): + def iterate_rdatas( + self, + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY, + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]: """Return a generator which yields (name, ttl, rdata) tuples for all rdatas in the zone which have the specified *rdtype* and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, then all rdatas will be matched. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -580,17 +596,24 @@ class Zone(dns.transaction.TransactionManager): """ rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) + covers = dns.rdatatype.RdataType.make(covers) for (name, node) in self.items(): for rds in node: - if rdtype == dns.rdatatype.ANY or \ - (rds.rdtype == rdtype and rds.covers == covers): + if rdtype == dns.rdatatype.ANY or ( + rds.rdtype == rdtype and rds.covers == covers + ): for rdata in rds: yield (name, rds.ttl, rdata) - def to_file(self, f, sorted=True, relativize=True, nl=None, - want_comments=False, want_origin=False): + def to_file( + self, + f: Any, + sorted: bool = True, + relativize: bool = True, + nl: Optional[str] = None, + want_comments: bool = False, + want_origin: bool = False, + ) -> None: """Write a zone to a file. *f*, a file or `str`. If *f* is a string, it is treated @@ -618,20 +641,21 @@ class Zone(dns.transaction.TransactionManager): one. """ - with contextlib.ExitStack() as stack: - if isinstance(f, str): - f = stack.enter_context(open(f, 'wb')) - + if isinstance(f, str): + cm: contextlib.AbstractContextManager = open(f, "wb") + else: + cm = contextlib.nullcontext(f) + with cm as f: # must be in this way, f.encoding may contain None, or even # attribute may not be there - file_enc = getattr(f, 'encoding', None) + file_enc = getattr(f, "encoding", None) if file_enc is None: - file_enc = 'utf-8' + file_enc = "utf-8" if nl is None: # binary mode, '\n' is not enough nl_b = os.linesep.encode(file_enc) - nl = '\n' + nl = "\n" elif isinstance(nl, str): nl_b = nl.encode(file_enc) else: @@ -639,7 +663,8 @@ class Zone(dns.transaction.TransactionManager): nl = nl.decode() if want_origin: - l = '$ORIGIN ' + self.origin.to_text() + assert self.origin is not None + l = "$ORIGIN " + self.origin.to_text() l_b = l.encode(file_enc) try: f.write(l_b) @@ -654,9 +679,12 @@ class Zone(dns.transaction.TransactionManager): else: names = self.keys() for n in names: - l = self[n].to_text(n, origin=self.origin, - relativize=relativize, - want_comments=want_comments) + l = self[n].to_text( + n, + origin=self.origin, + relativize=relativize, + want_comments=want_comments, + ) l_b = l.encode(file_enc) try: @@ -666,8 +694,14 @@ class Zone(dns.transaction.TransactionManager): f.write(l) f.write(nl) - def to_text(self, sorted=True, relativize=True, nl=None, - want_comments=False, want_origin=False): + def to_text( + self, + sorted: bool = True, + relativize: bool = True, + nl: Optional[str] = None, + want_comments: bool = False, + want_origin: bool = False, + ) -> str: """Return a zone's text as though it were written to a file. *sorted*, a ``bool``. If True, the default, then the file @@ -694,13 +728,12 @@ class Zone(dns.transaction.TransactionManager): Returns a ``str``. """ temp_buffer = io.StringIO() - self.to_file(temp_buffer, sorted, relativize, nl, want_comments, - want_origin) + self.to_file(temp_buffer, sorted, relativize, nl, want_comments, want_origin) return_value = temp_buffer.getvalue() temp_buffer.close() return return_value - def check_origin(self): + def check_origin(self) -> None: """Do some simple checking of the zone's origin. Raises ``dns.zone.NoSOA`` if there is no SOA RRset. @@ -712,13 +745,44 @@ class Zone(dns.transaction.TransactionManager): if self.relativize: name = dns.name.empty else: + assert self.origin is not None name = self.origin if self.get_rdataset(name, dns.rdatatype.SOA) is None: raise NoSOA if self.get_rdataset(name, dns.rdatatype.NS) is None: raise NoNS - def _compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): + def get_soa( + self, txn: Optional[dns.transaction.Transaction] = None + ) -> dns.rdtypes.ANY.SOA.SOA: + """Get the zone SOA rdata. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Returns a ``dns.rdtypes.ANY.SOA.SOA`` Rdata. + """ + if self.relativize: + origin_name = dns.name.empty + else: + if self.origin is None: + # get_soa() has been called very early, and there must not be + # an SOA if there is no origin. + raise NoSOA + origin_name = self.origin + soa: Optional[dns.rdataset.Rdataset] + if txn: + soa = txn.get(origin_name, dns.rdatatype.SOA) + else: + soa = self.get_rdataset(origin_name, dns.rdatatype.SOA) + if soa is None: + raise NoSOA + return soa[0] + + def _compute_digest( + self, + hash_algorithm: DigestHashAlgorithm, + scheme: DigestScheme = DigestScheme.SIMPLE, + ) -> bytes: hashinfo = _digest_hashers.get(hash_algorithm) if not hashinfo: raise UnsupportedDigestHashAlgorithm @@ -728,47 +792,52 @@ class Zone(dns.transaction.TransactionManager): if self.relativize: origin_name = dns.name.empty else: + assert self.origin is not None origin_name = self.origin hasher = hashinfo() for (name, node) in sorted(self.items()): rrnamebuf = name.to_digestable(self.origin) - for rdataset in sorted(node, - key=lambda rds: (rds.rdtype, rds.covers)): - if name == origin_name and \ - dns.rdatatype.ZONEMD in (rdataset.rdtype, rdataset.covers): + for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)): + if name == origin_name and dns.rdatatype.ZONEMD in ( + rdataset.rdtype, + rdataset.covers, + ): continue - rrfixed = struct.pack('!HHI', rdataset.rdtype, - rdataset.rdclass, rdataset.ttl) - rdatas = [rdata.to_digestable(self.origin) - for rdata in rdataset] + rrfixed = struct.pack( + "!HHI", rdataset.rdtype, rdataset.rdclass, rdataset.ttl + ) + rdatas = [rdata.to_digestable(self.origin) for rdata in rdataset] for rdata in sorted(rdatas): - rrlen = struct.pack('!H', len(rdata)) + rrlen = struct.pack("!H", len(rdata)) hasher.update(rrnamebuf + rrfixed + rrlen + rdata) return hasher.digest() - def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): - if self.relativize: - origin_name = dns.name.empty - else: - origin_name = self.origin - serial = self.get_rdataset(origin_name, dns.rdatatype.SOA)[0].serial + def compute_digest( + self, + hash_algorithm: DigestHashAlgorithm, + scheme: DigestScheme = DigestScheme.SIMPLE, + ) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: + serial = self.get_soa().serial digest = self._compute_digest(hash_algorithm, scheme) - return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass, - dns.rdatatype.ZONEMD, - serial, scheme, hash_algorithm, - digest) + return dns.rdtypes.ANY.ZONEMD.ZONEMD( + self.rdclass, dns.rdatatype.ZONEMD, serial, scheme, hash_algorithm, digest + ) - def verify_digest(self, zonemd=None): + def verify_digest( + self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD] = None + ) -> None: + digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]] if zonemd: digests = [zonemd] else: - digests = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) - if digests is None: + assert self.origin is not None + rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) + if rds is None: raise NoDigest + digests = rds for digest in digests: try: - computed = self._compute_digest(digest.hash_algorithm, - digest.scheme) + computed = self._compute_digest(digest.hash_algorithm, digest.scheme) if computed == digest.digest: return except Exception: @@ -777,16 +846,18 @@ class Zone(dns.transaction.TransactionManager): # TransactionManager methods - def reader(self): - return Transaction(self, False, - Version(self, 1, self.nodes, self.origin)) + def reader(self) -> "Transaction": + return Transaction(self, False, Version(self, 1, self.nodes, self.origin)) - def writer(self, replacement=False): + def writer(self, replacement: bool = False) -> "Transaction": txn = Transaction(self, replacement) txn._setup_version() return txn - def origin_information(self): + def origin_information( + self, + ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: + effective: Optional[dns.name.Name] if self.relativize: effective = dns.name.empty else: @@ -821,8 +892,9 @@ class Zone(dns.transaction.TransactionManager): # A node with a version id. -class VersionedNode(dns.node.Node): - __slots__ = ['id'] + +class VersionedNode(dns.node.Node): # lgtm[py/missing-equals] + __slots__ = ["id"] def __init__(self): super().__init__() @@ -832,8 +904,6 @@ class VersionedNode(dns.node.Node): @dns.immutable.immutable class ImmutableVersionedNode(VersionedNode): - __slots__ = ['id'] - def __init__(self, node): super().__init__() self.id = node.id @@ -841,30 +911,51 @@ class ImmutableVersionedNode(VersionedNode): [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] ) - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: if create: raise TypeError("immutable") return super().find_rdataset(rdclass, rdtype, covers, False) - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: if create: raise TypeError("immutable") return super().get_rdataset(rdclass, rdtype, covers, False) - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: raise TypeError("immutable") - def replace_rdataset(self, replacement): + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: raise TypeError("immutable") - def is_immutable(self): + def is_immutable(self) -> bool: return True class Version: - def __init__(self, zone, id, nodes=None, origin=None): + def __init__( + self, + zone: Zone, + id: int, + nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None, + origin: Optional[dns.name.Name] = None, + ): self.zone = zone self.id = id if nodes is not None: @@ -873,13 +964,13 @@ class Version: self.nodes = {} self.origin = origin - def _validate_name(self, name): + def _validate_name(self, name: dns.name.Name) -> dns.name.Name: if name.is_absolute(): if self.origin is None: # This should probably never happen as other code (e.g. # _rr_line) will notice the lack of an origin before us, but # we check just in case! - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") if not name.is_subdomain(self.origin): raise KeyError("name is not a subdomain of the zone origin") if self.zone.relativize: @@ -887,15 +978,20 @@ class Version: elif not self.zone.relativize: # We have a relative name in a non-relative zone, so derelativize. if self.origin is None: - raise KeyError('no zone origin is defined') + raise KeyError("no zone origin is defined") name = name.derelativize(self.origin) return name - def get_node(self, name): + def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: name = self._validate_name(name) return self.nodes.get(name) - def get_rdataset(self, name, rdtype, covers): + def get_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> Optional[dns.rdataset.Rdataset]: node = self.get_node(name) if node is None: return None @@ -906,7 +1002,7 @@ class Version: class WritableVersion(Version): - def __init__(self, zone, replacement=False): + def __init__(self, zone: Zone, replacement: bool = False): # The zone._versions_lock must be held by our caller in a versioned # zone. id = zone._get_next_version_id() @@ -920,19 +1016,21 @@ class WritableVersion(Version): # We have to copy the zone origin as it may be None in the first # version, and we don't want to mutate the zone until we commit. self.origin = zone.origin - self.changed = set() + self.changed: Set[dns.name.Name] = set() - def _maybe_cow(self, name): + def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node: name = self._validate_name(name) node = self.nodes.get(name) if node is None or name not in self.changed: new_node = self.zone.node_factory() - if hasattr(new_node, 'id'): + if hasattr(new_node, "id"): # We keep doing this for backwards compatibility, as earlier # code used new_node.id != self.id for the "do we need to CoW?" # test. Now we use the changed set as this works with both # regular zones and versioned zones. - new_node.id = self.id + # + # We ignore the mypy error as this is safe but it doesn't see it. + new_node.id = self.id # type: ignore if node is not None: # moo! copy on write! new_node.rdatasets.extend(node.rdatasets) @@ -942,17 +1040,24 @@ class WritableVersion(Version): else: return node - def delete_node(self, name): + def delete_node(self, name: dns.name.Name) -> None: name = self._validate_name(name) if name in self.nodes: del self.nodes[name] self.changed.add(name) - def put_rdataset(self, name, rdataset): + def put_rdataset( + self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset + ) -> None: node = self._maybe_cow(name) node.replace_rdataset(rdataset) - def delete_rdataset(self, name, rdtype, covers): + def delete_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> None: node = self._maybe_cow(name) node.delete_rdataset(self.zone.rdclass, rdtype, covers) if len(node) == 0: @@ -961,7 +1066,7 @@ class WritableVersion(Version): @dns.immutable.immutable class ImmutableVersion(Version): - def __init__(self, version): + def __init__(self, version: WritableVersion): # We tell super() that it's a replacement as we don't want it # to copy the nodes, as we're about to do that with an # immutable Dict. @@ -976,11 +1081,12 @@ class ImmutableVersion(Version): # it might not exist if we deleted it in the version if node: version.nodes[name] = ImmutableVersionedNode(node) - self.nodes = dns.immutable.Dict(version.nodes, True) + # We're changing the type of the nodes dictionary here on purpose, so + # we ignore the mypy error. + self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore class Transaction(dns.transaction.Transaction): - def __init__(self, zone, replacement, version=None, make_immutable=False): read_only = version is not None super().__init__(zone, replacement, read_only) @@ -1057,9 +1163,18 @@ class Transaction(dns.transaction.Transaction): return (absolute, relativize, effective) -def from_text(text, origin=None, rdclass=dns.rdataclass.IN, - relativize=True, zone_factory=Zone, filename=None, - allow_include=False, check_origin=True, idna_codec=None): +def from_text( + text: str, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = False, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, + allow_directives: Union[bool, Iterable[str]] = True, +) -> Zone: """Build a zone object from a zone file format string. *text*, a ``str``, the zone file format input. @@ -1068,7 +1183,8 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, of the zone; if not specified, the first ``$ORIGIN`` statement in the zone file will determine the origin of the zone. - *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is + class IN. *relativize*, a ``bool``, determine's whether domain names are relativized to the zone's origin. The default is ``True``. @@ -1092,6 +1208,13 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder is used. + *allow_directives*, a ``bool`` or an iterable of `str`. If ``True``, the default, + then directives are permitted, and the *allow_include* parameter controls whether + ``$INCLUDE`` is permitted. If ``False`` or an empty iterable, then no directive + processing is done and any directive-like text will be treated as a regular owner + name. If a non-empty iterable, then only the listed directives (including the + ``$``) are allowed. + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. Raises ``dns.zone.NoNS`` if there is no NS RRset. @@ -1106,12 +1229,17 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, # interface is from_file(). if filename is None: - filename = '' + filename = "" zone = zone_factory(origin, rdclass, relativize=relativize) with zone.writer(True) as txn: tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) - reader = dns.zonefile.Reader(tok, rdclass, txn, - allow_include=allow_include) + reader = dns.zonefile.Reader( + tok, + rdclass, + txn, + allow_include=allow_include, + allow_directives=allow_directives, + ) try: reader.read() except dns.zonefile.UnknownOrigin: @@ -1123,9 +1251,18 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, return zone -def from_file(f, origin=None, rdclass=dns.rdataclass.IN, - relativize=True, zone_factory=Zone, filename=None, - allow_include=True, check_origin=True): +def from_file( + f: Any, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = True, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, + allow_directives: Union[bool, Iterable[str]] = True, +) -> Zone: """Read a zone file and build a zone object. *f*, a file or ``str``. If *f* is a string, it is treated @@ -1159,6 +1296,13 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN, encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder is used. + *allow_directives*, a ``bool`` or an iterable of `str`. If ``True``, the default, + then directives are permitted, and the *allow_include* parameter controls whether + ``$INCLUDE`` is permitted. If ``False`` or an empty iterable, then no directive + processing is done and any directive-like text will be treated as a regular owner + name. If a non-empty iterable, then only the listed directives (including the + ``$``) are allowed. + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. Raises ``dns.zone.NoNS`` if there is no NS RRset. @@ -1168,16 +1312,34 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN, Returns a subclass of ``dns.zone.Zone``. """ - with contextlib.ExitStack() as stack: - if isinstance(f, str): - if filename is None: - filename = f - f = stack.enter_context(open(f)) - return from_text(f, origin, rdclass, relativize, zone_factory, - filename, allow_include, check_origin) + if isinstance(f, str): + if filename is None: + filename = f + cm: contextlib.AbstractContextManager = open(f) + else: + cm = contextlib.nullcontext(f) + with cm as f: + return from_text( + f, + origin, + rdclass, + relativize, + zone_factory, + filename, + allow_include, + check_origin, + idna_codec, + allow_directives, + ) + assert False # make mypy happy lgtm[py/unreachable-statement] -def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): +def from_xfr( + xfr: Any, + zone_factory: Any = Zone, + relativize: bool = True, + check_origin: bool = True, +) -> Zone: """Convert the output of a zone transfer generator into a zone object. *xfr*, a generator of ``dns.message.Message`` objects, typically @@ -1198,6 +1360,8 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): Raises ``KeyError`` if there is no origin node. + Raises ``ValueError`` if no messages are yielded by the generator. + Returns a subclass of ``dns.zone.Zone``. """ @@ -1215,11 +1379,12 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): if not znode: znode = z.node_factory() z.nodes[rrset.name] = znode - zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype, - rrset.covers, True) + zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype, rrset.covers, True) zrds.update_ttl(rrset.ttl) for rd in rrset: zrds.add(rd) + if z is None: + raise ValueError("empty transfer") if check_origin: z.check_origin() return z diff --git a/lib/dns/zone.pyi b/lib/dns/zone.pyi deleted file mode 100644 index 272814fe..00000000 --- a/lib/dns/zone.pyi +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Generator, Optional, Union, Tuple, Iterable, Callable, Any, Iterator, TextIO, BinaryIO, Dict -from . import rdata, zone, rdataclass, name, rdataclass, message, rdatatype, exception, node, rdataset, rrset, rdatatype - -class BadZone(exception.DNSException): ... -class NoSOA(BadZone): ... -class NoNS(BadZone): ... -class UnknownOrigin(BadZone): ... - -class Zone: - def __getitem__(self, key : str) -> node.Node: - ... - def __init__(self, origin : Union[str,name.Name], rdclass : int = rdataclass.IN, relativize : bool = True) -> None: - self.nodes : Dict[str,node.Node] - self.origin = origin - def values(self): - return self.nodes.values() - def iterate_rdatas(self, rdtype : Union[int,str] = rdatatype.ANY, covers : Union[int,str] = None) -> Iterable[Tuple[name.Name, int, rdata.Rdata]]: - ... - def __iter__(self) -> Iterator[str]: - ... - def get_node(self, name : Union[name.Name,str], create=False) -> Optional[node.Node]: - ... - def find_rrset(self, name : Union[str,name.Name], rdtype : Union[int,str], covers=rdatatype.NONE) -> rrset.RRset: - ... - def find_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, - create=False) -> rdataset.Rdataset: - ... - def get_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, create=False) -> Optional[rdataset.Rdataset]: - ... - def get_rrset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> Optional[rrset.RRset]: - ... - def replace_rdataset(self, name : Union[str,name.Name], replacement : rdataset.Rdataset) -> None: - ... - def delete_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> None: - ... - def iterate_rdatasets(self, rdtype : Union[str,int] =rdatatype.ANY, - covers : Union[str,int] =rdatatype.NONE): - ... - def to_file(self, f : Union[TextIO, BinaryIO, str], sorted=True, relativize=True, nl : Optional[bytes] = None): - ... - def to_text(self, sorted=True, relativize=True, nl : Optional[str] = None) -> str: - ... - -def from_xfr(xfr : Generator[Any,Any,message.Message], zone_factory : Callable[..., zone.Zone] = zone.Zone, relativize=True, check_origin=True): - ... - -def from_text(text : str, origin : Optional[Union[str,name.Name]] = None, rdclass : int = rdataclass.IN, - relativize=True, zone_factory : Callable[...,zone.Zone] = zone.Zone, filename : Optional[str] = None, - allow_include=False, check_origin=True) -> zone.Zone: - ... - -def from_file(f, origin : Optional[Union[str,name.Name]] = None, rdclass=rdataclass.IN, - relativize=True, zone_factory : Callable[..., zone.Zone] = Zone, filename : Optional[str] = None, - allow_include=True, check_origin=True) -> zone.Zone: - ... diff --git a/lib/dns/zonefile.py b/lib/dns/zonefile.py index 53b40880..1a53f5bc 100644 --- a/lib/dns/zonefile.py +++ b/lib/dns/zonefile.py @@ -17,6 +17,8 @@ """DNS Zones.""" +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + import re import sys @@ -49,29 +51,60 @@ def _check_cname_and_other_data(txn, name, rdataset): # empty nodes are neutral. return node_kind = node.classify() - if node_kind == dns.node.NodeKind.CNAME and \ - rdataset_kind == dns.node.NodeKind.REGULAR: - raise CNAMEAndOtherData('rdataset type is not compatible with a ' - 'CNAME node') - elif node_kind == dns.node.NodeKind.REGULAR and \ - rdataset_kind == dns.node.NodeKind.CNAME: - raise CNAMEAndOtherData('CNAME rdataset is not compatible with a ' - 'regular data node') + if ( + node_kind == dns.node.NodeKind.CNAME + and rdataset_kind == dns.node.NodeKind.REGULAR + ): + raise CNAMEAndOtherData("rdataset type is not compatible with a CNAME node") + elif ( + node_kind == dns.node.NodeKind.REGULAR + and rdataset_kind == dns.node.NodeKind.CNAME + ): + raise CNAMEAndOtherData( + "CNAME rdataset is not compatible with a regular data node" + ) # Otherwise at least one of the node and the rdataset is neutral, so # adding the rdataset is ok +SavedStateType = Tuple[ + dns.tokenizer.Tokenizer, + Optional[dns.name.Name], # current_origin + Optional[dns.name.Name], # last_name + Optional[Any], # current_file + int, # last_ttl + bool, # last_ttl_known + int, # default_ttl + bool, +] # default_ttl_known + + +def _upper_dollarize(s): + s = s.upper() + if not s.startswith("$"): + s = "$" + s + return s + + class Reader: """Read a DNS zone file into a transaction.""" - def __init__(self, tok, rdclass, txn, allow_include=False, - allow_directives=True, force_name=None, - force_ttl=None, force_rdclass=None, force_rdtype=None, - default_ttl=None): + def __init__( + self, + tok: dns.tokenizer.Tokenizer, + rdclass: dns.rdataclass.RdataClass, + txn: dns.transaction.Transaction, + allow_include: bool = False, + allow_directives: Union[bool, Iterable[str]] = True, + force_name: Optional[dns.name.Name] = None, + force_ttl: Optional[int] = None, + force_rdclass: Optional[dns.rdataclass.RdataClass] = None, + force_rdtype: Optional[dns.rdatatype.RdataType] = None, + default_ttl: Optional[int] = None, + ): self.tok = tok - (self.zone_origin, self.relativize, _) = \ - txn.manager.origin_information() + (self.zone_origin, self.relativize, _) = txn.manager.origin_information() self.current_origin = self.zone_origin self.last_ttl = 0 self.last_ttl_known = False @@ -86,10 +119,21 @@ class Reader: self.last_name = self.current_origin self.zone_rdclass = rdclass self.txn = txn - self.saved_state = [] - self.current_file = None - self.allow_include = allow_include - self.allow_directives = allow_directives + self.saved_state: List[SavedStateType] = [] + self.current_file: Optional[Any] = None + self.allowed_directives: Set[str] + if allow_directives is True: + self.allowed_directives = {"$GENERATE", "$ORIGIN", "$TTL"} + if allow_include: + self.allowed_directives.add("$INCLUDE") + elif allow_directives is False: + # allow_include was ignored in earlier releases if allow_directives was + # False, so we continue that. + self.allowed_directives = set() + else: + # Note that if directives are explicitly specified, then allow_include + # is ignored. + self.allowed_directives = set(_upper_dollarize(d) for d in allow_directives) self.force_name = force_name self.force_ttl = force_ttl self.force_rdclass = force_rdclass @@ -176,13 +220,17 @@ class Reader: try: rdtype = dns.rdatatype.from_text(token.value) except Exception: - raise dns.exception.SyntaxError( - "unknown rdatatype '%s'" % token.value) + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) try: - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.current_origin, self.relativize, - self.zone_origin) + rd = dns.rdata.from_text( + rdclass, + rdtype, + self.tok, + self.current_origin, + self.relativize, + self.zone_origin, + ) except dns.exception.SyntaxError: # Catch and reraise. raise @@ -194,7 +242,8 @@ class Reader: # helpful filename:line info. (ty, va) = sys.exc_info()[:2] raise dns.exception.SyntaxError( - "caught exception {}: {}".format(str(ty), str(va))) + "caught exception {}: {}".format(str(ty), str(va)) + ) if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default @@ -213,7 +262,7 @@ class Reader: self.txn.add(name, ttl, rd) - def _parse_modify(self, side): + def _parse_modify(self, side: str) -> Tuple[str, str, int, int, str]: # Here we catch everything in '{' '}' in a group so we can replace it # with ''. is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") @@ -225,31 +274,36 @@ class Reader: g1 = is_generate1.match(side) if g1: mod, sign, offset, width, base = g1.groups() - if sign == '': - sign = '+' + if sign == "": + sign = "+" g2 = is_generate2.match(side) if g2: mod, sign, offset = g2.groups() - if sign == '': - sign = '+' + if sign == "": + sign = "+" width = 0 - base = 'd' + base = "d" g3 = is_generate3.match(side) if g3: mod, sign, offset, width = g3.groups() - if sign == '': - sign = '+' - base = 'd' + if sign == "": + sign = "+" + base = "d" if not (g1 or g2 or g3): - mod = '' - sign = '+' + mod = "" + sign = "+" offset = 0 width = 0 - base = 'd' + base = "d" - if base != 'd': - raise NotImplementedError() + offset = int(offset) + width = int(width) + + if sign not in ["+", "-"]: + raise dns.exception.SyntaxError("invalid offset sign %s" % sign) + if base not in ["d", "o", "x", "X", "n", "N"]: + raise dns.exception.SyntaxError("invalid type %s" % base) return mod, sign, offset, width, base @@ -313,37 +367,47 @@ class Reader: if not token.is_identifier(): raise dns.exception.SyntaxError except Exception: - raise dns.exception.SyntaxError("unknown rdatatype '%s'" % - token.value) + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) # rhs (required) rhs = token.value - # The code currently only supports base 'd', so the last value - # in the tuple _parse_modify returns is ignored - lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs) - rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs) + def _calculate_index(counter: int, offset_sign: str, offset: int) -> int: + """Calculate the index from the counter and offset.""" + if offset_sign == "-": + offset *= -1 + return counter + offset + + def _format_index(index: int, base: str, width: int) -> str: + """Format the index with the given base, and zero-fill it + to the given width.""" + if base in ["d", "o", "x", "X"]: + return format(index, base).zfill(width) + + # base can only be n or N here + hexa = _format_index(index, "x", width) + nibbles = ".".join(hexa[::-1])[:width] + if base == "N": + nibbles = nibbles.upper() + return nibbles + + lmod, lsign, loffset, lwidth, lbase = self._parse_modify(lhs) + rmod, rsign, roffset, rwidth, rbase = self._parse_modify(rhs) for i in range(start, stop + 1, step): # +1 because bind is inclusive and python is exclusive - if lsign == '+': - lindex = i + int(loffset) - elif lsign == '-': - lindex = i - int(loffset) + lindex = _calculate_index(i, lsign, loffset) + rindex = _calculate_index(i, rsign, roffset) - if rsign == '-': - rindex = i - int(roffset) - elif rsign == '+': - rindex = i + int(roffset) + lzfindex = _format_index(lindex, lbase, lwidth) + rzfindex = _format_index(rindex, rbase, rwidth) - lzfindex = str(lindex).zfill(int(lwidth)) - rzfindex = str(rindex).zfill(int(rwidth)) + name = lhs.replace("$%s" % (lmod), lzfindex) + rdata = rhs.replace("$%s" % (rmod), rzfindex) - name = lhs.replace('$%s' % (lmod), lzfindex) - rdata = rhs.replace('$%s' % (rmod), rzfindex) - - self.last_name = dns.name.from_text(name, self.current_origin, - self.tok.idna_codec) + self.last_name = dns.name.from_text( + name, self.current_origin, self.tok.idna_codec + ) name = self.last_name if not name.is_subdomain(self.zone_origin): self._eat_line() @@ -352,9 +416,14 @@ class Reader: name = name.relativize(self.zone_origin) try: - rd = dns.rdata.from_text(rdclass, rdtype, rdata, - self.current_origin, self.relativize, - self.zone_origin) + rd = dns.rdata.from_text( + rdclass, + rdtype, + rdata, + self.current_origin, + self.relativize, + self.zone_origin, + ) except dns.exception.SyntaxError: # Catch and reraise. raise @@ -365,12 +434,13 @@ class Reader: # We convert them to syntax errors so that we can emit # helpful filename:line info. (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError("caught exception %s: %s" % - (str(ty), str(va))) + raise dns.exception.SyntaxError( + "caught exception %s: %s" % (str(ty), str(va)) + ) self.txn.add(name, ttl, rd) - def read(self): + def read(self) -> None: """Read a DNS zone file and build a zone object. @raises dns.zone.NoSOA: No SOA RR was found at the zone origin @@ -384,14 +454,16 @@ class Reader: if self.current_file is not None: self.current_file.close() if len(self.saved_state) > 0: - (self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known) = self.saved_state.pop(-1) + ( + self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known, + ) = self.saved_state.pop(-1) continue break elif token.is_eol(): @@ -399,53 +471,62 @@ class Reader: elif token.is_comment(): self.tok.get_eol() continue - elif token.value[0] == '$' and self.allow_directives: + elif token.value[0] == "$" and len(self.allowed_directives) > 0: + # Note that we only run directive processing code if at least + # one directive is allowed in order to be backwards compatible c = token.value.upper() - if c == '$TTL': + if c not in self.allowed_directives: + raise dns.exception.SyntaxError( + f"zone file directive '{c}' is not allowed" + ) + if c == "$TTL": token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError("bad $TTL") self.default_ttl = dns.ttl.from_text(token.value) self.default_ttl_known = True self.tok.get_eol() - elif c == '$ORIGIN': + elif c == "$ORIGIN": self.current_origin = self.tok.get_name() self.tok.get_eol() if self.zone_origin is None: self.zone_origin = self.current_origin self.txn._set_origin(self.current_origin) - elif c == '$INCLUDE' and self.allow_include: + elif c == "$INCLUDE": token = self.tok.get() filename = token.value token = self.tok.get() + new_origin: Optional[dns.name.Name] if token.is_identifier(): - new_origin =\ - dns.name.from_text(token.value, - self.current_origin, - self.tok.idna_codec) + new_origin = dns.name.from_text( + token.value, self.current_origin, self.tok.idna_codec + ) self.tok.get_eol() elif not token.is_eol_or_eof(): - raise dns.exception.SyntaxError( - "bad origin in $INCLUDE") + raise dns.exception.SyntaxError("bad origin in $INCLUDE") else: new_origin = self.current_origin - self.saved_state.append((self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known)) - self.current_file = open(filename, 'r') - self.tok = dns.tokenizer.Tokenizer(self.current_file, - filename) + self.saved_state.append( + ( + self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known, + ) + ) + self.current_file = open(filename, "r") + self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) self.current_origin = new_origin - elif c == '$GENERATE': + elif c == "$GENERATE": self._generate_line() else: raise dns.exception.SyntaxError( - "Unknown zone file directive '" + c + "'") + f"Unknown zone file directive '{c}'" + ) continue self.tok.unget(token) self._rr_line() @@ -454,13 +535,13 @@ class Reader: if detail is None: detail = "syntax error" ex = dns.exception.SyntaxError( - "%s:%d: %s" % (filename, line_number, detail)) + "%s:%d: %s" % (filename, line_number, detail) + ) tb = sys.exc_info()[2] raise ex.with_traceback(tb) from None class RRsetsReaderTransaction(dns.transaction.Transaction): - def __init__(self, manager, replacement, read_only): assert not read_only super().__init__(manager, replacement, read_only) @@ -512,8 +593,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): if commit and self._changed(): rrsets = [] for (name, _, _), rdataset in self.rdatasets.items(): - rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype, - rdataset.covers) + rrset = dns.rrset.RRset( + name, rdataset.rdclass, rdataset.rdtype, rdataset.covers + ) rrset.update(rdataset) rrsets.append(rrset) self.manager.set_rrsets(rrsets) @@ -521,15 +603,22 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): def _set_origin(self, origin): pass + def _iterate_rdatasets(self): + raise NotImplementedError # pragma: no cover + class RRSetsReaderManager(dns.transaction.TransactionManager): - def __init__(self, origin=dns.name.root, relativize=False, - rdclass=dns.rdataclass.IN): + def __init__( + self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN + ): self.origin = origin self.relativize = relativize self.rdclass = rdclass self.rrsets = [] + def reader(self): # pragma: no cover + raise NotImplementedError + def writer(self, replacement=False): assert replacement is True return RRsetsReaderTransaction(self, True, False) @@ -548,10 +637,18 @@ class RRSetsReaderManager(dns.transaction.TransactionManager): self.rrsets = rrsets -def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN, - default_rdclass=dns.rdataclass.IN, - rdtype=None, default_ttl=None, idna_codec=None, - origin=dns.name.root, relativize=False): +def read_rrsets( + text: Any, + name: Optional[Union[dns.name.Name, str]] = None, + ttl: Optional[int] = None, + rdclass: Optional[Union[dns.rdataclass.RdataClass, str]] = dns.rdataclass.IN, + default_rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None, + default_ttl: Optional[Union[int, str]] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[Union[dns.name.Name, str]] = dns.name.root, + relativize: bool = False, +) -> List[dns.rrset.RRset]: """Read one or more rrsets from the specified text, possibly subject to restrictions. @@ -610,15 +707,27 @@ def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN, if isinstance(default_ttl, str): default_ttl = dns.ttl.from_text(default_ttl) if rdclass is not None: - rdclass = dns.rdataclass.RdataClass.make(rdclass) - default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + else: + the_rdclass = None + the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) if rdtype is not None: - rdtype = dns.rdatatype.RdataType.make(rdtype) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + else: + the_rdtype = None manager = RRSetsReaderManager(origin, relativize, default_rdclass) with manager.writer(True) as txn: - tok = dns.tokenizer.Tokenizer(text, '', idna_codec=idna_codec) - reader = Reader(tok, default_rdclass, txn, allow_directives=False, - force_name=name, force_ttl=ttl, force_rdclass=rdclass, - force_rdtype=rdtype, default_ttl=default_ttl) + tok = dns.tokenizer.Tokenizer(text, "", idna_codec=idna_codec) + reader = Reader( + tok, + the_default_rdclass, + txn, + allow_directives=False, + force_name=name, + force_ttl=ttl, + force_rdclass=the_rdclass, + force_rdtype=the_rdtype, + default_ttl=default_ttl, + ) reader.read() return manager.rrsets diff --git a/lib/dns/zonetypes.py b/lib/dns/zonetypes.py new file mode 100644 index 00000000..195ee2ec --- /dev/null +++ b/lib/dns/zonetypes.py @@ -0,0 +1,37 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""Common zone-related types.""" + +# This is a separate file to avoid import circularity between dns.zone and +# the implementation of the ZONEMD type. + +import hashlib + +import dns.enum + + +class DigestScheme(dns.enum.IntEnum): + """ZONEMD Scheme""" + + SIMPLE = 1 + + @classmethod + def _maximum(cls): + return 255 + + +class DigestHashAlgorithm(dns.enum.IntEnum): + """ZONEMD Hash Algorithm""" + + SHA384 = 1 + SHA512 = 2 + + @classmethod + def _maximum(cls): + return 255 + + +_digest_hashers = { + DigestHashAlgorithm.SHA384: hashlib.sha384, + DigestHashAlgorithm.SHA512: hashlib.sha512, +} diff --git a/requirements.txt b/requirements.txt index 017ae555..69e57f37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ cheroot==9.0.0 cherrypy==18.8.0 cloudinary==1.30.0 distro==1.8.0 -dnspython==2.2.1 +dnspython==2.3.0 facebook-sdk==3.1.0 future==0.18.3 ga4mp==2.0.4