Bump dnspython from 2.2.1 to 2.3.0 (#1975)

* Bump dnspython from 2.2.1 to 2.3.0

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.2.1 to 2.3.0.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.2.1...v2.3.0)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.3.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2023-03-02 20:54:32 -08:00 committed by GitHub
parent 6910079330
commit 32c06a8b72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
137 changed files with 7699 additions and 4277 deletions

View file

@ -18,49 +18,52 @@
"""dnspython DNS toolkit""" """dnspython DNS toolkit"""
__all__ = [ __all__ = [
'asyncbackend', "asyncbackend",
'asyncquery', "asyncquery",
'asyncresolver', "asyncresolver",
'dnssec', "dnssec",
'e164', "dnssectypes",
'edns', "e164",
'entropy', "edns",
'exception', "entropy",
'flags', "exception",
'immutable', "flags",
'inet', "immutable",
'ipv4', "inet",
'ipv6', "ipv4",
'message', "ipv6",
'name', "message",
'namedict', "name",
'node', "namedict",
'opcode', "node",
'query', "opcode",
'rcode', "query",
'rdata', "quic",
'rdataclass', "rcode",
'rdataset', "rdata",
'rdatatype', "rdataclass",
'renderer', "rdataset",
'resolver', "rdatatype",
'reversename', "renderer",
'rrset', "resolver",
'serial', "reversename",
'set', "rrset",
'tokenizer', "serial",
'transaction', "set",
'tsig', "tokenizer",
'tsigkeyring', "transaction",
'ttl', "tsig",
'rdtypes', "tsigkeyring",
'update', "ttl",
'version', "rdtypes",
'versioned', "update",
'wire', "version",
'xfr', "versioned",
'zone', "wire",
'zonefile', "xfr",
"zone",
"zonetypes",
"zonefile",
] ]
from dns.version import version as __version__ # noqa from dns.version import version as __version__ # noqa

View file

@ -3,6 +3,7 @@
# This is a nullcontext for both sync and async. 3.7 has a nullcontext, # This is a nullcontext for both sync and async. 3.7 has a nullcontext,
# but it is only for sync use. # but it is only for sync use.
class NullContext: class NullContext:
def __init__(self, enter_result=None): def __init__(self, enter_result=None):
self.enter_result = enter_result self.enter_result = enter_result
@ -23,6 +24,7 @@ class NullContext:
# These are declared here so backends can import them without creating # These are declared here so backends can import them without creating
# circular dependencies with dns.asyncbackend. # circular dependencies with dns.asyncbackend.
class Socket: # pragma: no cover class Socket: # pragma: no cover
async def close(self): async def close(self):
pass pass
@ -41,6 +43,9 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
raise NotImplementedError raise NotImplementedError
@ -58,12 +63,23 @@ class StreamSocket(Socket): # pragma: no cover
class Backend: # pragma: no cover class Backend: # pragma: no cover
def name(self): def name(self):
return 'unknown' return "unknown"
async def make_socket(self, af, socktype, proto=0, async def make_socket(
source=None, destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
raise NotImplementedError raise NotImplementedError
def datagram_connection_required(self): def datagram_connection_required(self):
return False return False
async def sleep(self, interval):
raise NotImplementedError

View file

@ -10,7 +10,8 @@ import dns._asyncbackend
import dns.exception import dns.exception
_is_win32 = sys.platform == 'win32' _is_win32 = sys.platform == "win32"
def _get_running_loop(): def _get_running_loop():
try: try:
@ -30,7 +31,6 @@ class _DatagramProtocol:
def datagram_received(self, data, addr): def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done(): if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr)) self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done(): if self.recvfrom and not self.recvfrom.done():
@ -56,30 +56,34 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol): def __init__(self, family, transport, protocol):
self.family = family super().__init__(family)
self.transport = transport self.transport = transport
self.protocol = protocol self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto # no timeout for asyncio sendto
self.transport.sendto(what, destination) self.transport.sendto(what, destination)
return len(what)
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it # ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future() done = _get_running_loop().create_future()
try:
assert self.protocol.recvfrom is None assert self.protocol.recvfrom is None
self.protocol.recvfrom = done self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout) await _maybe_wait_for(done, timeout)
return done.result() return done.result()
finally:
self.protocol.recvfrom = None
async def close(self): async def close(self):
self.protocol.close() self.protocol.close()
async def getpeername(self): async def getpeername(self):
return self.transport.get_extra_info('peername') return self.transport.get_extra_info("peername")
async def getsockname(self): async def getsockname(self):
return self.transport.get_extra_info('sockname') return self.transport.get_extra_info("sockname")
class StreamSocket(dns._asyncbackend.StreamSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
@ -93,8 +97,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
return await _maybe_wait_for(self.writer.drain(), timeout) return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout): async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size), return await _maybe_wait_for(self.reader.read(size), timeout)
timeout)
async def close(self): async def close(self):
self.writer.close() self.writer.close()
@ -104,43 +107,64 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
pass pass
async def getpeername(self): async def getpeername(self):
return self.writer.get_extra_info('peername') return self.writer.get_extra_info("peername")
async def getsockname(self): async def getsockname(self):
return self.writer.get_extra_info('sockname') return self.writer.get_extra_info("sockname")
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
return 'asyncio' return "asyncio"
async def make_socket(self, af, socktype, proto=0, async def make_socket(
source=None, destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
if destination is None and socktype == socket.SOCK_DGRAM and \ socktype,
_is_win32: proto=0,
raise NotImplementedError('destinationless datagram sockets ' source=None,
'are not supported by asyncio ' destination=None,
'on Windows') timeout=None,
ssl_context=None,
server_hostname=None,
):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop() loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM: if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af, _DatagramProtocol,
proto=proto, remote_addr=destination) source,
family=af,
proto=proto,
remote_addr=destination,
)
return DatagramSocket(af, transport, protocol) return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM: elif socktype == socket.SOCK_STREAM:
if destination is None:
# This shouldn't happen, but we check to make code analysis software
# happier.
raise ValueError("destination required for stream sockets")
(r, w) = await _maybe_wait_for( (r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0], asyncio.open_connection(
destination[0],
destination[1], destination[1],
ssl=ssl_context, ssl=ssl_context,
family=af, family=af,
proto=proto, proto=proto,
local_addr=source, local_addr=source,
server_hostname=server_hostname), server_hostname=server_hostname,
timeout) ),
timeout,
)
return StreamSocket(af, r, w) return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' + raise NotImplementedError(
f'type {socktype}') # pragma: no cover "unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
await asyncio.sleep(interval) await asyncio.sleep(interval)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination) return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size) return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
await self.socket.close() await self.socket.close()
@ -57,12 +59,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout): async def sendall(self, what, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.sendall(what) return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout): async def recv(self, size, timeout):
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
return await self.socket.recv(size) return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
await self.socket.close() await self.socket.close()
@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
return 'curio' return "curio"
async def make_socket(self, af, socktype, proto=0, async def make_socket(
source=None, destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if socktype == socket.SOCK_DGRAM: if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto) s = curio.socket.socket(af, socktype, proto)
try: try:
@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend):
else: else:
source_addr = None source_addr = None
async with _maybe_timeout(timeout): async with _maybe_timeout(timeout):
s = await curio.open_connection(destination[0], destination[1], s = await curio.open_connection(
destination[0],
destination[1],
ssl=ssl_context, ssl=ssl_context,
source_addr=source_addr, source_addr=source_addr,
server_hostname=server_hostname) server_hostname=server_hostname,
)
return StreamSocket(s) return StreamSocket(s)
raise NotImplementedError('unsupported socket ' + raise NotImplementedError(
f'type {socktype}') # pragma: no cover "unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
await curio.sleep(interval) await curio.sleep(interval)

View file

@ -1,84 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator is for python 3.6,
# which doesn't have Context Variables. This implementation is somewhat
# costly for classes with slots, as it adds a __dict__ to them.
import inspect
class _Immutable:
"""Immutable mixin class"""
# Note we MUST NOT have __slots__ as that causes
#
# TypeError: multiple bases have instance lay-out conflict
#
# when we get mixed in with another class with slots. When we
# get mixed into something with slots, it effectively adds __dict__ to
# the slots of the other class, which allows attribute setting to work,
# albeit at the cost of the dictionary.
def __setattr__(self, name, value):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
try:
# Are we already initializing an immutable class?
previous = args[0]._immutable_init
except AttributeError:
# We are the first!
previous = None
object.__setattr__(args[0], '_immutable_init', args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
if not previous:
# If we started the initialization, establish immutability
# by removing the attribute that allows mutation
object.__delattr__(args[0], '_immutable_init')
nf.__signature__ = inspect.signature(f)
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View file

@ -8,7 +8,7 @@ import contextvars
import inspect import inspect
_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False) _in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable: class _Immutable:
@ -41,6 +41,7 @@ def _immutable_init(f):
f(*args, **kwargs) f(*args, **kwargs)
finally: finally:
_in__init__.reset(previous) _in__init__.reset(previous)
nf.__signature__ = inspect.signature(f) nf.__signature__ = inspect.signature(f)
return nf return nf
@ -50,7 +51,7 @@ def immutable(cls):
# Some ancestor already has the mixin, so just make sure we keep # Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol. # following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__) cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'): if hasattr(cls, "__setstate__"):
cls.__setstate__ = _immutable_init(cls.__setstate__) cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls ncls = cls
else: else:
@ -63,7 +64,8 @@ def immutable(cls):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'): if hasattr(cls, "__setstate__"):
@_immutable_init @_immutable_init
def __setstate__(self, *args, **kwargs): def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs) super().__setstate__(*args, **kwargs)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination) return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.socket.recvfrom(size) return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
self.socket.close() self.socket.close()
@ -58,12 +60,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout): async def sendall(self, what, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.stream.send_all(what) return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout): async def recv(self, size, timeout):
with _maybe_timeout(timeout): with _maybe_timeout(timeout):
return await self.stream.receive_some(size) return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout) raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self): async def close(self):
await self.stream.aclose() await self.stream.aclose()
@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend): class Backend(dns._asyncbackend.Backend):
def name(self): def name(self):
return 'trio' return "trio"
async def make_socket(self, af, socktype, proto=0, source=None, async def make_socket(
destination=None, timeout=None, self,
ssl_context=None, server_hostname=None): af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto) s = trio.socket.socket(af, socktype, proto)
stream = None stream = None
try: try:
@ -103,19 +113,20 @@ class Backend(dns._asyncbackend.Backend):
return DatagramSocket(s) return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM: elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s) stream = trio.SocketStream(s)
s = None
tls = False tls = False
if ssl_context: if ssl_context:
tls = True tls = True
try: try:
stream = trio.SSLStream(stream, ssl_context, stream = trio.SSLStream(
server_hostname=server_hostname) stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover except Exception: # pragma: no cover
await stream.aclose() await stream.aclose()
raise raise
return StreamSocket(af, stream, tls) return StreamSocket(af, stream, tls)
raise NotImplementedError('unsupported socket ' + raise NotImplementedError(
f'type {socktype}') # pragma: no cover "unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval): async def sleep(self, interval):
await trio.sleep(interval) await trio.sleep(interval)

View file

@ -1,26 +1,33 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Dict
import dns.exception import dns.exception
# pylint: disable=unused-import # pylint: disable=unused-import
from dns._asyncbackend import Socket, DatagramSocket, \ from dns._asyncbackend import (
StreamSocket, Backend # noqa: Socket,
DatagramSocket,
StreamSocket,
Backend,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import # pylint: enable=unused-import
_default_backend = None _default_backend = None
_backends = {} _backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes # Allow sniffio import to be disabled for testing purposes
_no_sniffio = False _no_sniffio = False
class AsyncLibraryNotFoundError(dns.exception.DNSException): class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass pass
def get_backend(name): def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend. """Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio", *name*, a ``str``, the name of the backend. Currently the "trio",
@ -32,22 +39,25 @@ def get_backend(name):
backend = _backends.get(name) backend = _backends.get(name)
if backend: if backend:
return backend return backend
if name == 'trio': if name == "trio":
import dns._trio_backend import dns._trio_backend
backend = dns._trio_backend.Backend() backend = dns._trio_backend.Backend()
elif name == 'curio': elif name == "curio":
import dns._curio_backend import dns._curio_backend
backend = dns._curio_backend.Backend() backend = dns._curio_backend.Backend()
elif name == 'asyncio': elif name == "asyncio":
import dns._asyncio_backend import dns._asyncio_backend
backend = dns._asyncio_backend.Backend() backend = dns._asyncio_backend.Backend()
else: else:
raise NotImplementedError(f'unimplemented async backend {name}') raise NotImplementedError(f"unimplemented async backend {name}")
_backends[name] = backend _backends[name] = backend
return backend return backend
def sniff(): def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using """Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available. the ``sniffio`` module if it is available.
@ -59,35 +69,32 @@ def sniff():
if _no_sniffio: if _no_sniffio:
raise ImportError raise ImportError
import sniffio import sniffio
try: try:
return sniffio.current_async_library() return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError: except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError('sniffio cannot determine ' + raise AsyncLibraryNotFoundError(
'async library') "sniffio cannot determine " + "async library"
)
except ImportError: except ImportError:
import asyncio import asyncio
try: try:
asyncio.get_running_loop() asyncio.get_running_loop()
return 'asyncio' return "asyncio"
except RuntimeError: except RuntimeError:
raise AsyncLibraryNotFoundError('no async library detected') raise AsyncLibraryNotFoundError("no async library detected")
except AttributeError: # pragma: no cover
# we have to check current_task on 3.6
if not asyncio.Task.current_task():
raise AsyncLibraryNotFoundError('no async library detected')
return 'asyncio'
def get_default_backend(): def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary. """Get the default backend, initializing it if necessary."""
"""
if _default_backend: if _default_backend:
return _default_backend return _default_backend
return set_default_backend(sniff()) return set_default_backend(sniff())
def set_default_backend(name): def set_default_backend(name: str) -> Backend:
"""Set the default backend. """Set the default backend.
It's not normally necessary to call this method, as It's not normally necessary to call this method, as

View file

@ -1,13 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
class Backend:
...
def get_backend(name: str) -> Backend:
...
def sniff() -> str:
...
def get_default_backend() -> Backend:
...
def set_default_backend(name: str) -> Backend:
...

View file

@ -17,7 +17,10 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64 import base64
import contextlib
import socket import socket
import struct import struct
import time import time
@ -27,12 +30,24 @@ import dns.exception
import dns.inet import dns.inet
import dns.name import dns.name
import dns.message import dns.message
import dns.quic
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.transaction
from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ from dns._asyncbackend import NullContext
UDPMode, _have_httpx, _have_http2, NoDOH from dns.query import (
_compute_times,
_matches_destination,
BadResponse,
ssl,
UDPMode,
_have_httpx,
_have_http2,
NoDOH,
NoDOQ,
)
if _have_httpx: if _have_httpx:
import httpx import httpx
@ -47,11 +62,11 @@ def _source_tuple(af, address, port):
if address or port: if address or port:
if address is None: if address is None:
if af == socket.AF_INET: if af == socket.AF_INET:
address = '0.0.0.0' address = "0.0.0.0"
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
address = '::' address = "::"
else: else:
raise NotImplementedError(f'unknown address family {af}') raise NotImplementedError(f"unknown address family {af}")
return (address, port) return (address, port)
else: else:
return None return None
@ -66,7 +81,12 @@ def _timeout(expiration, now=None):
return None return None
async def send_udp(sock, what, destination, expiration=None): async def send_udp(
sock: dns.asyncbackend.DatagramSocket,
what: Union[dns.message.Message, bytes],
destination: Any,
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket. """Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``. *sock*, a ``dns.asyncbackend.DatagramSocket``.
@ -78,7 +98,8 @@ async def send_udp(sock, what, destination, expiration=None):
*expiration*, a ``float`` or ``None``, the absolute time at which *expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will a timeout exception should be raised. If ``None``, no timeout will
occur. occur. The expiration value is meaningless for the asyncio backend, as
asyncio's transport sendto() never blocks.
Returns an ``(int, float)`` tuple of bytes sent and the sent time. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
""" """
@ -90,35 +111,61 @@ async def send_udp(sock, what, destination, expiration=None):
return (n, sent_time) return (n, sent_time)
async def receive_udp(sock, destination=None, expiration=None, async def receive_udp(
ignore_unexpected=False, one_rr_per_rrset=False, sock: dns.asyncbackend.DatagramSocket,
keyring=None, request_mac=b'', ignore_trailing=False, destination: Optional[Any] = None,
raise_on_truncation=False): expiration: Optional[float] = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
) -> Any:
"""Read a DNS message from a UDP socket. """Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``. *sock*, a ``dns.asyncbackend.DatagramSocket``.
See :py:func:`dns.query.receive_udp()` for the documentation of the other See :py:func:`dns.query.receive_udp()` for the documentation of the other
parameters, exceptions, and return type of this method. parameters, and exceptions.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
""" """
wire = b'' wire = b""
while 1: while 1:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if _matches_destination(sock.family, from_address, destination, if _matches_destination(
ignore_unexpected): sock.family, from_address, destination, ignore_unexpected
):
break break
received_time = time.time() received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing, ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation) raise_on_truncation=raise_on_truncation,
)
return (r, received_time, from_address) return (r, received_time, from_address)
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, async def udp(
ignore_trailing=False, raise_on_truncation=False, sock=None, q: dns.message.Message,
backend=None): where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP. """Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
@ -134,13 +181,10 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
""" """
wire = q.to_wire() wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
s = None
# After 3.6 is no longer supported, this can use an AsyncExitStack.
try:
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
destination = _lltuple((where, port), af) destination = _lltuple((where, port), af)
if sock: if sock:
s = sock cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else: else:
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
@ -149,27 +193,40 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
dtuple = (where, port) dtuple = (where, port)
else: else:
dtuple = None dtuple = None
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
dtuple) async with cm as s:
await send_udp(s, wire, destination, expiration) await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration, (r, received_time, _) = await receive_udp(
s,
destination,
expiration,
ignore_unexpected, ignore_unexpected,
one_rr_per_rrset, one_rr_per_rrset,
q.keyring, q.mac, q.keyring,
q.mac,
ignore_trailing, ignore_trailing,
raise_on_truncation) raise_on_truncation,
)
r.time = received_time - begin_time r.time = received_time - begin_time
if not q.is_response(r): if not q.is_response(r):
raise BadResponse raise BadResponse
return r return r
finally:
if not sock and s:
await s.close()
async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
source_port=0, ignore_unexpected=False, async def udp_with_fallback(
one_rr_per_rrset=False, ignore_trailing=False, q: dns.message.Message,
udp_sock=None, tcp_sock=None, backend=None): where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back """Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response. to TCP if UDP results in a truncated response.
@ -191,18 +248,42 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
method. method.
""" """
try: try:
response = await udp(q, where, timeout, port, source, source_port, response = await udp(
ignore_unexpected, one_rr_per_rrset, q,
ignore_trailing, True, udp_sock, backend) where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
backend,
)
return (response, False) return (response, False)
except dns.message.Truncated: except dns.message.Truncated:
response = await tcp(q, where, timeout, port, source, source_port, response = await tcp(
one_rr_per_rrset, ignore_trailing, tcp_sock, q,
backend) where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
backend,
)
return (response, True) return (response, True)
async def send_tcp(sock, what, expiration=None): async def send_tcp(
sock: dns.asyncbackend.StreamSocket,
what: Union[dns.message.Message, bytes],
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket. """Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``. *sock*, a ``dns.asyncbackend.StreamSocket``.
@ -212,12 +293,14 @@ async def send_tcp(sock, what, expiration=None):
""" """
if isinstance(what, dns.message.Message): if isinstance(what, dns.message.Message):
what = what.to_wire() wire = what.to_wire()
l = len(what) else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us # copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed # avoid writev() or doing a short write that would get pushed
# onto the net # onto the net
tcpmsg = struct.pack("!H", l) + what tcpmsg = struct.pack("!H", l) + wire
sent_time = time.time() sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time) return (len(tcpmsg), sent_time)
@ -227,18 +310,24 @@ async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we """Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF. either get the desired amount, or we hit EOF.
""" """
s = b'' s = b""
while count > 0: while count > 0:
n = await sock.recv(count, _timeout(expiration)) n = await sock.recv(count, _timeout(expiration))
if n == b'': if n == b"":
raise EOFError raise EOFError
count = count - len(n) count = count - len(n)
s = s + n s = s + n
return s return s
async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, async def receive_tcp(
keyring=None, request_mac=b'', ignore_trailing=False): sock: dns.asyncbackend.StreamSocket,
expiration: Optional[float] = None,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket. """Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``. *sock*, a ``dns.asyncbackend.StreamSocket``.
@ -251,15 +340,28 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
(l,) = struct.unpack("!H", ldata) (l,) = struct.unpack("!H", ldata)
wire = await _read_exactly(sock, l, expiration) wire = await _read_exactly(sock, l, expiration)
received_time = time.time() received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing) ignore_trailing=ignore_trailing,
)
return (r, received_time) return (r, received_time)
async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, async def tcp(
one_rr_per_rrset=False, ignore_trailing=False, sock=None, q: dns.message.Message,
backend=None): where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP. """Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
@ -276,41 +378,48 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
wire = q.to_wire() wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
s = None
# After 3.6 is no longer supported, this can use an AsyncExitStack.
try:
if sock: if sock:
# Verify that the socket is connected, as if it's not connected, # Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or # it's not writable, and the polling in send_tcp() will time out or
# hang forever. # hang forever.
await sock.getpeername() await sock.getpeername()
s = sock cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else: else:
# These are simple (address, port) pairs, not # These are simple (address, port) pairs, not family-dependent tuples
# family-dependent tuples you pass to lowlevel socket # you pass to low-level socket code.
# code.
af = dns.inet.af_for_address(where) af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port) stuple = _source_tuple(af, source, source_port)
dtuple = (where, port) dtuple = (where, port)
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, cm = await backend.make_socket(
dtuple, timeout) af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
await send_tcp(s, wire, expiration) await send_tcp(s, wire, expiration)
(r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, (r, received_time) = await receive_tcp(
q.keyring, q.mac, s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
ignore_trailing) )
r.time = received_time - begin_time r.time = received_time - begin_time
if not q.is_response(r): if not q.is_response(r):
raise BadResponse raise BadResponse
return r return r
finally:
if not sock and s:
await s.close()
async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, sock=None, async def tls(
backend=None, ssl_context=None, server_hostname=None): q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS. """Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
@ -326,11 +435,14 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
See :py:func:`dns.query.tls()` for the documentation of the other See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method. parameters, exceptions, and return type of this method.
""" """
# After 3.6 is no longer supported, this can use an AsyncExitStack.
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
if not sock: if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None: if ssl_context is None:
ssl_context = ssl.create_default_context() # See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None: if server_hostname is None:
ssl_context.check_hostname = False ssl_context.check_hostname = False
else: else:
@ -341,25 +453,49 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
dtuple = (where, port) dtuple = (where, port)
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, cm = await backend.make_socket(
dtuple, timeout, ssl_context, af,
server_hostname) socket.SOCK_STREAM,
else: 0,
s = sock stuple,
try: dtuple,
timeout,
ssl_context,
server_hostname,
)
async with cm as s:
timeout = _timeout(expiration) timeout = _timeout(expiration)
response = await tcp(q, where, timeout, port, source, source_port, response = await tcp(
one_rr_per_rrset, ignore_trailing, s, backend) q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
s,
backend,
)
end_time = time.time() end_time = time.time()
response.time = end_time - begin_time response.time = end_time - begin_time
return response return response
finally:
if not sock and s:
await s.close()
async def https(q, where, timeout=None, port=443, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, client=None, async def https(
path='/dns-query', post=True, verify=True): q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 443,
source: Optional[str] = None,
source_port: int = 0, # pylint: disable=W0613
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
client: Optional["httpx.AsyncClient"] = None,
path: str = "/dns-query",
post: bool = True,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS. """Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for *client*, a ``httpx.AsyncClient``. If provided, the client to use for
@ -373,7 +509,7 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
""" """
if not _have_httpx: if not _have_httpx:
raise NoDOH('httpx is not available.') # pragma: no cover raise NoDOH("httpx is not available.") # pragma: no cover
wire = q.to_wire() wire = q.to_wire()
try: try:
@ -381,65 +517,78 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
except ValueError: except ValueError:
af = None af = None
transport = None transport = None
headers = { headers = {"accept": "application/dns-message"}
"accept": "application/dns-message"
}
if af is not None: if af is not None:
if af == socket.AF_INET: if af == socket.AF_INET:
url = 'https://{}:{}{}'.format(where, port, path) url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6: elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path) url = "https://[{}]:{}{}".format(where, port, path)
else: else:
url = where url = where
if source is not None: if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0]) transport = httpx.AsyncHTTPTransport(local_address=source[0])
# After 3.6 is no longer supported, this can use an AsyncExitStack if client:
client_to_close = None cm: contextlib.AbstractAsyncContextManager = NullContext(client)
try: else:
if not client: cm = httpx.AsyncClient(
client = httpx.AsyncClient(http1=True, http2=_have_http2, http1=True, http2=_have_http2, verify=verify, transport=transport
verify=verify, transport=transport) )
client_to_close = client
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples # GET and POST examples
if post: if post:
headers.update({ headers.update(
{
"content-type": "application/dns-message", "content-type": "application/dns-message",
"content-length": str(len(wire)) "content-length": str(len(wire)),
}) }
response = await client.post(url, headers=headers, content=wire, )
timeout=timeout) response = await the_client.post(
url, headers=headers, content=wire, timeout=timeout
)
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
wire = wire.decode() # httpx does a repr() if we give it bytes twire = wire.decode() # httpx does a repr() if we give it bytes
response = await client.get(url, headers=headers, timeout=timeout, response = await the_client.get(
params={"dns": wire}) url, headers=headers, timeout=timeout, params={"dns": twire}
finally: )
if client_to_close:
await client.aclose()
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
if response.status_code < 200 or response.status_code > 299: if response.status_code < 200 or response.status_code > 299:
raise ValueError('{} responded with status code {}' raise ValueError(
'\nResponse body: {}'.format(where, "{} responded with status code {}"
response.status_code, "\nResponse body: {!r}".format(
response.content)) where, response.status_code, response.content
r = dns.message.from_wire(response.content, )
)
r = dns.message.from_wire(
response.content,
keyring=q.keyring, keyring=q.keyring,
request_mac=q.request_mac, request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset, one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing) ignore_trailing=ignore_trailing,
r.time = response.elapsed )
r.time = response.elapsed.total_seconds()
if not q.is_response(r): if not q.is_response(r):
raise BadResponse raise BadResponse
return r return r
async def inbound_xfr(where, txn_manager, query=None,
port=53, timeout=None, lifetime=None, source=None, async def inbound_xfr(
source_port=0, udp_mode=UDPMode.NEVER, backend=None): where: str,
txn_manager: dns.transaction.TransactionManager,
query: Optional[dns.message.Message] = None,
port: int = 53,
timeout: Optional[float] = None,
lifetime: Optional[float] = None,
source: Optional[str] = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the """Conduct an inbound transfer and apply it via a transaction from the
txn_manager. txn_manager.
@ -472,42 +621,48 @@ async def inbound_xfr(where, txn_manager, query=None,
is_udp = False is_udp = False
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, sock_type, 0, stuple, dtuple, s = await backend.make_socket(
_timeout(expiration)) af, sock_type, 0, stuple, dtuple, _timeout(expiration)
)
async with s: async with s:
if is_udp: if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration)) await s.sendto(wire, dtuple, _timeout(expiration))
else: else:
tcpmsg = struct.pack("!H", len(wire)) + wire tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration) await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
is_udp) as inbound:
done = False done = False
tsig_ctx = None tsig_ctx = None
while not done: while not done:
(_, mexpiration) = _compute_times(timeout) (_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \ if mexpiration is None or (
(expiration is not None and mexpiration > expiration): expiration is not None and mexpiration > expiration
):
mexpiration = expiration mexpiration = expiration
if is_udp: if is_udp:
destination = _lltuple((where, port), af) destination = _lltuple((where, port), af)
while True: while True:
timeout = _timeout(mexpiration) timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535, (rwire, from_address) = await s.recvfrom(65535, timeout)
timeout) if _matches_destination(
if _matches_destination(af, from_address, af, from_address, destination, True
destination, True): ):
break break
else: else:
ldata = await _read_exactly(s, 2, mexpiration) ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata) (l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration) rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = (rdtype == dns.rdatatype.IXFR) is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(rwire, keyring=query.keyring, r = dns.message.from_wire(
request_mac=query.mac, xfr=True, rwire,
origin=origin, tsig_ctx=tsig_ctx, keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp), multi=(not is_udp),
one_rr_per_rrset=is_ixfr) one_rr_per_rrset=is_ixfr,
)
try: try:
done = inbound.process_message(r) done = inbound.process_message(r)
except dns.xfr.UseTCP: except dns.xfr.UseTCP:
@ -521,3 +676,62 @@ async def inbound_xfr(where, txn_manager, query=None,
tsig_ctx = r.tsig_ctx tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig: if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG") raise dns.exception.FormError("missing TSIG")
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(context, verify_mode=verify) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
start = time.time()
stream = await the_connection.make_stream()
async with stream:
await stream.send(wire, True)
wire = await stream.receive(timeout)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r

View file

@ -1,43 +0,0 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
async def udp(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.DatagramSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tls(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

View file

@ -17,13 +17,18 @@
"""Asynchronous DNS stub resolver.""" """Asynchronous DNS stub resolver."""
from typing import Any, Dict, Optional, Union
import time import time
import dns.asyncbackend import dns.asyncbackend
import dns.asyncquery import dns.asyncquery
import dns.exception import dns.exception
import dns.name
import dns.query import dns.query
import dns.resolver import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
# import some resolver symbols for brevity # import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
@ -37,11 +42,19 @@ _tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver): class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver.""" """Asynchronous DNS stub resolver."""
async def resolve(self, qname, rdtype=dns.rdatatype.A, async def resolve(
rdclass=dns.rdataclass.IN, self,
tcp=False, source=None, raise_on_no_answer=True, qname: Union[dns.name.Name, str],
source_port=0, lifetime=None, search=None, rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
backend=None): rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question. """Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
@ -52,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver):
type of this method. type of this method.
""" """
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, resolution = dns.resolver._Resolution(
raise_on_no_answer, search) self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
)
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
start = time.time() start = time.time()
@ -66,30 +80,40 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None: if answer is not None:
# cache hit! # cache hit!
return answer return answer
assert request is not None # needed for type checking
done = False done = False
while not done: while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver() (nameserver, port, tcp, backoff) = resolution.next_nameserver()
if backoff: if backoff:
await backend.sleep(backoff) await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime, timeout = self._compute_timeout(start, lifetime, resolution.errors)
resolution.errors)
try: try:
if dns.inet.is_address(nameserver): if dns.inet.is_address(nameserver):
if tcp: if tcp:
response = await _tcp(request, nameserver, response = await _tcp(
timeout, port, request,
source, source_port,
backend=backend)
else:
response = await _udp(request, nameserver,
timeout, port,
source, source_port,
raise_on_truncation=True,
backend=backend)
else:
response = await dns.asyncquery.https(request,
nameserver, nameserver,
timeout=timeout) timeout,
port,
source,
source_port,
backend=backend,
)
else:
response = await _udp(
request,
nameserver,
timeout,
port,
source,
source_port,
raise_on_truncation=True,
backend=backend,
)
else:
response = await dns.asyncquery.https(
request, nameserver, timeout=timeout
)
except Exception as ex: except Exception as ex:
(_, done) = resolution.query_result(None, ex) (_, done) = resolution.query_result(None, ex)
continue continue
@ -101,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None: if answer is not None:
return answer return answer
async def resolve_address(self, ipaddr, *args, **kwargs): async def resolve_address(
self, ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR """Use an asynchronous resolver to run a reverse query for PTR
records. records.
@ -116,15 +142,20 @@ class Resolver(dns.resolver.BaseResolver):
function. function.
""" """
# We make a modified kwargs for type checking happiness, as otherwise
return await self.resolve(dns.reversename.from_address(ipaddr), # we get a legit warning about possibly having rdtype and rdclass
rdtype=dns.rdatatype.PTR, # in the kwargs more than once.
rdclass=dns.rdataclass.IN, modified_kwargs: Dict[str, Any] = {}
*args, **kwargs) modified_kwargs.update(kwargs)
modified_kwargs["rdtype"] = dns.rdatatype.PTR
modified_kwargs["rdclass"] = dns.rdataclass.IN
return await self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
)
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
async def canonical_name(self, name): async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries The canonical name is the name the resolver uses for queries
@ -149,14 +180,15 @@ class Resolver(dns.resolver.BaseResolver):
default_resolver = None default_resolver = None
def get_default_resolver(): def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary.""" """Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None: if default_resolver is None:
reset_default_resolver() reset_default_resolver()
assert default_resolver is not None
return default_resolver return default_resolver
def reset_default_resolver(): def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver. """Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
@ -167,9 +199,18 @@ def reset_default_resolver():
default_resolver = Resolver() default_resolver = Resolver()
async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, async def resolve(
tcp=False, source=None, raise_on_no_answer=True, qname: Union[dns.name.Name, str],
source_port=0, lifetime=None, search=None, backend=None): rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question. """Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver This is a convenience function that uses the default resolver
@ -179,13 +220,23 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
information on the parameters. information on the parameters.
""" """
return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, return await get_default_resolver().resolve(
source, raise_on_no_answer, qname,
source_port, lifetime, search, rdtype,
backend) rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)
async def resolve_address(ipaddr, *args, **kwargs): async def resolve_address(
ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records. """Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
@ -194,7 +245,8 @@ async def resolve_address(ipaddr, *args, **kwargs):
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def canonical_name(name):
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*. """Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more See :py:func:`dns.resolver.Resolver.canonical_name` for more
@ -203,8 +255,14 @@ async def canonical_name(name):
return await get_default_resolver().canonical_name(name) return await get_default_resolver().canonical_name(name)
async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
resolver=None, backend=None): async def zone_for_name(
name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
tcp: bool = False,
resolver: Optional[Resolver] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.name.Name:
"""Find the name of the zone which contains the specified name. """Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more See :py:func:`dns.resolver.Resolver.zone_for_name` for more
@ -219,8 +277,10 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
raise NotAbsolute(name) raise NotAbsolute(name)
while True: while True:
try: try:
answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass, answer = await resolver.resolve(
tcp, backend=backend) name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
)
assert answer.rrset is not None
if answer.rrset.name == name: if answer.rrset.name == name:
return name return name
# otherwise we were CNAMEd or DNAMEd and need to look higher # otherwise we were CNAMEd or DNAMEd and need to look higher

View file

@ -1,26 +0,0 @@
from typing import Union, Optional, List, Any, Dict
from . import exception, rdataclass, name, rdatatype, asyncbackend
async def resolve(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...
async def resolve_address(self, ipaddr: str,
*args: Any, **kwargs: Optional[Dict]):
...
class Resolver:
def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
configure : Optional[bool] = True):
self.nameservers : List[str]
async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
rdclass : Union[int,str] = rdataclass.IN,
tcp : bool = False, source : Optional[str] = None,
raise_on_no_answer=True, source_port : int = 0,
lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...

File diff suppressed because it is too large Load diff

View file

@ -1,21 +0,0 @@
from typing import Union, Dict, Tuple, Optional
from . import rdataset, rrset, exception, name, rdtypes, rdata, node
import dns.rdtypes.ANY.DS as DS
import dns.rdtypes.ANY.DNSKEY as DNSKEY
_have_pyca : bool
def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
...
def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
...
class ValidationFailure(exception.DNSException):
...
def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
...
def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
...

71
lib/dns/dnssectypes.py Normal file
View file

@ -0,0 +1,71 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNSSEC-related types."""
# This is a separate file to avoid import circularity between dns.dnssec and
# the implementations of the DS and DNSKEY types.
import dns.enum
class Algorithm(dns.enum.IntEnum):
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
ED25519 = 15
ED448 = 16
INDIRECT = 252
PRIVATEDNS = 253
PRIVATEOID = 254
@classmethod
def _maximum(cls):
return 255
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
@classmethod
def _maximum(cls):
return 255
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
SHA1 = 1
@classmethod
def _maximum(cls):
return 255

View file

@ -17,15 +17,19 @@
"""DNS E.164 helpers.""" """DNS E.164 helpers."""
from typing import Iterable, Optional, Union
import dns.exception import dns.exception
import dns.name import dns.name
import dns.resolver import dns.resolver
#: The public E.164 domain. #: The public E.164 domain.
public_enum_domain = dns.name.from_text('e164.arpa.') public_enum_domain = dns.name.from_text("e164.arpa.")
def from_e164(text, origin=public_enum_domain): def from_e164(
text: str, origin: Optional[dns.name.Name] = public_enum_domain
) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose """Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number. value is the ENUM domain name for that number.
@ -42,10 +46,14 @@ def from_e164(text, origin=public_enum_domain):
parts = [d for d in text if d.isdigit()] parts = [d for d in text if d.isdigit()]
parts.reverse() parts.reverse()
return dns.name.from_text('.'.join(parts), origin=origin) return dns.name.from_text(".".join(parts), origin=origin)
def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): def to_e164(
name: dns.name.Name,
origin: Optional[dns.name.Name] = public_enum_domain,
want_plus_prefix: bool = True,
) -> str:
"""Convert an ENUM domain name into an E.164 number. """Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred Note that dnspython does not have any information about preferred
@ -69,15 +77,19 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
name = name.relativize(origin) name = name.relativize(origin)
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1] dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
if len(dlabels) != len(name.labels): if len(dlabels) != len(name.labels):
raise dns.exception.SyntaxError('non-digit labels in ENUM domain name') raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
dlabels.reverse() dlabels.reverse()
text = b''.join(dlabels) text = b"".join(dlabels)
if want_plus_prefix: if want_plus_prefix:
text = b'+' + text text = b"+" + text
return text.decode() return text.decode()
def query(number, domains, resolver=None): def query(
number: str,
domains: Iterable[Union[dns.name.Name, str]],
resolver: Optional[dns.resolver.Resolver] = None,
) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains. """Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
@ -98,7 +110,7 @@ def query(number, domains, resolver=None):
domain = dns.name.from_text(domain) domain = dns.name.from_text(domain)
qname = dns.e164.from_e164(number, domain) qname = dns.e164.from_e164(number, domain)
try: try:
return resolver.resolve(qname, 'NAPTR') return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e: except dns.resolver.NXDOMAIN as e:
e_nx += e e_nx += e
raise e_nx raise e_nx

View file

@ -1,10 +0,0 @@
from typing import Optional, Iterable
from . import name, resolver
def from_e164(text : str, origin=name.Name(".")) -> name.Name:
...
def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
...
def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
...

View file

@ -17,6 +17,8 @@
"""EDNS Options""" """EDNS Options"""
from typing import Any, Dict, Optional, Union
import math import math
import socket import socket
import struct import struct
@ -24,6 +26,7 @@ import struct
import dns.enum import dns.enum
import dns.inet import dns.inet
import dns.rdata import dns.rdata
import dns.wire
class OptionType(dns.enum.IntEnum): class OptionType(dns.enum.IntEnum):
@ -59,14 +62,14 @@ class Option:
"""Base class for all EDNS option types.""" """Base class for all EDNS option types."""
def __init__(self, otype): def __init__(self, otype: Union[OptionType, str]):
"""Initialize an option. """Initialize an option.
*otype*, an ``int``, is the option type. *otype*, a ``dns.edns.OptionType``, is the option type.
""" """
self.otype = OptionType.make(otype) self.otype = OptionType.make(otype)
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
"""Convert an option to wire format. """Convert an option to wire format.
Returns a ``bytes`` or ``None``. Returns a ``bytes`` or ``None``.
@ -75,10 +78,10 @@ class Option:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type. *otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be *parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length. restructed to the option length.
@ -115,26 +118,22 @@ class Option:
return self._cmp(other) != 0 return self._cmp(other) != 0
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) < 0 return self._cmp(other) < 0
def __le__(self, other): def __le__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) <= 0 return self._cmp(other) <= 0
def __ge__(self, other): def __ge__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) >= 0 return self._cmp(other) >= 0
def __gt__(self, other): def __gt__(self, other):
if not isinstance(other, Option) or \ if not isinstance(other, Option) or self.otype != other.otype:
self.otype != other.otype:
return NotImplemented return NotImplemented
return self._cmp(other) > 0 return self._cmp(other) > 0
@ -142,7 +141,7 @@ class Option:
return self.to_text() return self.to_text()
class GenericOption(Option): class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class """Generic Option Class
@ -150,28 +149,31 @@ class GenericOption(Option):
implementation. implementation.
""" """
def __init__(self, otype, data): def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
super().__init__(otype) super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True) self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
if file: if file:
file.write(self.data) file.write(self.data)
return None
else: else:
return self.data return self.data
def to_text(self): def to_text(self) -> str:
return "Generic %d" % self.otype return "Generic %d" % self.otype
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
return cls(otype, parser.get_remaining()) return cls(otype, parser.get_remaining())
class ECSOption(Option): class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)""" """EDNS Client Subnet (ECS, RFC7871)"""
def __init__(self, address, srclen=None, scopelen=0): def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information. """*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the *srclen*, an ``int``, the source prefix length, which is the
@ -200,8 +202,9 @@ class ECSOption(Option):
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32) srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32) scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen) else: # pragma: no cover (this will never happen)
raise ValueError('Bad address family') raise ValueError("Bad address family")
assert srclen is not None
self.address = address self.address = address
self.srclen = srclen self.srclen = srclen
self.scopelen = scopelen self.scopelen = scopelen
@ -214,16 +217,14 @@ class ECSOption(Option):
self.addrdata = addrdata[:nbytes] self.addrdata = addrdata[:nbytes]
nbits = srclen % 8 nbits = srclen % 8
if nbits != 0: if nbits != 0:
last = struct.pack('B', last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last self.addrdata = self.addrdata[:-1] + last
def to_text(self): def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen, return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
self.scopelen)
@staticmethod @staticmethod
def from_text(text): def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption` """Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option. *text*, a `str`, the text form of the option.
@ -246,7 +247,7 @@ class ECSOption(Option):
>>> # it understands results from `dns.edns.ECSOption.to_text()` >>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') >>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
""" """
optional_prefix = 'ECS' optional_prefix = "ECS"
tokens = text.split() tokens = text.split()
ecs_text = None ecs_text = None
if len(tokens) == 1: if len(tokens) == 1:
@ -257,47 +258,53 @@ class ECSOption(Option):
ecs_text = tokens[1] ecs_text = tokens[1]
else: else:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError('could not parse ECS from "{}"'.format(text))
n_slashes = ecs_text.count('/') n_slashes = ecs_text.count("/")
if n_slashes == 1: if n_slashes == 1:
address, srclen = ecs_text.split('/') address, tsrclen = ecs_text.split("/")
scope = 0 tscope = "0"
elif n_slashes == 2: elif n_slashes == 2:
address, srclen, scope = ecs_text.split('/') address, tsrclen, tscope = ecs_text.split("/")
else: else:
raise ValueError('could not parse ECS from "{}"'.format(text)) raise ValueError('could not parse ECS from "{}"'.format(text))
try: try:
scope = int(scope) scope = int(tscope)
except ValueError: except ValueError:
raise ValueError('invalid scope ' + raise ValueError(
'"{}": scope must be an integer'.format(scope)) "invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
try: try:
srclen = int(srclen) srclen = int(tsrclen)
except ValueError: except ValueError:
raise ValueError('invalid srclen ' + raise ValueError(
'"{}": srclen must be an integer'.format(srclen)) "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
)
return ECSOption(address, srclen, scope) return ECSOption(address, srclen, scope)
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + value = (
self.addrdata) struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
)
if file: if file:
file.write(value) file.write(value)
return None
else: else:
return value return value
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(
family, src, scope = parser.get_struct('!HBB') cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
family, src, scope = parser.get_struct("!HBB")
addrlen = int(math.ceil(src / 8.0)) addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen) prefix = parser.get_bytes(addrlen)
if family == 1: if family == 1:
pad = 4 - addrlen pad = 4 - addrlen
addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad) addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
elif family == 2: elif family == 2:
pad = 16 - addrlen pad = 16 - addrlen
addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad) addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
else: else:
raise ValueError('unsupported family') raise ValueError("unsupported family")
return cls(addr, src, scope) return cls(addr, src, scope)
@ -334,10 +341,10 @@ class EDECode(dns.enum.IntEnum):
return 65535 return 65535
class EDEOption(Option): class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)""" """Extended DNS Error (EDE, RFC8914)"""
def __init__(self, code, text=None): def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error. extended error.
@ -349,49 +356,50 @@ class EDEOption(Option):
self.code = EDECode.make(code) self.code = EDECode.make(code)
if text is not None and not isinstance(text, str): if text is not None and not isinstance(text, str):
raise ValueError('text must be string or None') raise ValueError("text must be string or None")
self.code = code
self.text = text self.text = text
def to_text(self): def to_text(self) -> str:
output = f'EDE {self.code}' output = f"EDE {self.code}"
if self.text is not None: if self.text is not None:
output += f': {self.text}' output += f": {self.text}"
return output return output
def to_wire(self, file=None): def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = struct.pack('!H', self.code) value = struct.pack("!H", self.code)
if self.text is not None: if self.text is not None:
value += self.text.encode('utf8') value += self.text.encode("utf8")
if file: if file:
file.write(value) file.write(value)
return None
else: else:
return value return value
@classmethod @classmethod
def from_wire_parser(cls, otype, parser): def from_wire_parser(
code = parser.get_uint16() cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
the_code = EDECode.make(parser.get_uint16())
text = parser.get_remaining() text = parser.get_remaining()
if text: if text:
if text[-1] == 0: # text MAY be null-terminated if text[-1] == 0: # text MAY be null-terminated
text = text[:-1] text = text[:-1]
text = text.decode('utf8') btext = text.decode("utf8")
else: else:
text = None btext = None
return cls(code, text) return cls(the_code, btext)
_type_to_class = { _type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption, OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption, OptionType.EDE: EDEOption,
} }
def get_option_class(otype): def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type. """Return the class for the specified option type.
The GenericOption class is used if a more specific class is not The GenericOption class is used if a more specific class is not
@ -404,7 +412,9 @@ def get_option_class(otype):
return cls return cls
def option_from_wire_parser(otype, parser): def option_from_wire_parser(
otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type. *otype*, an ``int``, is the option type.
@ -414,12 +424,14 @@ def option_from_wire_parser(otype, parser):
Returns an instance of a subclass of ``dns.edns.Option``. Returns an instance of a subclass of ``dns.edns.Option``.
""" """
cls = get_option_class(otype) the_otype = OptionType.make(otype)
otype = OptionType.make(otype) cls = get_option_class(the_otype)
return cls.from_wire_parser(otype, parser) return cls.from_wire_parser(otype, parser)
def option_from_wire(otype, wire, current, olen): def option_from_wire(
otype: Union[OptionType, str], wire: bytes, current: int, olen: int
) -> Option:
"""Build an EDNS option object from wire format. """Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type. *otype*, an ``int``, is the option type.
@ -437,7 +449,8 @@ def option_from_wire(otype, wire, current, olen):
with parser.restrict_to(olen): with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser) return option_from_wire_parser(otype, parser)
def register_type(implementation, otype):
def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type. """Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``. *implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
@ -447,6 +460,7 @@ def register_type(implementation, otype):
_type_to_class[otype] = implementation _type_to_class[otype] = implementation
### BEGIN generated OptionType constants ### BEGIN generated OptionType constants
NSID = OptionType.NSID NSID = OptionType.NSID

View file

@ -15,14 +15,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from typing import Any, Optional
import os import os
import hashlib import hashlib
import random import random
import threading
import time import time
try:
import threading as _threading
except ImportError: # pragma: no cover
import dummy_threading as _threading # type: ignore
class EntropyPool: class EntropyPool:
@ -32,51 +31,51 @@ class EntropyPool:
# leaving this code doesn't hurt anything as the library code # leaving this code doesn't hurt anything as the library code
# is used if present. # is used if present.
def __init__(self, seed=None): def __init__(self, seed: Optional[bytes] = None):
self.pool_index = 0 self.pool_index = 0
self.digest = None self.digest: Optional[bytearray] = None
self.next_byte = 0 self.next_byte = 0
self.lock = _threading.Lock() self.lock = threading.Lock()
self.hash = hashlib.sha1() self.hash = hashlib.sha1()
self.hash_len = 20 self.hash_len = 20
self.pool = bytearray(b'\0' * self.hash_len) self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None: if seed is not None:
self._stir(bytearray(seed)) self._stir(seed)
self.seeded = True self.seeded = True
self.seed_pid = os.getpid() self.seed_pid = os.getpid()
else: else:
self.seeded = False self.seeded = False
self.seed_pid = 0 self.seed_pid = 0
def _stir(self, entropy): def _stir(self, entropy: bytes) -> None:
for c in entropy: for c in entropy:
if self.pool_index == self.hash_len: if self.pool_index == self.hash_len:
self.pool_index = 0 self.pool_index = 0
b = c & 0xff b = c & 0xFF
self.pool[self.pool_index] ^= b self.pool[self.pool_index] ^= b
self.pool_index += 1 self.pool_index += 1
def stir(self, entropy): def stir(self, entropy: bytes) -> None:
with self.lock: with self.lock:
self._stir(entropy) self._stir(entropy)
def _maybe_seed(self): def _maybe_seed(self) -> None:
if not self.seeded or self.seed_pid != os.getpid(): if not self.seeded or self.seed_pid != os.getpid():
try: try:
seed = os.urandom(16) seed = os.urandom(16)
except Exception: # pragma: no cover except Exception: # pragma: no cover
try: try:
with open('/dev/urandom', 'rb', 0) as r: with open("/dev/urandom", "rb", 0) as r:
seed = r.read(16) seed = r.read(16)
except Exception: except Exception:
seed = str(time.time()) seed = str(time.time()).encode()
self.seeded = True self.seeded = True
self.seed_pid = os.getpid() self.seed_pid = os.getpid()
self.digest = None self.digest = None
seed = bytearray(seed) seed = bytearray(seed)
self._stir(seed) self._stir(seed)
def random_8(self): def random_8(self) -> int:
with self.lock: with self.lock:
self._maybe_seed() self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len: if self.digest is None or self.next_byte == self.hash_len:
@ -88,16 +87,16 @@ class EntropyPool:
self.next_byte += 1 self.next_byte += 1
return value return value
def random_16(self): def random_16(self) -> int:
return self.random_8() * 256 + self.random_8() return self.random_8() * 256 + self.random_8()
def random_32(self): def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16() return self.random_16() * 65536 + self.random_16()
def random_between(self, first, last): def random_between(self, first: int, last: int) -> int:
size = last - first + 1 size = last - first + 1
if size > 4294967296: if size > 4294967296:
raise ValueError('too big') raise ValueError("too big")
if size > 65536: if size > 65536:
rand = self.random_32 rand = self.random_32
max = 4294967295 max = 4294967295
@ -109,20 +108,24 @@ class EntropyPool:
max = 255 max = 255
return first + size * rand() // (max + 1) return first + size * rand() // (max + 1)
pool = EntropyPool() pool = EntropyPool()
system_random: Optional[Any]
try: try:
system_random = random.SystemRandom() system_random = random.SystemRandom()
except Exception: # pragma: no cover except Exception: # pragma: no cover
system_random = None system_random = None
def random_16():
def random_16() -> int:
if system_random is not None: if system_random is not None:
return system_random.randrange(0, 65536) return system_random.randrange(0, 65536)
else: else:
return pool.random_16() return pool.random_16()
def between(first, last):
def between(first: int, last: int) -> int:
if system_random is not None: if system_random is not None:
return system_random.randrange(first, last + 1) return system_random.randrange(first, last + 1)
else: else:

View file

@ -1,10 +0,0 @@
from typing import Optional
from random import SystemRandom
system_random : Optional[SystemRandom]
def random_16() -> int:
pass
def between(first: int, last: int) -> int:
pass

View file

@ -17,6 +17,7 @@
import enum import enum
class IntEnum(enum.IntEnum): class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _check_value(cls, value): def _check_value(cls, value):
@ -32,6 +33,9 @@ class IntEnum(enum.IntEnum):
return cls[text] return cls[text]
except KeyError: except KeyError:
pass pass
value = cls._extra_from_text(text)
if value:
return value
prefix = cls._prefix() prefix = cls._prefix()
if text.startswith(prefix) and text[len(prefix) :].isdigit(): if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix) :]) value = int(text[len(prefix) :])
@ -46,9 +50,13 @@ class IntEnum(enum.IntEnum):
def to_text(cls, value): def to_text(cls, value):
cls._check_value(value) cls._check_value(value)
try: try:
return cls(value).name text = cls(value).name
except ValueError: except ValueError:
return f"{cls._prefix()}{value}" text = None
text = cls._extra_to_text(value, text)
if text is None:
text = f"{cls._prefix()}{value}"
return text
@classmethod @classmethod
def make(cls, value): def make(cls, value):
@ -83,7 +91,15 @@ class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _prefix(cls): def _prefix(cls):
return '' return ""
@classmethod
def _extra_from_text(cls, text): # pylint: disable=W0613
return None
@classmethod
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
return current_text
@classmethod @classmethod
def _unknown_exception_class(cls): def _unknown_exception_class(cls):

View file

@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``. always be subclasses of ``DNSException``.
""" """
from typing import Optional, Set
class DNSException(Exception): class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions. """Abstract base class shared by all dnspython exceptions.
@ -44,14 +48,15 @@ class DNSException(Exception):
and ``fmt`` class variables to get nice parametrized messages. and ``fmt`` class variables to get nice parametrized messages.
""" """
msg = None # non-parametrized message msg: Optional[str] = None # non-parametrized message
supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check) supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt = None # message parametrized with results from _fmt_kwargs fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs) self._check_params(*args, **kwargs)
if kwargs: if kwargs:
self.kwargs = self._check_kwargs(**kwargs) # This call to a virtual method from __init__ is ok in our usage
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self) self.msg = str(self)
else: else:
self.kwargs = dict() # defined but empty for old mode exceptions self.kwargs = dict() # defined but empty for old mode exceptions
@ -68,14 +73,15 @@ class DNSException(Exception):
For sanity we do not allow to mix old and new behavior.""" For sanity we do not allow to mix old and new behavior."""
if args or kwargs: if args or kwargs:
assert bool(args) != bool(kwargs), \ assert bool(args) != bool(
'keyword arguments are mutually exclusive with positional args' kwargs
), "keyword arguments are mutually exclusive with positional args"
def _check_kwargs(self, **kwargs): def _check_kwargs(self, **kwargs):
if kwargs: if kwargs:
assert set(kwargs.keys()) == self.supp_kwargs, \ assert (
'following set of keyword args is required: %s' % ( set(kwargs.keys()) == self.supp_kwargs
self.supp_kwargs) ), "following set of keyword args is required: %s" % (self.supp_kwargs)
return kwargs return kwargs
def _fmt_kwargs(self, **kwargs): def _fmt_kwargs(self, **kwargs):
@ -124,9 +130,15 @@ class TooBig(DNSException):
class Timeout(DNSException): class Timeout(DNSException):
"""The DNS operation timed out.""" """The DNS operation timed out."""
supp_kwargs = {'timeout'}
supp_kwargs = {"timeout"}
fmt = "The DNS operation timed out after {timeout:.3f} seconds" fmt = "The DNS operation timed out after {timeout:.3f} seconds"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ExceptionWrapper: class ExceptionWrapper:
def __init__(self, exception_class): def __init__(self, exception_class):
@ -136,7 +148,6 @@ class ExceptionWrapper:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val, if exc_type is not None and not isinstance(exc_val, self.exception_class):
self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val raise self.exception_class(str(exc_val)) from exc_val
return False return False

View file

@ -1,12 +0,0 @@
from typing import Set, Optional, Dict
class DNSException(Exception):
supp_kwargs : Set[str]
kwargs : Optional[Dict]
fmt : Optional[str]
class SyntaxError(DNSException): ...
class FormError(DNSException): ...
class Timeout(DNSException): ...
class TooBig(DNSException): ...
class UnexpectedEnd(SyntaxError): ...

View file

@ -17,10 +17,13 @@
"""DNS Message Flags.""" """DNS Message Flags."""
from typing import Any
import enum import enum
# Standard DNS flags # Standard DNS flags
class Flag(enum.IntFlag): class Flag(enum.IntFlag):
#: Query Response #: Query Response
QR = 0x8000 QR = 0x8000
@ -40,12 +43,13 @@ class Flag(enum.IntFlag):
# EDNS flags # EDNS flags
class EDNSFlag(enum.IntFlag): class EDNSFlag(enum.IntFlag):
#: DNSSEC answer OK #: DNSSEC answer OK
DO = 0x8000 DO = 0x8000
def _from_text(text, enum_class): def _from_text(text: str, enum_class: Any) -> int:
flags = 0 flags = 0
tokens = text.split() tokens = text.split()
for t in tokens: for t in tokens:
@ -53,15 +57,15 @@ def _from_text(text, enum_class):
return flags return flags
def _to_text(flags, enum_class): def _to_text(flags: int, enum_class: Any) -> str:
text_flags = [] text_flags = []
for k, v in enum_class.__members__.items(): for k, v in enum_class.__members__.items():
if flags & v != 0: if flags & v != 0:
text_flags.append(k) text_flags.append(k)
return ' '.join(text_flags) return " ".join(text_flags)
def from_text(text): def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags """Convert a space-separated list of flag text values into a flags
value. value.
@ -71,7 +75,7 @@ def from_text(text):
return _from_text(text, Flag) return _from_text(text, Flag)
def to_text(flags): def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text """Convert a flags value into a space-separated list of flag text
values. values.
@ -81,7 +85,7 @@ def to_text(flags):
return _to_text(flags, Flag) return _to_text(flags, Flag)
def edns_from_text(text): def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS """Convert a space-separated list of EDNS flag text values into a EDNS
flags value. flags value.
@ -91,7 +95,7 @@ def edns_from_text(text):
return _from_text(text, EDNSFlag) return _from_text(text, EDNSFlag)
def edns_to_text(flags): def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag """Convert an EDNS flags value into a space-separated list of EDNS flag
text values. text values.
@ -100,6 +104,7 @@ def edns_to_text(flags):
return _to_text(flags, EDNSFlag) return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants ### BEGIN generated Flag constants
QR = Flag.QR QR = Flag.QR

View file

@ -17,9 +17,12 @@
"""DNS GENERATE range conversion.""" """DNS GENERATE range conversion."""
from typing import Tuple
import dns import dns
def from_text(text):
def from_text(text: str) -> Tuple[int, int, int]:
"""Convert the text form of a range in a ``$GENERATE`` statement to an """Convert the text form of a range in a ``$GENERATE`` statement to an
integer. integer.
@ -31,22 +34,22 @@ def from_text(text):
start = -1 start = -1
stop = -1 stop = -1
step = 1 step = 1
cur = '' cur = ""
state = 0 state = 0
# state 0 1 2 # state 0 1 2
# x - y / z # x - y / z
if text and text[0] == '-': if text and text[0] == "-":
raise dns.exception.SyntaxError("Start cannot be a negative number") raise dns.exception.SyntaxError("Start cannot be a negative number")
for c in text: for c in text:
if c == '-' and state == 0: if c == "-" and state == 0:
start = int(cur) start = int(cur)
cur = '' cur = ""
state = 1 state = 1
elif c == '/': elif c == "/":
stop = int(cur) stop = int(cur)
cur = '' cur = ""
state = 2 state = 2
elif c.isdigit(): elif c.isdigit():
cur += c cur += c
@ -64,6 +67,6 @@ def from_text(text):
assert step >= 1 assert step >= 1
assert start >= 0 assert start >= 0
if start > stop: if start > stop:
raise dns.exception.SyntaxError('start must be <= stop') raise dns.exception.SyntaxError("start must be <= stop")
return (start, stop, step) return (start, stop, step)

View file

@ -1,32 +1,25 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc from typing import Any
import sys
import collections.abc
# pylint: disable=unused-import
if sys.version_info >= (3, 7):
odict = dict
from dns._immutable_ctx import immutable from dns._immutable_ctx import immutable
else:
# pragma: no cover
from collections import OrderedDict as odict
from dns._immutable_attr import immutable # noqa
# pylint: enable=unused-import
@immutable @immutable
class Dict(collections.abc.Mapping): class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
def __init__(self, dictionary, no_copy=False): def __init__(self, dictionary: Any, no_copy: bool = False):
"""Make an immutable dictionary from the specified dictionary. """Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external of copied. Only set this if you are sure there will be no external
references to the dictionary. references to the dictionary.
""" """
if no_copy and isinstance(dictionary, odict): if no_copy and isinstance(dictionary, dict):
self._odict = dictionary self._odict = dictionary
else: else:
self._odict = odict(dictionary) self._odict = dict(dictionary)
self._hash = None self._hash = None
def __getitem__(self, key): def __getitem__(self, key):
@ -37,7 +30,7 @@ class Dict(collections.abc.Mapping):
h = 0 h = 0
for key in sorted(self._odict.keys()): for key in sorted(self._odict.keys()):
h ^= hash(key) h ^= hash(key)
object.__setattr__(self, '_hash', h) object.__setattr__(self, "_hash", h)
# this does return an int, but pylint doesn't figure that out # this does return an int, but pylint doesn't figure that out
return self._hash return self._hash
@ -48,7 +41,7 @@ class Dict(collections.abc.Mapping):
return iter(self._odict) return iter(self._odict)
def constify(o): def constify(o: Any) -> Any:
""" """
Convert mutable types to immutable types. Convert mutable types to immutable types.
""" """
@ -63,7 +56,7 @@ def constify(o):
if isinstance(o, list): if isinstance(o, list):
return tuple(constify(elt) for elt in o) return tuple(constify(elt) for elt in o)
if isinstance(o, dict): if isinstance(o, dict):
cdict = odict() cdict = dict()
for k, v in o.items(): for k, v in o.items():
cdict[k] = constify(v) cdict[k] = constify(v)
return Dict(cdict, True) return Dict(cdict, True)

View file

@ -17,6 +17,8 @@
"""Generic Internet address helper functions.""" """Generic Internet address helper functions."""
from typing import Any, Optional, Tuple
import socket import socket
import dns.ipv4 import dns.ipv4
@ -30,7 +32,7 @@ AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6 AF_INET6 = socket.AF_INET6
def inet_pton(family, text): def inet_pton(family: int, text: str) -> bytes:
"""Convert the textual form of a network address into its binary form. """Convert the textual form of a network address into its binary form.
*family* is an ``int``, the address family. *family* is an ``int``, the address family.
@ -51,7 +53,7 @@ def inet_pton(family, text):
raise NotImplementedError raise NotImplementedError
def inet_ntop(family, address): def inet_ntop(family: int, address: bytes) -> str:
"""Convert the binary form of a network address into its textual form. """Convert the binary form of a network address into its textual form.
*family* is an ``int``, the address family. *family* is an ``int``, the address family.
@ -72,7 +74,7 @@ def inet_ntop(family, address):
raise NotImplementedError raise NotImplementedError
def af_for_address(text): def af_for_address(text: str) -> int:
"""Determine the address family of a textual-form network address. """Determine the address family of a textual-form network address.
*text*, a ``str``, the textual address. *text*, a ``str``, the textual address.
@ -94,7 +96,7 @@ def af_for_address(text):
raise ValueError raise ValueError
def is_multicast(text): def is_multicast(text: str) -> bool:
"""Is the textual-form network address a multicast address? """Is the textual-form network address a multicast address?
*text*, a ``str``, the textual address. *text*, a ``str``, the textual address.
@ -116,7 +118,7 @@ def is_multicast(text):
raise ValueError raise ValueError
def is_address(text): def is_address(text: str) -> bool:
"""Is the specified string an IPv4 or IPv6 address? """Is the specified string an IPv4 or IPv6 address?
*text*, a ``str``, the textual address. *text*, a ``str``, the textual address.
@ -135,7 +137,9 @@ def is_address(text):
return False return False
def low_level_address_tuple(high_tuple, af=None): def low_level_address_tuple(
high_tuple: Tuple[str, int], af: Optional[int] = None
) -> Any:
"""Given a "high-level" address tuple, i.e. """Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls. suitable for use in socket calls.
@ -143,7 +147,6 @@ def low_level_address_tuple(high_tuple, af=None):
If an *af* other than ``None`` is provided, it is assumed the If an *af* other than ``None`` is provided, it is assumed the
address in the high-level tuple is valid and has that af. If af address in the high-level tuple is valid and has that af. If af
is ``None``, then af_for_address will be called. is ``None``, then af_for_address will be called.
""" """
address, port = high_tuple address, port = high_tuple
if af is None: if af is None:
@ -151,7 +154,7 @@ def low_level_address_tuple(high_tuple, af=None):
if af == AF_INET: if af == AF_INET:
return (address, port) return (address, port)
elif af == AF_INET6: elif af == AF_INET6:
i = address.find('%') i = address.find("%")
if i < 0: if i < 0:
# no scope, shortcut! # no scope, shortcut!
return (address, port, 0, 0) return (address, port, 0, 0)
@ -167,4 +170,4 @@ def low_level_address_tuple(high_tuple, af=None):
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
return tup return tup
else: else:
raise NotImplementedError(f'unknown address family {af}') raise NotImplementedError(f"unknown address family {af}")

View file

@ -1,4 +0,0 @@
from typing import Union
from socket import AddressFamily
AF_INET6 : Union[int, AddressFamily]

View file

@ -17,11 +17,14 @@
"""IPv4 helper functions.""" """IPv4 helper functions."""
from typing import Union
import struct import struct
import dns.exception import dns.exception
def inet_ntoa(address):
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv4 address in binary form to text form. """Convert an IPv4 address in binary form to text form.
*address*, a ``bytes``, the IPv4 address in binary form. *address*, a ``bytes``, the IPv4 address in binary form.
@ -31,30 +34,32 @@ def inet_ntoa(address):
if len(address) != 4: if len(address) != 4:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
return ('%u.%u.%u.%u' % (address[0], address[1], return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
address[2], address[3]))
def inet_aton(text):
def inet_aton(text: Union[str, bytes]) -> bytes:
"""Convert an IPv4 address in text form to binary form. """Convert an IPv4 address in text form to binary form.
*text*, a ``str``, the IPv4 address in textual form. *text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Returns a ``bytes``. Returns a ``bytes``.
""" """
if not isinstance(text, bytes): if not isinstance(text, bytes):
text = text.encode() btext = text.encode()
parts = text.split(b'.') else:
btext = text
parts = btext.split(b".")
if len(parts) != 4: if len(parts) != 4:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
for part in parts: for part in parts:
if not part.isdigit(): if not part.isdigit():
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
if len(part) > 1 and part[0] == ord('0'): if len(part) > 1 and part[0] == ord("0"):
# No leading zeros # No leading zeros
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
try: try:
b = [int(part) for part in parts] b = [int(part) for part in parts]
return struct.pack('BBBB', *b) return struct.pack("BBBB", *b)
except Exception: except Exception:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError

View file

@ -17,15 +17,18 @@
"""IPv6 helper functions.""" """IPv6 helper functions."""
from typing import List, Union
import re import re
import binascii import binascii
import dns.exception import dns.exception
import dns.ipv4 import dns.ipv4
_leading_zero = re.compile(r'0+([0-9a-f]+)') _leading_zero = re.compile(r"0+([0-9a-f]+)")
def inet_ntoa(address):
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv6 address in binary form to text form. """Convert an IPv6 address in binary form to text form.
*address*, a ``bytes``, the IPv6 address in binary form. *address*, a ``bytes``, the IPv6 address in binary form.
@ -58,7 +61,7 @@ def inet_ntoa(address):
start = -1 start = -1
last_was_zero = False last_was_zero = False
for i in range(8): for i in range(8):
if chunks[i] != '0': if chunks[i] != "0":
if last_was_zero: if last_was_zero:
end = i end = i
current_len = end - start current_len = end - start
@ -76,27 +79,30 @@ def inet_ntoa(address):
best_start = start best_start = start
best_len = current_len best_len = current_len
if best_len > 1: if best_len > 1:
if best_start == 0 and \ if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
(best_len == 6 or
best_len == 5 and chunks[5] == 'ffff'):
# We have an embedded IPv4 address # We have an embedded IPv4 address
if best_len == 6: if best_len == 6:
prefix = '::' prefix = "::"
else: else:
prefix = '::ffff:' prefix = "::ffff:"
hex = prefix + dns.ipv4.inet_ntoa(address[12:]) thex = prefix + dns.ipv4.inet_ntoa(address[12:])
else: else:
hex = ':'.join(chunks[:best_start]) + '::' + \ thex = (
':'.join(chunks[best_start + best_len:]) ":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else: else:
hex = ':'.join(chunks) thex = ":".join(chunks)
return hex return thex
_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$')
_colon_colon_start = re.compile(br'::.*')
_colon_colon_end = re.compile(br'.*::$')
def inet_aton(text, ignore_scope=False): _v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
_colon_colon_start = re.compile(rb"::.*")
_colon_colon_end = re.compile(rb".*::$")
def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form. """Convert an IPv6 address in text form to binary form.
*text*, a ``str``, the IPv6 address in textual form. *text*, a ``str``, the IPv6 address in textual form.
@ -111,82 +117,88 @@ def inet_aton(text, ignore_scope=False):
# Our aim here is not something fast; we just want something that works. # Our aim here is not something fast; we just want something that works.
# #
if not isinstance(text, bytes): if not isinstance(text, bytes):
text = text.encode() btext = text.encode()
else:
btext = text
if ignore_scope: if ignore_scope:
parts = text.split(b'%') parts = btext.split(b"%")
l = len(parts) l = len(parts)
if l == 2: if l == 2:
text = parts[0] btext = parts[0]
elif l > 2: elif l > 2:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
if text == b'': if btext == b"":
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
elif text.endswith(b':') and not text.endswith(b'::'): elif btext.endswith(b":") and not btext.endswith(b"::"):
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
elif text.startswith(b':') and not text.startswith(b'::'): elif btext.startswith(b":") and not btext.startswith(b"::"):
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
elif text == b'::': elif btext == b"::":
text = b'0::' btext = b"0::"
# #
# Get rid of the icky dot-quad syntax if we have it. # Get rid of the icky dot-quad syntax if we have it.
# #
m = _v4_ending.match(text) m = _v4_ending.match(btext)
if m is not None: if m is not None:
b = dns.ipv4.inet_aton(m.group(2)) b = dns.ipv4.inet_aton(m.group(2))
text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), btext = (
b[0], b[1], b[2], "{}:{:02x}{:02x}:{:02x}{:02x}".format(
b[3])).encode() m.group(1).decode(), b[0], b[1], b[2], b[3]
)
).encode()
# #
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to # Try to turn '::<whatever>' into ':<whatever>'; if no match try to
# turn '<whatever>::' into '<whatever>:' # turn '<whatever>::' into '<whatever>:'
# #
m = _colon_colon_start.match(text) m = _colon_colon_start.match(btext)
if m is not None: if m is not None:
text = text[1:] btext = btext[1:]
else: else:
m = _colon_colon_end.match(text) m = _colon_colon_end.match(btext)
if m is not None: if m is not None:
text = text[:-1] btext = btext[:-1]
# #
# Now canonicalize into 8 chunks of 4 hex digits each # Now canonicalize into 8 chunks of 4 hex digits each
# #
chunks = text.split(b':') chunks = btext.split(b":")
l = len(chunks) l = len(chunks)
if l > 8: if l > 8:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
seen_empty = False seen_empty = False
canonical = [] canonical: List[bytes] = []
for c in chunks: for c in chunks:
if c == b'': if c == b"":
if seen_empty: if seen_empty:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
seen_empty = True seen_empty = True
for _ in range(0, 8 - l + 1): for _ in range(0, 8 - l + 1):
canonical.append(b'0000') canonical.append(b"0000")
else: else:
lc = len(c) lc = len(c)
if lc > 4: if lc > 4:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
if lc != 4: if lc != 4:
c = (b'0' * (4 - lc)) + c c = (b"0" * (4 - lc)) + c
canonical.append(c) canonical.append(c)
if l < 8 and not seen_empty: if l < 8 and not seen_empty:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
text = b''.join(canonical) btext = b"".join(canonical)
# #
# Finally we can go to binary. # Finally we can go to binary.
# #
try: try:
return binascii.unhexlify(text) return binascii.unhexlify(btext)
except (binascii.Error, TypeError): except (binascii.Error, TypeError):
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
_mapped_prefix = b'\x00' * 10 + b'\xff\xff'
def is_mapped(address): _mapped_prefix = b"\x00" * 10 + b"\xff\xff"
def is_mapped(address: bytes) -> bool:
"""Is the specified address a mapped IPv4 address? """Is the specified address a mapped IPv4 address?
*address*, a ``bytes`` is an IPv6 address in binary form. *address*, a ``bytes`` is an IPv6 address in binary form.

File diff suppressed because it is too large Load diff

View file

@ -1,47 +0,0 @@
from typing import Optional, Dict, List, Tuple, Union
from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode
import hmac
class Message:
def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes:
...
def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int,
covers=rdatatype.NONE, deleting : Optional[int]=None, create=False,
force_unique=False) -> rrset.RRset:
...
def __init__(self, id : Optional[int] =None) -> None:
self.id : int
self.flags = 0
self.sections : List[List[rrset.RRset]] = [[], [], [], []]
self.opt : rrset.RRset = None
self.request_payload = 0
self.keyring = None
self.tsig : rrset.RRset = None
self.request_mac = b''
self.xfr = False
self.origin = None
self.tsig_ctx = None
self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
def is_response(self, other : Message) -> bool:
...
def set_rcode(self, rcode : rcode.Rcode):
...
def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
...
def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
question_only=False, one_rr_per_rrset=False,
ignore_trailing=False) -> Message:
...
def make_response(query : Message, recursion_available=False, our_payload=8192,
fudge=300) -> Message:
...
def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None,
want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None,
request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message:
...

View file

@ -18,32 +18,61 @@
"""DNS Names. """DNS Names.
""" """
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import copy import copy
import struct import struct
import encodings.idna # type: ignore import encodings.idna # type: ignore
try: try:
import idna # type: ignore import idna # type: ignore
have_idna_2008 = True have_idna_2008 = True
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
have_idna_2008 = False have_idna_2008 = False
import dns.enum
import dns.wire import dns.wire
import dns.exception import dns.exception
import dns.immutable import dns.immutable
# fullcompare() result values
CompressType = Dict["Name", int]
class NameRelation(dns.enum.IntEnum):
"""Name relation result from fullcompare()."""
# This is an IntEnum for backwards compatibility in case anyone
# has hardwired the constants.
#: The compared names have no relationship to each other. #: The compared names have no relationship to each other.
NAMERELN_NONE = 0 NONE = 0
#: the first name is a superdomain of the second. #: the first name is a superdomain of the second.
NAMERELN_SUPERDOMAIN = 1 SUPERDOMAIN = 1
#: The first name is a subdomain of the second. #: The first name is a subdomain of the second.
NAMERELN_SUBDOMAIN = 2 SUBDOMAIN = 2
#: The compared names are equal. #: The compared names are equal.
NAMERELN_EQUAL = 3 EQUAL = 3
#: The compared names have a common ancestor. #: The compared names have a common ancestor.
NAMERELN_COMMONANCESTOR = 4 COMMONANCESTOR = 4
@classmethod
def _maximum(cls):
return cls.COMMONANCESTOR
@classmethod
def _short_name(cls):
return cls.__name__
# Backwards compatibility
NAMERELN_NONE = NameRelation.NONE
NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN
NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN
NAMERELN_EQUAL = NameRelation.EQUAL
NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR
class EmptyLabel(dns.exception.SyntaxError): class EmptyLabel(dns.exception.SyntaxError):
@ -84,6 +113,7 @@ class NoParent(dns.exception.DNSException):
"""An attempt was made to get the parent of the root name """An attempt was made to get the parent of the root name
or the empty name.""" or the empty name."""
class NoIDNA2008(dns.exception.DNSException): class NoIDNA2008(dns.exception.DNSException):
"""IDNA 2008 processing was requested but the idna module is not """IDNA 2008 processing was requested but the idna module is not
available.""" available."""
@ -92,9 +122,47 @@ class NoIDNA2008(dns.exception.DNSException):
class IDNAException(dns.exception.DNSException): class IDNAException(dns.exception.DNSException):
"""IDNA processing raised an exception.""" """IDNA processing raised an exception."""
supp_kwargs = {'idna_exception'} supp_kwargs = {"idna_exception"}
fmt = "IDNA processing exception: {idna_exception}" fmt = "IDNA processing exception: {idna_exception}"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
def _escapify(label: Union[bytes, str]) -> str:
"""Escape the characters in label which need it.
@returns: the escaped string
@rtype: string"""
if isinstance(label, bytes):
# Ordinary DNS label mode. Escape special characters and values
# < 0x20 or > 0x7f.
text = ""
for c in label:
if c in _escaped:
text += "\\" + chr(c)
elif c > 0x20 and c < 0x7F:
text += chr(c)
else:
text += "\\%03d" % c
return text
# Unicode label mode. Escape only special characters and values < 0x20
text = ""
for uc in label:
if uc in _escaped_text:
text += "\\" + uc
elif uc <= "\x20":
text += "\\%03d" % ord(uc)
else:
text += uc
return text
class IDNACodec: class IDNACodec:
"""Abstract base class for IDNA encoder/decoders.""" """Abstract base class for IDNA encoder/decoders."""
@ -102,26 +170,28 @@ class IDNACodec:
def __init__(self): def __init__(self):
pass pass
def is_idna(self, label): def is_idna(self, label: bytes) -> bool:
return label.lower().startswith(b'xn--') return label.lower().startswith(b"xn--")
def encode(self, label): def encode(self, label: str) -> bytes:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def decode(self, label): def decode(self, label: bytes) -> str:
# We do not apply any IDNA policy on decode. # We do not apply any IDNA policy on decode.
if self.is_idna(label): if self.is_idna(label):
try: try:
label = label[4:].decode('punycode') slabel = label[4:].decode("punycode")
return _escapify(slabel)
except Exception as e: except Exception as e:
raise IDNAException(idna_exception=e) raise IDNAException(idna_exception=e)
else:
return _escapify(label) return _escapify(label)
class IDNA2003Codec(IDNACodec): class IDNA2003Codec(IDNACodec):
"""IDNA 2003 encoder/decoder.""" """IDNA 2003 encoder/decoder."""
def __init__(self, strict_decode=False): def __init__(self, strict_decode: bool = False):
"""Initialize the IDNA 2003 encoder/decoder. """Initialize the IDNA 2003 encoder/decoder.
*strict_decode* is a ``bool``. If `True`, then IDNA2003 checking *strict_decode* is a ``bool``. If `True`, then IDNA2003 checking
@ -132,22 +202,22 @@ class IDNA2003Codec(IDNACodec):
super().__init__() super().__init__()
self.strict_decode = strict_decode self.strict_decode = strict_decode
def encode(self, label): def encode(self, label: str) -> bytes:
"""Encode *label*.""" """Encode *label*."""
if label == '': if label == "":
return b'' return b""
try: try:
return encodings.idna.ToASCII(label) return encodings.idna.ToASCII(label)
except UnicodeError: except UnicodeError:
raise LabelTooLong raise LabelTooLong
def decode(self, label): def decode(self, label: bytes) -> str:
"""Decode *label*.""" """Decode *label*."""
if not self.strict_decode: if not self.strict_decode:
return super().decode(label) return super().decode(label)
if label == b'': if label == b"":
return '' return ""
try: try:
return _escapify(encodings.idna.ToUnicode(label)) return _escapify(encodings.idna.ToUnicode(label))
except Exception as e: except Exception as e:
@ -155,16 +225,20 @@ class IDNA2003Codec(IDNACodec):
class IDNA2008Codec(IDNACodec): class IDNA2008Codec(IDNACodec):
"""IDNA 2008 encoder/decoder. """IDNA 2008 encoder/decoder."""
"""
def __init__(self, uts_46=False, transitional=False, def __init__(
allow_pure_ascii=False, strict_decode=False): self,
uts_46: bool = False,
transitional: bool = False,
allow_pure_ascii: bool = False,
strict_decode: bool = False,
):
"""Initialize the IDNA 2008 encoder/decoder. """Initialize the IDNA 2008 encoder/decoder.
*uts_46* is a ``bool``. If True, apply Unicode IDNA *uts_46* is a ``bool``. If True, apply Unicode IDNA
compatibility processing as described in Unicode Technical compatibility processing as described in Unicode Technical
Standard #46 (http://unicode.org/reports/tr46/). Standard #46 (https://unicode.org/reports/tr46/).
If False, do not apply the mapping. The default is False. If False, do not apply the mapping. The default is False.
*transitional* is a ``bool``: If True, use the *transitional* is a ``bool``: If True, use the
@ -188,11 +262,11 @@ class IDNA2008Codec(IDNACodec):
self.allow_pure_ascii = allow_pure_ascii self.allow_pure_ascii = allow_pure_ascii
self.strict_decode = strict_decode self.strict_decode = strict_decode
def encode(self, label): def encode(self, label: str) -> bytes:
if label == '': if label == "":
return b'' return b""
if self.allow_pure_ascii and is_all_ascii(label): if self.allow_pure_ascii and is_all_ascii(label):
encoded = label.encode('ascii') encoded = label.encode("ascii")
if len(encoded) > 63: if len(encoded) > 63:
raise LabelTooLong raise LabelTooLong
return encoded return encoded
@ -203,16 +277,16 @@ class IDNA2008Codec(IDNACodec):
label = idna.uts46_remap(label, False, self.transitional) label = idna.uts46_remap(label, False, self.transitional)
return idna.alabel(label) return idna.alabel(label)
except idna.IDNAError as e: except idna.IDNAError as e:
if e.args[0] == 'Label too long': if e.args[0] == "Label too long":
raise LabelTooLong raise LabelTooLong
else: else:
raise IDNAException(idna_exception=e) raise IDNAException(idna_exception=e)
def decode(self, label): def decode(self, label: bytes) -> str:
if not self.strict_decode: if not self.strict_decode:
return super().decode(label) return super().decode(label)
if label == b'': if label == b"":
return '' return ""
if not have_idna_2008: if not have_idna_2008:
raise NoIDNA2008 raise NoIDNA2008
try: try:
@ -223,8 +297,6 @@ class IDNA2008Codec(IDNACodec):
except (idna.IDNAError, UnicodeError) as e: except (idna.IDNAError, UnicodeError) as e:
raise IDNAException(idna_exception=e) raise IDNAException(idna_exception=e)
_escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
IDNA_2003_Practical = IDNA2003Codec(False) IDNA_2003_Practical = IDNA2003Codec(False)
IDNA_2003_Strict = IDNA2003Codec(True) IDNA_2003_Strict = IDNA2003Codec(True)
@ -235,35 +307,8 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True)
IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
IDNA_2008 = IDNA_2008_Practical IDNA_2008 = IDNA_2008_Practical
def _escapify(label):
"""Escape the characters in label which need it.
@returns: the escaped string
@rtype: string"""
if isinstance(label, bytes):
# Ordinary DNS label mode. Escape special characters and values
# < 0x20 or > 0x7f.
text = ''
for c in label:
if c in _escaped:
text += '\\' + chr(c)
elif c > 0x20 and c < 0x7F:
text += chr(c)
else:
text += '\\%03d' % c
return text
# Unicode label mode. Escape only special characters and values < 0x20 def _validate_labels(labels: Tuple[bytes, ...]) -> None:
text = ''
for c in label:
if c in _escaped_text:
text += '\\' + c
elif c <= '\x20':
text += '\\%03d' % ord(c)
else:
text += c
return text
def _validate_labels(labels):
"""Check for empty labels in the middle of a label sequence, """Check for empty labels in the middle of a label sequence,
labels that are too long, and for too many labels. labels that are too long, and for too many labels.
@ -284,7 +329,7 @@ def _validate_labels(labels):
total += ll + 1 total += ll + 1
if ll > 63: if ll > 63:
raise LabelTooLong raise LabelTooLong
if i < 0 and label == b'': if i < 0 and label == b"":
i = j i = j
j += 1 j += 1
if total > 255: if total > 255:
@ -293,7 +338,7 @@ def _validate_labels(labels):
raise EmptyLabel raise EmptyLabel
def _maybe_convert_to_binary(label): def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
"""If label is ``str``, convert it to ``bytes``. If it is already """If label is ``str``, convert it to ``bytes``. If it is already
``bytes`` just return it. ``bytes`` just return it.
@ -316,14 +361,13 @@ class Name:
of the class are immutable. of the class are immutable.
""" """
__slots__ = ['labels'] __slots__ = ["labels"]
def __init__(self, labels): def __init__(self, labels: Iterable[Union[bytes, str]]):
"""*labels* is any iterable whose values are ``str`` or ``bytes``. """*labels* is any iterable whose values are ``str`` or ``bytes``."""
"""
labels = [_maybe_convert_to_binary(x) for x in labels] blabels = [_maybe_convert_to_binary(x) for x in labels]
self.labels = tuple(labels) self.labels = tuple(blabels)
_validate_labels(self.labels) _validate_labels(self.labels)
def __copy__(self): def __copy__(self):
@ -334,29 +378,29 @@ class Name:
def __getstate__(self): def __getstate__(self):
# Names can be pickled # Names can be pickled
return {'labels': self.labels} return {"labels": self.labels}
def __setstate__(self, state): def __setstate__(self, state):
super().__setattr__('labels', state['labels']) super().__setattr__("labels", state["labels"])
_validate_labels(self.labels) _validate_labels(self.labels)
def is_absolute(self): def is_absolute(self) -> bool:
"""Is the most significant label of this name the root label? """Is the most significant label of this name the root label?
Returns a ``bool``. Returns a ``bool``.
""" """
return len(self.labels) > 0 and self.labels[-1] == b'' return len(self.labels) > 0 and self.labels[-1] == b""
def is_wild(self): def is_wild(self) -> bool:
"""Is this name wild? (I.e. Is the least significant label '*'?) """Is this name wild? (I.e. Is the least significant label '*'?)
Returns a ``bool``. Returns a ``bool``.
""" """
return len(self.labels) > 0 and self.labels[0] == b'*' return len(self.labels) > 0 and self.labels[0] == b"*"
def __hash__(self): def __hash__(self) -> int:
"""Return a case-insensitive hash of the name. """Return a case-insensitive hash of the name.
Returns an ``int``. Returns an ``int``.
@ -368,14 +412,14 @@ class Name:
h += (h << 3) + c h += (h << 3) + c
return h return h
def fullcompare(self, other): def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]:
"""Compare two names, returning a 3-tuple """Compare two names, returning a 3-tuple
``(relation, order, nlabels)``. ``(relation, order, nlabels)``.
*relation* describes the relation ship between the names, *relation* describes the relation ship between the names,
and is one of: ``dns.name.NAMERELN_NONE``, and is one of: ``dns.name.NameRelation.NONE``,
``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``, ``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``,
``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``. ``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``.
*order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and == *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and ==
0 if *self* == *other*. A relative name is always less than an 0 if *self* == *other*. A relative name is always less than an
@ -404,9 +448,9 @@ class Name:
oabs = other.is_absolute() oabs = other.is_absolute()
if sabs != oabs: if sabs != oabs:
if sabs: if sabs:
return (NAMERELN_NONE, 1, 0) return (NameRelation.NONE, 1, 0)
else: else:
return (NAMERELN_NONE, -1, 0) return (NameRelation.NONE, -1, 0)
l1 = len(self.labels) l1 = len(self.labels)
l2 = len(other.labels) l2 = len(other.labels)
ldiff = l1 - l2 ldiff = l1 - l2
@ -417,7 +461,7 @@ class Name:
order = 0 order = 0
nlabels = 0 nlabels = 0
namereln = NAMERELN_NONE namereln = NameRelation.NONE
while l > 0: while l > 0:
l -= 1 l -= 1
l1 -= 1 l1 -= 1
@ -427,52 +471,52 @@ class Name:
if label1 < label2: if label1 < label2:
order = -1 order = -1
if nlabels > 0: if nlabels > 0:
namereln = NAMERELN_COMMONANCESTOR namereln = NameRelation.COMMONANCESTOR
return (namereln, order, nlabels) return (namereln, order, nlabels)
elif label1 > label2: elif label1 > label2:
order = 1 order = 1
if nlabels > 0: if nlabels > 0:
namereln = NAMERELN_COMMONANCESTOR namereln = NameRelation.COMMONANCESTOR
return (namereln, order, nlabels) return (namereln, order, nlabels)
nlabels += 1 nlabels += 1
order = ldiff order = ldiff
if ldiff < 0: if ldiff < 0:
namereln = NAMERELN_SUPERDOMAIN namereln = NameRelation.SUPERDOMAIN
elif ldiff > 0: elif ldiff > 0:
namereln = NAMERELN_SUBDOMAIN namereln = NameRelation.SUBDOMAIN
else: else:
namereln = NAMERELN_EQUAL namereln = NameRelation.EQUAL
return (namereln, order, nlabels) return (namereln, order, nlabels)
def is_subdomain(self, other): def is_subdomain(self, other: "Name") -> bool:
"""Is self a subdomain of other? """Is self a subdomain of other?
Note that the notion of subdomain includes equality, e.g. Note that the notion of subdomain includes equality, e.g.
"dnpython.org" is a subdomain of itself. "dnspython.org" is a subdomain of itself.
Returns a ``bool``. Returns a ``bool``.
""" """
(nr, _, _) = self.fullcompare(other) (nr, _, _) = self.fullcompare(other)
if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL: if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL:
return True return True
return False return False
def is_superdomain(self, other): def is_superdomain(self, other: "Name") -> bool:
"""Is self a superdomain of other? """Is self a superdomain of other?
Note that the notion of superdomain includes equality, e.g. Note that the notion of superdomain includes equality, e.g.
"dnpython.org" is a superdomain of itself. "dnspython.org" is a superdomain of itself.
Returns a ``bool``. Returns a ``bool``.
""" """
(nr, _, _) = self.fullcompare(other) (nr, _, _) = self.fullcompare(other)
if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL: if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL:
return True return True
return False return False
def canonicalize(self): def canonicalize(self) -> "Name":
"""Return a name which is equal to the current name, but is in """Return a name which is equal to the current name, but is in
DNSSEC canonical form. DNSSEC canonical form.
""" """
@ -516,12 +560,12 @@ class Name:
return NotImplemented return NotImplemented
def __repr__(self): def __repr__(self):
return '<DNS name ' + self.__str__() + '>' return "<DNS name " + self.__str__() + ">"
def __str__(self): def __str__(self):
return self.to_text(False) return self.to_text(False)
def to_text(self, omit_final_dot=False): def to_text(self, omit_final_dot: bool = False) -> str:
"""Convert name to DNS text format. """Convert name to DNS text format.
*omit_final_dot* is a ``bool``. If True, don't emit the final *omit_final_dot* is a ``bool``. If True, don't emit the final
@ -532,17 +576,19 @@ class Name:
""" """
if len(self.labels) == 0: if len(self.labels) == 0:
return '@' return "@"
if len(self.labels) == 1 and self.labels[0] == b'': if len(self.labels) == 1 and self.labels[0] == b"":
return '.' return "."
if omit_final_dot and self.is_absolute(): if omit_final_dot and self.is_absolute():
l = self.labels[:-1] l = self.labels[:-1]
else: else:
l = self.labels l = self.labels
s = '.'.join(map(_escapify, l)) s = ".".join(map(_escapify, l))
return s return s
def to_unicode(self, omit_final_dot=False, idna_codec=None): def to_unicode(
self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None
) -> str:
"""Convert name to Unicode text format. """Convert name to Unicode text format.
IDN ACE labels are converted to Unicode. IDN ACE labels are converted to Unicode.
@ -561,18 +607,18 @@ class Name:
""" """
if len(self.labels) == 0: if len(self.labels) == 0:
return '@' return "@"
if len(self.labels) == 1 and self.labels[0] == b'': if len(self.labels) == 1 and self.labels[0] == b"":
return '.' return "."
if omit_final_dot and self.is_absolute(): if omit_final_dot and self.is_absolute():
l = self.labels[:-1] l = self.labels[:-1]
else: else:
l = self.labels l = self.labels
if idna_codec is None: if idna_codec is None:
idna_codec = IDNA_2003_Practical idna_codec = IDNA_2003_Practical
return '.'.join([idna_codec.decode(x) for x in l]) return ".".join([idna_codec.decode(x) for x in l])
def to_digestable(self, origin=None): def to_digestable(self, origin: Optional["Name"] = None) -> bytes:
"""Convert name to a format suitable for digesting in hashes. """Convert name to a format suitable for digesting in hashes.
The name is canonicalized and converted to uncompressed wire The name is canonicalized and converted to uncompressed wire
@ -589,10 +635,17 @@ class Name:
Returns a ``bytes``. Returns a ``bytes``.
""" """
return self.to_wire(origin=origin, canonicalize=True) digest = self.to_wire(origin=origin, canonicalize=True)
assert digest is not None
return digest
def to_wire(self, file=None, compress=None, origin=None, def to_wire(
canonicalize=False): self,
file: Optional[Any] = None,
compress: Optional[CompressType] = None,
origin: Optional["Name"] = None,
canonicalize: bool = False,
) -> Optional[bytes]:
"""Convert name to wire format, possibly compressing it. """Convert name to wire format, possibly compressing it.
*file* is the file where the name is emitted (typically an *file* is the file where the name is emitted (typically an
@ -638,6 +691,7 @@ class Name:
out += label out += label
return bytes(out) return bytes(out)
labels: Iterable[bytes]
if not self.is_absolute(): if not self.is_absolute():
if origin is None or not origin.is_absolute(): if origin is None or not origin.is_absolute():
raise NeedAbsoluteNameOrOrigin raise NeedAbsoluteNameOrOrigin
@ -654,24 +708,25 @@ class Name:
else: else:
pos = None pos = None
if pos is not None: if pos is not None:
value = 0xc000 + pos value = 0xC000 + pos
s = struct.pack('!H', value) s = struct.pack("!H", value)
file.write(s) file.write(s)
break break
else: else:
if compress is not None and len(n) > 1: if compress is not None and len(n) > 1:
pos = file.tell() pos = file.tell()
if pos <= 0x3fff: if pos <= 0x3FFF:
compress[n] = pos compress[n] = pos
l = len(label) l = len(label)
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
if l > 0: if l > 0:
if canonicalize: if canonicalize:
file.write(label.lower()) file.write(label.lower())
else: else:
file.write(label) file.write(label)
return None
def __len__(self): def __len__(self) -> int:
"""The length of the name (in labels). """The length of the name (in labels).
Returns an ``int``. Returns an ``int``.
@ -688,7 +743,7 @@ class Name:
def __sub__(self, other): def __sub__(self, other):
return self.relativize(other) return self.relativize(other)
def split(self, depth): def split(self, depth: int) -> Tuple["Name", "Name"]:
"""Split a name into a prefix and suffix names at the specified depth. """Split a name into a prefix and suffix names at the specified depth.
*depth* is an ``int`` specifying the number of labels in the suffix *depth* is an ``int`` specifying the number of labels in the suffix
@ -705,11 +760,10 @@ class Name:
elif depth == l: elif depth == l:
return (dns.name.empty, self) return (dns.name.empty, self)
elif depth < 0 or depth > l: elif depth < 0 or depth > l:
raise ValueError( raise ValueError("depth must be >= 0 and <= the length of the name")
'depth must be >= 0 and <= the length of the name')
return (Name(self[:-depth]), Name(self[-depth:])) return (Name(self[:-depth]), Name(self[-depth:]))
def concatenate(self, other): def concatenate(self, other: "Name") -> "Name":
"""Return a new name which is the concatenation of self and other. """Return a new name which is the concatenation of self and other.
Raises ``dns.name.AbsoluteConcatenation`` if the name is Raises ``dns.name.AbsoluteConcatenation`` if the name is
@ -724,7 +778,7 @@ class Name:
labels.extend(list(other.labels)) labels.extend(list(other.labels))
return Name(labels) return Name(labels)
def relativize(self, origin): def relativize(self, origin: "Name") -> "Name":
"""If the name is a subdomain of *origin*, return a new name which is """If the name is a subdomain of *origin*, return a new name which is
the name relative to origin. Otherwise return the name. the name relative to origin. Otherwise return the name.
@ -740,7 +794,7 @@ class Name:
else: else:
return self return self
def derelativize(self, origin): def derelativize(self, origin: "Name") -> "Name":
"""If the name is a relative name, return a new name which is the """If the name is a relative name, return a new name which is the
concatenation of the name and origin. Otherwise return the name. concatenation of the name and origin. Otherwise return the name.
@ -756,7 +810,9 @@ class Name:
else: else:
return self return self
def choose_relativity(self, origin=None, relativize=True): def choose_relativity(
self, origin: Optional["Name"] = None, relativize: bool = True
) -> "Name":
"""Return a name with the relativity desired by the caller. """Return a name with the relativity desired by the caller.
If *origin* is ``None``, then the name is returned. If *origin* is ``None``, then the name is returned.
@ -775,7 +831,7 @@ class Name:
else: else:
return self return self
def parent(self): def parent(self) -> "Name":
"""Return the parent of the name. """Return the parent of the name.
For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``.
@ -790,13 +846,17 @@ class Name:
raise NoParent raise NoParent
return Name(self.labels[1:]) return Name(self.labels[1:])
#: The root name, '.' #: The root name, '.'
root = Name([b'']) root = Name([b""])
#: The empty name. #: The empty name.
empty = Name([]) empty = Name([])
def from_unicode(text, origin=root, idna_codec=None):
def from_unicode(
text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None
) -> Name:
"""Convert unicode text into a Name object. """Convert unicode text into a Name object.
Labels are encoded in IDN ACE form according to rules specified by Labels are encoded in IDN ACE form according to rules specified by
@ -819,17 +879,17 @@ def from_unicode(text, origin=root, idna_codec=None):
if not (origin is None or isinstance(origin, Name)): if not (origin is None or isinstance(origin, Name)):
raise ValueError("origin must be a Name or None") raise ValueError("origin must be a Name or None")
labels = [] labels = []
label = '' label = ""
escaping = False escaping = False
edigits = 0 edigits = 0
total = 0 total = 0
if idna_codec is None: if idna_codec is None:
idna_codec = IDNA_2003 idna_codec = IDNA_2003
if text == '@': if text == "@":
text = '' text = ""
if text: if text:
if text in ['.', '\u3002', '\uff0e', '\uff61']: if text in [".", "\u3002", "\uff0e", "\uff61"]:
return Name([b'']) # no Unicode "u" on this constant! return Name([b""]) # no Unicode "u" on this constant!
for c in text: for c in text:
if escaping: if escaping:
if edigits == 0: if edigits == 0:
@ -848,12 +908,12 @@ def from_unicode(text, origin=root, idna_codec=None):
if edigits == 3: if edigits == 3:
escaping = False escaping = False
label += chr(total) label += chr(total)
elif c in ['.', '\u3002', '\uff0e', '\uff61']: elif c in [".", "\u3002", "\uff0e", "\uff61"]:
if len(label) == 0: if len(label) == 0:
raise EmptyLabel raise EmptyLabel
labels.append(idna_codec.encode(label)) labels.append(idna_codec.encode(label))
label = '' label = ""
elif c == '\\': elif c == "\\":
escaping = True escaping = True
edigits = 0 edigits = 0
total = 0 total = 0
@ -864,22 +924,28 @@ def from_unicode(text, origin=root, idna_codec=None):
if len(label) > 0: if len(label) > 0:
labels.append(idna_codec.encode(label)) labels.append(idna_codec.encode(label))
else: else:
labels.append(b'') labels.append(b"")
if (len(labels) == 0 or labels[-1] != b'') and origin is not None: if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
labels.extend(list(origin.labels)) labels.extend(list(origin.labels))
return Name(labels) return Name(labels)
def is_all_ascii(text):
def is_all_ascii(text: str) -> bool:
for c in text: for c in text:
if ord(c) > 0x7f: if ord(c) > 0x7F:
return False return False
return True return True
def from_text(text, origin=root, idna_codec=None):
def from_text(
text: Union[bytes, str],
origin: Optional[Name] = root,
idna_codec: Optional[IDNACodec] = None,
) -> Name:
"""Convert text into a Name object. """Convert text into a Name object.
*text*, a ``str``, is the text to convert into a name. *text*, a ``bytes`` or ``str``, is the text to convert into a name.
*origin*, a ``dns.name.Name``, specifies the origin to *origin*, a ``dns.name.Name``, specifies the origin to
append to non-absolute names. The default is the root name. append to non-absolute names. The default is the root name.
@ -903,23 +969,23 @@ def from_text(text, origin=root, idna_codec=None):
# #
# then it's still "all ASCII" even though the domain name has # then it's still "all ASCII" even though the domain name has
# codepoints > 127. # codepoints > 127.
text = text.encode('ascii') text = text.encode("ascii")
if not isinstance(text, bytes): if not isinstance(text, bytes):
raise ValueError("input to from_text() must be a string") raise ValueError("input to from_text() must be a string")
if not (origin is None or isinstance(origin, Name)): if not (origin is None or isinstance(origin, Name)):
raise ValueError("origin must be a Name or None") raise ValueError("origin must be a Name or None")
labels = [] labels = []
label = b'' label = b""
escaping = False escaping = False
edigits = 0 edigits = 0
total = 0 total = 0
if text == b'@': if text == b"@":
text = b'' text = b""
if text: if text:
if text == b'.': if text == b".":
return Name([b'']) return Name([b""])
for c in text: for c in text:
byte_ = struct.pack('!B', c) byte_ = struct.pack("!B", c)
if escaping: if escaping:
if edigits == 0: if edigits == 0:
if byte_.isdigit(): if byte_.isdigit():
@ -936,13 +1002,13 @@ def from_text(text, origin=root, idna_codec=None):
edigits += 1 edigits += 1
if edigits == 3: if edigits == 3:
escaping = False escaping = False
label += struct.pack('!B', total) label += struct.pack("!B", total)
elif byte_ == b'.': elif byte_ == b".":
if len(label) == 0: if len(label) == 0:
raise EmptyLabel raise EmptyLabel
labels.append(label) labels.append(label)
label = b'' label = b""
elif byte_ == b'\\': elif byte_ == b"\\":
escaping = True escaping = True
edigits = 0 edigits = 0
total = 0 total = 0
@ -953,13 +1019,16 @@ def from_text(text, origin=root, idna_codec=None):
if len(label) > 0: if len(label) > 0:
labels.append(label) labels.append(label)
else: else:
labels.append(b'') labels.append(b"")
if (len(labels) == 0 or labels[-1] != b'') and origin is not None: if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
labels.extend(list(origin.labels)) labels.extend(list(origin.labels))
return Name(labels) return Name(labels)
def from_wire_parser(parser): # we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
def from_wire_parser(parser: "dns.wire.Parser") -> Name:
"""Convert possibly compressed wire format into a Name. """Convert possibly compressed wire format into a Name.
*parser* is a dns.wire.Parser. *parser* is a dns.wire.Parser.
@ -980,7 +1049,7 @@ def from_wire_parser(parser):
if count < 64: if count < 64:
labels.append(parser.get_bytes(count)) labels.append(parser.get_bytes(count))
elif count >= 192: elif count >= 192:
current = (count & 0x3f) * 256 + parser.get_uint8() current = (count & 0x3F) * 256 + parser.get_uint8()
if current >= biggest_pointer: if current >= biggest_pointer:
raise BadPointer raise BadPointer
biggest_pointer = current biggest_pointer = current
@ -988,11 +1057,11 @@ def from_wire_parser(parser):
else: else:
raise BadLabelType raise BadLabelType
count = parser.get_uint8() count = parser.get_uint8()
labels.append(b'') labels.append(b"")
return Name(labels) return Name(labels)
def from_wire(message, current): def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
"""Convert possibly compressed wire format into a Name. """Convert possibly compressed wire format into a Name.
*message* is a ``bytes`` containing an entire DNS message in DNS *message* is a ``bytes`` containing an entire DNS message in DNS

View file

@ -1,40 +0,0 @@
from typing import Optional, Union, Tuple, Iterable, List
have_idna_2008: bool
class Name:
def is_subdomain(self, o : Name) -> bool: ...
def is_superdomain(self, o : Name) -> bool: ...
def __init__(self, labels : Iterable[Union[bytes,str]]) -> None:
self.labels : List[bytes]
def is_absolute(self) -> bool: ...
def is_wild(self) -> bool: ...
def fullcompare(self, other) -> Tuple[int,int,int]: ...
def canonicalize(self) -> Name: ...
def __eq__(self, other) -> bool: ...
def __ne__(self, other) -> bool: ...
def __lt__(self, other : Name) -> bool: ...
def __le__(self, other : Name) -> bool: ...
def __ge__(self, other : Name) -> bool: ...
def __gt__(self, other : Name) -> bool: ...
def to_text(self, omit_final_dot=False) -> str: ...
def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ...
def to_digestable(self, origin=None) -> bytes: ...
def to_wire(self, file=None, compress=None, origin=None,
canonicalize=False) -> Optional[bytes]: ...
def __add__(self, other : Name) -> Name: ...
def __sub__(self, other : Name) -> Name: ...
def split(self, depth) -> List[Tuple[str,str]]: ...
def concatenate(self, other : Name) -> Name: ...
def relativize(self, origin) -> Name: ...
def derelativize(self, origin) -> Name: ...
def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ...
def parent(self) -> Name: ...
class IDNACodec:
pass
def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name:
...
empty : Name

View file

@ -27,7 +27,8 @@
"""DNS name dictionary""" """DNS name dictionary"""
from collections.abc import MutableMapping # pylint seems to be confused about this one!
from collections.abc import MutableMapping # pylint: disable=no-name-in-module
import dns.name import dns.name
@ -62,7 +63,7 @@ class NameDict(MutableMapping):
def __setitem__(self, key, value): def __setitem__(self, key, value):
if not isinstance(key, dns.name.Name): if not isinstance(key, dns.name.Name):
raise ValueError('NameDict key must be a name') raise ValueError("NameDict key must be a name")
self.__store[key] = value self.__store[key] = value
self.__update_max_depth(key) self.__update_max_depth(key)

View file

@ -17,12 +17,17 @@
"""DNS nodes. A node is a set of rdatasets.""" """DNS nodes. A node is a set of rdatasets."""
from typing import Any, Dict, Optional
import enum import enum
import io import io
import dns.immutable import dns.immutable
import dns.name
import dns.rdataclass
import dns.rdataset import dns.rdataset
import dns.rdatatype import dns.rdatatype
import dns.rrset
import dns.renderer import dns.renderer
@ -37,21 +42,23 @@ _neutral_types = {
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
} }
def _matches_type_or_its_signature(rdtypes, rdtype, covers): def _matches_type_or_its_signature(rdtypes, rdtype, covers):
return rdtype in rdtypes or \ return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
(rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
@enum.unique @enum.unique
class NodeKind(enum.Enum): class NodeKind(enum.Enum):
"""Rdatasets in nodes """Rdatasets in nodes"""
"""
REGULAR = 0 # a.k.a "other data" REGULAR = 0 # a.k.a "other data"
NEUTRAL = 1 NEUTRAL = 1
CNAME = 2 CNAME = 2
@classmethod @classmethod
def classify(cls, rdtype, covers): def classify(
cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType
) -> "NodeKind":
if _matches_type_or_its_signature(_cname_types, rdtype, covers): if _matches_type_or_its_signature(_cname_types, rdtype, covers):
return NodeKind.CNAME return NodeKind.CNAME
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers): elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
@ -60,7 +67,7 @@ class NodeKind(enum.Enum):
return NodeKind.REGULAR return NodeKind.REGULAR
@classmethod @classmethod
def classify_rdataset(cls, rdataset): def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind":
return cls.classify(rdataset.rdtype, rdataset.covers) return cls.classify(rdataset.rdtype, rdataset.covers)
@ -81,19 +88,19 @@ class Node:
deleted. deleted.
""" """
__slots__ = ['rdatasets'] __slots__ = ["rdatasets"]
def __init__(self): def __init__(self):
# the set of rdatasets, represented as a list. # the set of rdatasets, represented as a list.
self.rdatasets = [] self.rdatasets = []
def to_text(self, name, **kw): def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str:
"""Convert a node to text format. """Convert a node to text format.
Each rdataset at the node is printed. Any keyword arguments Each rdataset at the node is printed. Any keyword arguments
to this method are passed on to the rdataset's to_text() method. to this method are passed on to the rdataset's to_text() method.
*name*, a ``dns.name.Name`` or ``str``, the owner name of the *name*, a ``dns.name.Name``, the owner name of the
rdatasets. rdatasets.
Returns a ``str``. Returns a ``str``.
@ -103,12 +110,12 @@ class Node:
s = io.StringIO() s = io.StringIO()
for rds in self.rdatasets: for rds in self.rdatasets:
if len(rds) > 0: if len(rds) > 0:
s.write(rds.to_text(name, **kw)) s.write(rds.to_text(name, **kw)) # type: ignore[arg-type]
s.write('\n') s.write("\n")
return s.getvalue()[:-1] return s.getvalue()[:-1]
def __repr__(self): def __repr__(self):
return '<DNS node ' + str(id(self)) + '>' return "<DNS node " + str(id(self)) + ">"
def __eq__(self, other): def __eq__(self, other):
# #
@ -144,27 +151,36 @@ class Node:
if len(self.rdatasets) > 0: if len(self.rdatasets) > 0:
kind = NodeKind.classify_rdataset(rdataset) kind = NodeKind.classify_rdataset(rdataset)
if kind == NodeKind.CNAME: if kind == NodeKind.CNAME:
self.rdatasets = [rds for rds in self.rdatasets if self.rdatasets = [
NodeKind.classify_rdataset(rds) != rds
NodeKind.REGULAR] for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR
]
elif kind == NodeKind.REGULAR: elif kind == NodeKind.REGULAR:
self.rdatasets = [rds for rds in self.rdatasets if self.rdatasets = [
NodeKind.classify_rdataset(rds) != rds
NodeKind.CNAME] for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.CNAME
]
# Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to # Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
# edit self.rdatasets. # edit self.rdatasets.
self.rdatasets.append(rdataset) self.rdatasets.append(rdataset)
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, def find_rdataset(
create=False): self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
"""Find an rdataset matching the specified properties in the """Find an rdataset matching the specified properties in the
current node. current node.
*rdclass*, an ``int``, the class of the rdataset. *rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
*rdtype*, an ``int``, the type of the rdataset. *rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
*covers*, an ``int`` or ``None``, the covered type. *covers*, a ``dns.rdatatype.RdataType``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG then the covers value will be the rdata type the SIG/RRSIG
@ -191,8 +207,13 @@ class Node:
self._append_rdataset(rds) self._append_rdataset(rds)
return rds return rds
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, def get_rdataset(
create=False): self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> Optional[dns.rdataset.Rdataset]:
"""Get an rdataset matching the specified properties in the """Get an rdataset matching the specified properties in the
current node. current node.
@ -223,7 +244,12 @@ class Node:
rds = None rds = None
return rds return rds
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
"""Delete the rdataset matching the specified properties in the """Delete the rdataset matching the specified properties in the
current node. current node.
@ -240,7 +266,7 @@ class Node:
if rds is not None: if rds is not None:
self.rdatasets.remove(rds) self.rdatasets.remove(rds)
def replace_rdataset(self, replacement): def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
"""Replace an rdataset. """Replace an rdataset.
It is not an error if there is no rdataset matching *replacement*. It is not an error if there is no rdataset matching *replacement*.
@ -256,16 +282,17 @@ class Node:
""" """
if not isinstance(replacement, dns.rdataset.Rdataset): if not isinstance(replacement, dns.rdataset.Rdataset):
raise ValueError('replacement is not an rdataset') raise ValueError("replacement is not an rdataset")
if isinstance(replacement, dns.rrset.RRset): if isinstance(replacement, dns.rrset.RRset):
# RRsets are not good replacements as the match() method # RRsets are not good replacements as the match() method
# is not compatible. # is not compatible.
replacement = replacement.to_rdataset() replacement = replacement.to_rdataset()
self.delete_rdataset(replacement.rdclass, replacement.rdtype, self.delete_rdataset(
replacement.covers) replacement.rdclass, replacement.rdtype, replacement.covers
)
self._append_rdataset(replacement) self._append_rdataset(replacement)
def classify(self): def classify(self) -> NodeKind:
"""Classify a node. """Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a A node which contains a CNAME or RRSIG(CNAME) is a
@ -286,7 +313,7 @@ class Node:
return kind return kind
return NodeKind.NEUTRAL return NodeKind.NEUTRAL
def is_immutable(self): def is_immutable(self) -> bool:
return False return False
@ -298,23 +325,38 @@ class ImmutableNode(Node):
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
) )
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, def find_rdataset(
create=False): self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create: if create:
raise TypeError("immutable") raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False) return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, def get_rdataset(
create=False): self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> Optional[dns.rdataset.Rdataset]:
if create: if create:
raise TypeError("immutable") raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False) return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
raise TypeError("immutable") raise TypeError("immutable")
def replace_rdataset(self, replacement): def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
raise TypeError("immutable") raise TypeError("immutable")
def is_immutable(self): def is_immutable(self) -> bool:
return True return True

View file

@ -1,17 +0,0 @@
from typing import List, Optional, Union
from . import rdataset, rdatatype, name
class Node:
def __init__(self):
self.rdatasets : List[rdataset.Rdataset]
def to_text(self, name : Union[str,name.Name], **kw) -> str:
...
def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
create=False) -> rdataset.Rdataset:
...
def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
create=False) -> Optional[rdataset.Rdataset]:
...
def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE):
...
def replace_rdataset(self, replacement : rdataset.Rdataset) -> None:
...

View file

@ -20,6 +20,7 @@
import dns.enum import dns.enum
import dns.exception import dns.exception
class Opcode(dns.enum.IntEnum): class Opcode(dns.enum.IntEnum):
#: Query #: Query
QUERY = 0 QUERY = 0
@ -45,7 +46,7 @@ class UnknownOpcode(dns.exception.DNSException):
"""An DNS opcode is unknown.""" """An DNS opcode is unknown."""
def from_text(text): def from_text(text: str) -> Opcode:
"""Convert text into an opcode. """Convert text into an opcode.
*text*, a ``str``, the textual opcode *text*, a ``str``, the textual opcode
@ -58,7 +59,7 @@ def from_text(text):
return Opcode.from_text(text) return Opcode.from_text(text)
def from_flags(flags): def from_flags(flags: int) -> Opcode:
"""Extract an opcode from DNS message flags. """Extract an opcode from DNS message flags.
*flags*, an ``int``, the DNS flags. *flags*, an ``int``, the DNS flags.
@ -66,10 +67,10 @@ def from_flags(flags):
Returns an ``int``. Returns an ``int``.
""" """
return (flags & 0x7800) >> 11 return Opcode((flags & 0x7800) >> 11)
def to_flags(value): def to_flags(value: Opcode) -> int:
"""Convert an opcode to a value suitable for ORing into DNS message """Convert an opcode to a value suitable for ORing into DNS message
flags. flags.
@ -81,7 +82,7 @@ def to_flags(value):
return (value << 11) & 0x7800 return (value << 11) & 0x7800
def to_text(value): def to_text(value: Opcode) -> str:
"""Convert an opcode to text. """Convert an opcode to text.
*value*, an ``int`` the opcode value, *value*, an ``int`` the opcode value,
@ -94,7 +95,7 @@ def to_text(value):
return Opcode.to_text(value) return Opcode.to_text(value)
def is_update(flags): def is_update(flags: int) -> bool:
"""Is the opcode in flags UPDATE? """Is the opcode in flags UPDATE?
*flags*, an ``int``, the DNS message flags. *flags*, an ``int``, the DNS message flags.
@ -104,6 +105,7 @@ def is_update(flags):
return from_flags(flags) == Opcode.UPDATE return from_flags(flags) == Opcode.UPDATE
### BEGIN generated Opcode constants ### BEGIN generated Opcode constants
QUERY = Opcode.QUERY QUERY = Opcode.QUERY

File diff suppressed because it is too large Load diff

View file

@ -1,64 +0,0 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message
from requests.sessions import Session
import socket
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
have_doh: bool
def https(q : message.Message, where: str, timeout : Optional[float] = None,
port : Optional[int] = 443, source : Optional[str] = None,
source_port : Optional[int] = 0,
session: Optional[Session] = None,
path : Optional[str] = '/dns-query', post : Optional[bool] = True,
bootstrap_address : Optional[str] = None,
verify : Optional[bool] = True) -> message.Message:
pass
def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None) -> message.Message:
pass
def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR,
rdclass=rdataclass.IN,
timeout : Optional[float] = None, port=53,
keyring : Optional[Dict[name.Name, bytes]] = None,
keyname : Union[str,name.Name]= None, relativize=True,
lifetime : Optional[float] = None,
source : Optional[str] = None, source_port=0, serial=0,
use_udp : Optional[bool] = False,
keyalgorithm=tsig.default_algorithm) \
-> Generator[Any,Any,message.Message]:
pass
def udp(q : message.Message, where : str, timeout : Optional[float] = None,
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None) -> message.Message:
pass
def tls(q : message.Message, where : str, timeout : Optional[float] = None,
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

74
lib/dns/quic/__init__.py Normal file
View file

@ -0,0 +1,74 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
try:
import aioquic.quic.configuration # type: ignore
import dns.asyncbackend
from dns._asyncbackend import NullContext
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
from dns.quic._asyncio import (
AsyncioQuicManager,
AsyncioQuicConnection,
AsyncioQuicStream,
)
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
have_quic = True
def null_factory(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
return NullContext(None)
def _asyncio_manager_factory(
context, *args, **kwargs # pylint: disable=unused-argument
):
return AsyncioQuicManager(*args, **kwargs)
# We have a context factory and a manager factory as for trio we need to have
# a nursery.
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
try:
import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
TrioQuicManager,
TrioQuicConnection,
TrioQuicStream,
)
def _trio_context_factory():
return trio.open_nursery()
def _trio_manager_factory(context, *args, **kwargs):
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
except ImportError:
pass
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
except ImportError:
have_quic = False
from typing import Any
class AsyncQuicStream: # type: ignore
pass
class AsyncQuicConnection: # type: ignore
async def make_stream(self) -> Any:
raise NotImplementedError
class SyncQuicStream: # type: ignore
pass
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError

206
lib/dns/quic/_asyncio.py Normal file
View file

@ -0,0 +1,206 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import asyncio
import socket
import ssl
import struct
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.inet
import dns.asyncbackend
from dns.quic._common import (
BaseQuicStream,
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
)
class AsyncioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = asyncio.Condition()
async def _wait_for_wake_up(self):
async with self._wake_up:
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True:
if self._buffer.have(amount):
return
self._expecting = amount
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except Exception:
pass
self._expecting = 0
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class AsyncioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = None
self._handshake_complete = asyncio.Event()
self._socket_created = asyncio.Event()
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
async def _receiver(self):
try:
af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio")
self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, self._peer
)
self._socket_created.set()
async with self._socket:
while not self._done:
(datagram, address) = await self._socket.recvfrom(
QUIC_MAX_DATAGRAM, None
)
if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue
self._connection.receive_datagram(
datagram, self._peer[0], time.time()
)
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
async with self._wake_timer:
self._wake_timer.notify_all()
except Exception:
pass
async def _wait_for_wake_timer(self):
async with self._wake_timer:
await self._wake_timer.wait()
async def _sender(self):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, address) in datagrams:
assert address == self._peer[0]
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
self._done = True
self._receiver_task.cancel()
count += 1
if count > 10:
# yield
count = 0
await asyncio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
async with self._wake_timer:
self._wake_timer.notify_all()
def run(self):
if self._closed:
return
self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender())
async def make_stream(self):
await self._handshake_complete.wait()
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
async with self._wake_timer:
self._wake_timer.notify_all()
try:
await self._receiver_task
except asyncio.CancelledError:
pass
try:
await self._sender_task
except asyncio.CancelledError:
pass
class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
super().__init__(conf, verify_mode, AsyncioQuicConnection)
def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
if start:
connection.run()
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

180
lib/dns/quic/_common.py Normal file
View file

@ -0,0 +1,180 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import struct
import time
from typing import Any
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import dns.inet
QUIC_MAX_DATAGRAM = 2048
class UnexpectedEOF(Exception):
pass
class Buffer:
def __init__(self):
self._buffer = b""
self._seen_end = False
def put(self, data, is_end):
if self._seen_end:
return
self._buffer += data
if is_end:
self._seen_end = True
def have(self, amount):
if len(self._buffer) >= amount:
return True
if self._seen_end:
raise UnexpectedEOF
return False
def seen_end(self):
return self._seen_end
def get(self, amount):
assert self.have(amount)
data = self._buffer[:amount]
self._buffer = self._buffer[amount:]
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
self._connection = connection
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
def id(self):
return self._stream_id
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
else:
expiration = None
return expiration
def _timeout_from_expiration(self, expiration):
if expiration is not None:
timeout = max(expiration - time.time(), 0.0)
else:
timeout = None
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises UnexpectedEOF.
def _encapsulate(self, datagram):
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
return self._expecting > 0 and self._buffer.have(self._expecting)
def _close(self):
self._connection.close_stream(self._stream_id)
self._buffer.put(b"", True) # send EOF in case we haven't seen it.
class BaseQuicConnection:
def __init__(
self, connection, address, port, source=None, source_port=0, manager=None
):
self._done = False
self._connection = connection
self._address = address
self._port = port
self._closed = False
self._manager = manager
self._streams = {}
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
if self._af == socket.AF_INET:
source = "0.0.0.0"
elif self._af == socket.AF_INET6:
source = "::"
else:
raise NotImplementedError
if source:
self._source = (source, source_port)
else:
self._source = None
def close_stream(self, stream_id):
del self._streams[stream_id]
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
if expiration is None:
expiration = now + 3600 # arbitrary "big" value
interval = max(expiration - now, 0)
if self._closed and closed_is_special:
# lower sleep interval to avoid a race in the closing process
# which can lead to higher latency closing due to sleeping when
# we have events.
interval = min(interval, 0.05)
return (expiration, interval)
def _handle_timer(self, expiration):
now = time.time()
if expiration <= now:
self._connection.handle_timer(now)
class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self) -> Any:
pass
class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory):
self._connections = {}
self._connection_factory = connection_factory
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=["doq", "doq-i03"],
verify_mode=verify_mode,
)
if verify_path is not None:
conf.load_verify_locations(verify_path)
self._conf = conf
def _connect(self, address, port=853, source=None, source_port=0):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
qconn.connect(address, time.time())
connection = self._connection_factory(
qconn, address, port, source, source_port, self
)
self._connections[(address, port)] = connection
return (connection, True)
def closed(self, address, port):
try:
del self._connections[(address, port)]
except KeyError:
pass
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
raise NotImplementedError

214
lib/dns/quic/_sync.py Normal file
View file

@ -0,0 +1,214 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import selectors
import struct
import threading
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.inet
from dns.quic._common import (
BaseQuicStream,
BaseQuicConnection,
BaseQuicManager,
QUIC_MAX_DATAGRAM,
)
# Avoid circularity with dns.query
if hasattr(selectors, "PollSelector"):
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
class SyncQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = threading.Condition()
self._lock = threading.Lock()
def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True:
with self._lock:
if self._buffer.have(amount):
return
self._expecting = amount
with self._wake_up:
self._wake_up.wait(timeout)
self._expecting = 0
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
self._connection.write(self._stream_id, data, is_end)
def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
with self._wake_up:
self._wake_up.notify()
def close(self):
with self._lock:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
with self._wake_up:
self._wake_up.notify()
return False
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
if self._source is not None:
try:
self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
except Exception:
self._socket.close()
raise
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
def _read(self):
count = 0
while count < 10:
count += 1
try:
datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
except BlockingIOError:
return
with self._lock:
self._connection.receive_datagram(datagram, self._peer[0], time.time())
def _drain_wakeup(self):
while True:
try:
self._receive_wakeup.recv(32)
except BlockingIOError:
return
def _worker(self):
sel = _selector_class()
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
items = sel.select(interval)
for (key, _) in items:
key.data()
with self._lock:
self._handle_timer(expiration)
datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, _) in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
def _handle_events(self):
while True:
with self._lock:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
self._done = True
def write(self, stream, data, is_end=False):
with self._lock:
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start()
def make_stream(self):
self._handshake_complete.wait()
with self._lock:
stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
def close_stream(self, stream_id):
with self._lock:
super().close_stream(stream_id)
def close(self):
with self._lock:
if self._closed:
return
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_wakeup.send(b"\x01")
self._worker_thread.join()
class SyncQuicManager(BaseQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
super().__init__(conf, verify_mode, SyncQuicConnection)
self._lock = threading.Lock()
def connect(self, address, port=853, source=None, source_port=0):
with self._lock:
(connection, start) = self._connect(address, port, source, source_port)
if start:
connection.run()
return connection
def closed(self, address, port):
with self._lock:
super().closed(address, port)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
connection.close()
return False

170
lib/dns/quic/_trio.py Normal file
View file

@ -0,0 +1,170 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import struct
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import trio
import dns.inet
from dns._asyncbackend import NullContext
from dns.quic._common import (
BaseQuicStream,
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
)
class TrioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = trio.Condition()
async def wait_for(self, amount):
while True:
if self._buffer.have(amount):
return
self._expecting = amount
async with self._wake_up:
await self._wake_up.wait()
self._expecting = 0
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source:
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
async def _worker(self):
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
with trio.CancelScope(
deadline=trio.current_time() + interval
) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram(datagram, self._peer[0], time.time())
self._worker_scope = None
self._handle_timer(expiration)
datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, _) in datagrams:
await self._socket.send(datagram)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
self._done = True
self._socket.close()
count += 1
if count > 10:
# yield
count = 0
await trio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
if self._worker_scope is not None:
self._worker_scope.cancel()
async def run(self):
if self._closed:
return
async with trio.open_nursery() as nursery:
nursery.start_soon(self._worker)
self._run_done.set()
async def make_stream(self):
await self._handshake_complete.wait()
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
if self._worker_scope is not None:
self._worker_scope.cancel()
await self._run_done.wait()
class TrioQuicManager(AsyncQuicManager):
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
super().__init__(conf, verify_mode, TrioQuicConnection)
self._nursery = nursery
def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
if start:
self._nursery.start_soon(connection.run)
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View file

@ -17,9 +17,12 @@
"""DNS Result Codes.""" """DNS Result Codes."""
from typing import Tuple
import dns.enum import dns.enum
import dns.exception import dns.exception
class Rcode(dns.enum.IntEnum): class Rcode(dns.enum.IntEnum):
#: No error #: No error
NOERROR = 0 NOERROR = 0
@ -77,20 +80,20 @@ class UnknownRcode(dns.exception.DNSException):
"""A DNS rcode is unknown.""" """A DNS rcode is unknown."""
def from_text(text): def from_text(text: str) -> Rcode:
"""Convert text into an rcode. """Convert text into an rcode.
*text*, a ``str``, the textual rcode or an integer in textual form. *text*, a ``str``, the textual rcode or an integer in textual form.
Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown. Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown.
Returns an ``int``. Returns a ``dns.rcode.Rcode``.
""" """
return Rcode.from_text(text) return Rcode.from_text(text)
def from_flags(flags, ednsflags): def from_flags(flags: int, ednsflags: int) -> Rcode:
"""Return the rcode value encoded by flags and ednsflags. """Return the rcode value encoded by flags and ednsflags.
*flags*, an ``int``, the DNS flags field. *flags*, an ``int``, the DNS flags field.
@ -99,17 +102,17 @@ def from_flags(flags, ednsflags):
Raises ``ValueError`` if rcode is < 0 or > 4095 Raises ``ValueError`` if rcode is < 0 or > 4095
Returns an ``int``. Returns a ``dns.rcode.Rcode``.
""" """
value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0)
return value return Rcode.make(value)
def to_flags(value): def to_flags(value: Rcode) -> Tuple[int, int]:
"""Return a (flags, ednsflags) tuple which encodes the rcode. """Return a (flags, ednsflags) tuple which encodes the rcode.
*value*, an ``int``, the rcode. *value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095. Raises ``ValueError`` if rcode is < 0 or > 4095.
@ -117,16 +120,16 @@ def to_flags(value):
""" """
if value < 0 or value > 4095: if value < 0 or value > 4095:
raise ValueError('rcode must be >= 0 and <= 4095') raise ValueError("rcode must be >= 0 and <= 4095")
v = value & 0xf v = value & 0xF
ev = (value & 0xff0) << 20 ev = (value & 0xFF0) << 20
return (v, ev) return (v, ev)
def to_text(value, tsig=False): def to_text(value: Rcode, tsig: bool = False) -> str:
"""Convert rcode into text. """Convert rcode into text.
*value*, an ``int``, the rcode. *value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095. Raises ``ValueError`` if rcode is < 0 or > 4095.
@ -134,9 +137,10 @@ def to_text(value, tsig=False):
""" """
if tsig and value == Rcode.BADVERS: if tsig and value == Rcode.BADVERS:
return 'BADSIG' return "BADSIG"
return Rcode.to_text(value) return Rcode.to_text(value)
### BEGIN generated Rcode constants ### BEGIN generated Rcode constants
NOERROR = Rcode.NOERROR NOERROR = Rcode.NOERROR

View file

@ -17,6 +17,8 @@
"""DNS rdata.""" """DNS rdata."""
from typing import Any, Dict, Optional, Tuple, Union
from importlib import import_module from importlib import import_module
import base64 import base64
import binascii import binascii
@ -55,21 +57,22 @@ class NoRelativeRdataOrdering(dns.exception.DNSException):
""" """
def _wordbreak(data, chunksize=_chunksize, separator=b' '): def _wordbreak(data, chunksize=_chunksize, separator=b" "):
"""Break a binary string into chunks of chunksize characters separated by """Break a binary string into chunks of chunksize characters separated by
a space. a space.
""" """
if not chunksize: if not chunksize:
return data.decode() return data.decode()
return separator.join([data[i:i + chunksize] return separator.join(
for i [data[i : i + chunksize] for i in range(0, len(data), chunksize)]
in range(0, len(data), chunksize)]).decode() ).decode()
# pylint: disable=unused-argument # pylint: disable=unused-argument
def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
def _hexify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its hex encoding, broken up into chunks """Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a separator. of chunksize characters separated by a separator.
""" """
@ -77,17 +80,19 @@ def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
return _wordbreak(binascii.hexlify(data), chunksize, separator) return _wordbreak(binascii.hexlify(data), chunksize, separator)
def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw): def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks """Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a separator. of chunksize characters separated by a separator.
""" """
return _wordbreak(base64.b64encode(data), chunksize, separator) return _wordbreak(base64.b64encode(data), chunksize, separator)
# pylint: enable=unused-argument # pylint: enable=unused-argument
__escaped = b'"\\' __escaped = b'"\\'
def _escapify(qstring): def _escapify(qstring):
"""Escape the characters in a quoted string which need it.""" """Escape the characters in a quoted string which need it."""
@ -96,14 +101,14 @@ def _escapify(qstring):
if not isinstance(qstring, bytearray): if not isinstance(qstring, bytearray):
qstring = bytearray(qstring) qstring = bytearray(qstring)
text = '' text = ""
for c in qstring: for c in qstring:
if c in __escaped: if c in __escaped:
text += '\\' + chr(c) text += "\\" + chr(c)
elif c >= 0x20 and c < 0x7F: elif c >= 0x20 and c < 0x7F:
text += chr(c) text += chr(c)
else: else:
text += '\\%03d' % c text += "\\%03d" % c
return text return text
@ -117,6 +122,7 @@ def _truncate_bitmap(what):
return what[0 : i + 1] return what[0 : i + 1]
return what[0:1] return what[0:1]
# So we don't have to edit all the rdata classes... # So we don't have to edit all the rdata classes...
_constify = dns.immutable.constify _constify = dns.immutable.constify
@ -125,7 +131,7 @@ _constify = dns.immutable.constify
class Rdata: class Rdata:
"""Base class for all DNS rdata types.""" """Base class for all DNS rdata types."""
__slots__ = ['rdclass', 'rdtype', 'rdcomment'] __slots__ = ["rdclass", "rdtype", "rdcomment"]
def __init__(self, rdclass, rdtype): def __init__(self, rdclass, rdtype):
"""Initialize an rdata. """Initialize an rdata.
@ -140,8 +146,9 @@ class Rdata:
self.rdcomment = None self.rdcomment = None
def _get_all_slots(self): def _get_all_slots(self):
return itertools.chain.from_iterable(getattr(cls, '__slots__', []) return itertools.chain.from_iterable(
for cls in self.__class__.__mro__) getattr(cls, "__slots__", []) for cls in self.__class__.__mro__
)
def __getstate__(self): def __getstate__(self):
# We used to try to do a tuple of all slots here, but it # We used to try to do a tuple of all slots here, but it
@ -160,12 +167,12 @@ class Rdata:
def __setstate__(self, state): def __setstate__(self, state):
for slot, val in state.items(): for slot, val in state.items():
object.__setattr__(self, slot, val) object.__setattr__(self, slot, val)
if not hasattr(self, 'rdcomment'): if not hasattr(self, "rdcomment"):
# Pickled rdata from 2.0.x might not have a rdcomment, so add # Pickled rdata from 2.0.x might not have a rdcomment, so add
# it if needed. # it if needed.
object.__setattr__(self, 'rdcomment', None) object.__setattr__(self, "rdcomment", None)
def covers(self): def covers(self) -> dns.rdatatype.RdataType:
"""Return the type a Rdata covers. """Return the type a Rdata covers.
DNS SIG/RRSIG rdatas apply to a specific type; this type is DNS SIG/RRSIG rdatas apply to a specific type; this type is
@ -174,12 +181,12 @@ class Rdata:
creating rdatasets, allowing the rdataset to contain only RRSIGs creating rdatasets, allowing the rdataset to contain only RRSIGs
of a particular type, e.g. RRSIG(NS). of a particular type, e.g. RRSIG(NS).
Returns an ``int``. Returns a ``dns.rdatatype.RdataType``.
""" """
return dns.rdatatype.NONE return dns.rdatatype.NONE
def extended_rdatatype(self): def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of """Return a 32-bit type value, the least significant 16 bits of
which are the ordinary DNS type, and the upper 16 bits of which are which are the ordinary DNS type, and the upper 16 bits of which are
the "covered" type, if any. the "covered" type, if any.
@ -189,7 +196,12 @@ class Rdata:
return self.covers() << 16 | self.rdtype return self.covers() << 16 | self.rdtype
def to_text(self, origin=None, relativize=True, **kw): def to_text(
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
) -> str:
"""Convert an rdata to text format. """Convert an rdata to text format.
Returns a ``str``. Returns a ``str``.
@ -197,11 +209,22 @@ class Rdata:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(
self,
file: Optional[Any],
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> bytes:
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def to_wire(self, file=None, compress=None, origin=None, def to_wire(
canonicalize=False): self,
file: Optional[Any] = None,
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> bytes:
"""Convert an rdata to wire format. """Convert an rdata to wire format.
Returns a ``bytes`` or ``None``. Returns a ``bytes`` or ``None``.
@ -214,15 +237,18 @@ class Rdata:
self._to_wire(f, compress, origin, canonicalize) self._to_wire(f, compress, origin, canonicalize)
return f.getvalue() return f.getvalue()
def to_generic(self, origin=None): def to_generic(
self, origin: Optional[dns.name.Name] = None
) -> "dns.rdata.GenericRdata":
"""Creates a dns.rdata.GenericRdata equivalent of this rdata. """Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``. Returns a ``dns.rdata.GenericRdata``.
""" """
return dns.rdata.GenericRdata(self.rdclass, self.rdtype, return dns.rdata.GenericRdata(
self.to_wire(origin=origin)) self.rdclass, self.rdtype, self.to_wire(origin=origin)
)
def to_digestable(self, origin=None): def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This """Convert rdata to a format suitable for digesting in hashes. This
is also the DNSSEC canonical form. is also the DNSSEC canonical form.
@ -234,12 +260,19 @@ class Rdata:
def __repr__(self): def __repr__(self):
covers = self.covers() covers = self.covers()
if covers == dns.rdatatype.NONE: if covers == dns.rdatatype.NONE:
ctext = '' ctext = ""
else: else:
ctext = '(' + dns.rdatatype.to_text(covers) + ')' ctext = "(" + dns.rdatatype.to_text(covers) + ")"
return '<DNS ' + dns.rdataclass.to_text(self.rdclass) + ' ' + \ return (
dns.rdatatype.to_text(self.rdtype) + ctext + ' rdata: ' + \ "<DNS "
str(self) + '>' + dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdata: "
+ str(self)
+ ">"
)
def __str__(self): def __str__(self):
return self.to_text() return self.to_text()
@ -320,27 +353,39 @@ class Rdata:
return not self.__eq__(other) return not self.__eq__(other)
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, Rdata) or \ if (
self.rdclass != other.rdclass or self.rdtype != other.rdtype: not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented return NotImplemented
return self._cmp(other) < 0 return self._cmp(other) < 0
def __le__(self, other): def __le__(self, other):
if not isinstance(other, Rdata) or \ if (
self.rdclass != other.rdclass or self.rdtype != other.rdtype: not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented return NotImplemented
return self._cmp(other) <= 0 return self._cmp(other) <= 0
def __ge__(self, other): def __ge__(self, other):
if not isinstance(other, Rdata) or \ if (
self.rdclass != other.rdclass or self.rdtype != other.rdtype: not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented return NotImplemented
return self._cmp(other) >= 0 return self._cmp(other) >= 0
def __gt__(self, other): def __gt__(self, other):
if not isinstance(other, Rdata) or \ if (
self.rdclass != other.rdclass or self.rdtype != other.rdtype: not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented return NotImplemented
return self._cmp(other) > 0 return self._cmp(other) > 0
@ -348,15 +393,28 @@ class Rdata:
return hash(self.to_digestable(dns.name.root)) return hash(self.to_digestable(dns.name.root))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
tok: dns.tokenizer.Tokenizer,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
relativize_to: Optional[dns.name.Name] = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
parser: dns.wire.Parser,
origin: Optional[dns.name.Name] = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover raise NotImplementedError # pragma: no cover
def replace(self, **kwargs): def replace(self, **kwargs: Any) -> "Rdata":
""" """
Create a new Rdata instance based on the instance replace was Create a new Rdata instance based on the instance replace was
invoked on. It is possible to pass different parameters to invoked on. It is possible to pass different parameters to
@ -369,19 +427,25 @@ class Rdata:
""" """
# Get the constructor parameters. # Get the constructor parameters.
parameters = inspect.signature(self.__init__).parameters parameters = inspect.signature(self.__init__).parameters # type: ignore
# Ensure that all of the arguments correspond to valid fields. # Ensure that all of the arguments correspond to valid fields.
# Don't allow rdclass or rdtype to be changed, though. # Don't allow rdclass or rdtype to be changed, though.
for key in kwargs: for key in kwargs:
if key == 'rdcomment': if key == "rdcomment":
continue continue
if key not in parameters: if key not in parameters:
raise AttributeError("'{}' object has no attribute '{}'" raise AttributeError(
.format(self.__class__.__name__, key)) "'{}' object has no attribute '{}'".format(
if key in ('rdclass', 'rdtype'): self.__class__.__name__, key
raise AttributeError("Cannot overwrite '{}' attribute '{}'" )
.format(self.__class__.__name__, key)) )
if key in ("rdclass", "rdtype"):
raise AttributeError(
"Cannot overwrite '{}' attribute '{}'".format(
self.__class__.__name__, key
)
)
# Construct the parameter list. For each field, use the value in # Construct the parameter list. For each field, use the value in
# kwargs if present, and the current value otherwise. # kwargs if present, and the current value otherwise.
@ -391,9 +455,9 @@ class Rdata:
rd = self.__class__(*args) rd = self.__class__(*args)
# The comment is not set in the constructor, so give it special # The comment is not set in the constructor, so give it special
# handling. # handling.
rdcomment = kwargs.get('rdcomment', self.rdcomment) rdcomment = kwargs.get("rdcomment", self.rdcomment)
if rdcomment is not None: if rdcomment is not None:
object.__setattr__(rd, 'rdcomment', rdcomment) object.__setattr__(rd, "rdcomment", rdcomment)
return rd return rd
# Type checking and conversion helpers. These are class methods as # Type checking and conversion helpers. These are class methods as
@ -408,18 +472,26 @@ class Rdata:
return dns.rdatatype.RdataType.make(value) return dns.rdatatype.RdataType.make(value)
@classmethod @classmethod
def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True): def _as_bytes(
cls,
value: Any,
encode: bool = False,
max_length: Optional[int] = None,
empty_ok: bool = True,
) -> bytes:
if encode and isinstance(value, str): if encode and isinstance(value, str):
value = value.encode() bvalue = value.encode()
elif isinstance(value, bytearray): elif isinstance(value, bytearray):
value = bytes(value) bvalue = bytes(value)
elif not isinstance(value, bytes): elif isinstance(value, bytes):
raise ValueError('not bytes') bvalue = value
if max_length is not None and len(value) > max_length: else:
raise ValueError('too long') raise ValueError("not bytes")
if not empty_ok and len(value) == 0: if max_length is not None and len(bvalue) > max_length:
raise ValueError('empty bytes not allowed') raise ValueError("too long")
return value if not empty_ok and len(bvalue) == 0:
raise ValueError("empty bytes not allowed")
return bvalue
@classmethod @classmethod
def _as_name(cls, value): def _as_name(cls, value):
@ -429,49 +501,49 @@ class Rdata:
if isinstance(value, str): if isinstance(value, str):
return dns.name.from_text(value) return dns.name.from_text(value)
elif not isinstance(value, dns.name.Name): elif not isinstance(value, dns.name.Name):
raise ValueError('not a name') raise ValueError("not a name")
return value return value
@classmethod @classmethod
def _as_uint8(cls, value): def _as_uint8(cls, value):
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError('not an integer') raise ValueError("not an integer")
if value < 0 or value > 255: if value < 0 or value > 255:
raise ValueError('not a uint8') raise ValueError("not a uint8")
return value return value
@classmethod @classmethod
def _as_uint16(cls, value): def _as_uint16(cls, value):
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError('not an integer') raise ValueError("not an integer")
if value < 0 or value > 65535: if value < 0 or value > 65535:
raise ValueError('not a uint16') raise ValueError("not a uint16")
return value return value
@classmethod @classmethod
def _as_uint32(cls, value): def _as_uint32(cls, value):
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError('not an integer') raise ValueError("not an integer")
if value < 0 or value > 4294967295: if value < 0 or value > 4294967295:
raise ValueError('not a uint32') raise ValueError("not a uint32")
return value return value
@classmethod @classmethod
def _as_uint48(cls, value): def _as_uint48(cls, value):
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError('not an integer') raise ValueError("not an integer")
if value < 0 or value > 281474976710655: if value < 0 or value > 281474976710655:
raise ValueError('not a uint48') raise ValueError("not a uint48")
return value return value
@classmethod @classmethod
def _as_int(cls, value, low=None, high=None): def _as_int(cls, value, low=None, high=None):
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError('not an integer') raise ValueError("not an integer")
if low is not None and value < low: if low is not None and value < low:
raise ValueError('value too small') raise ValueError("value too small")
if high is not None and value > high: if high is not None and value > high:
raise ValueError('value too large') raise ValueError("value too large")
return value return value
@classmethod @classmethod
@ -483,7 +555,7 @@ class Rdata:
elif isinstance(value, bytes): elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value) return dns.ipv4.inet_ntoa(value)
else: else:
raise ValueError('not an IPv4 address') raise ValueError("not an IPv4 address")
@classmethod @classmethod
def _as_ipv6_address(cls, value): def _as_ipv6_address(cls, value):
@ -494,14 +566,14 @@ class Rdata:
elif isinstance(value, bytes): elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value) return dns.ipv6.inet_ntoa(value)
else: else:
raise ValueError('not an IPv6 address') raise ValueError("not an IPv6 address")
@classmethod @classmethod
def _as_bool(cls, value): def _as_bool(cls, value):
if isinstance(value, bool): if isinstance(value, bool):
return value return value
else: else:
raise ValueError('not a boolean') raise ValueError("not a boolean")
@classmethod @classmethod
def _as_ttl(cls, value): def _as_ttl(cls, value):
@ -510,7 +582,7 @@ class Rdata:
elif isinstance(value, str): elif isinstance(value, str):
return dns.ttl.from_text(value) return dns.ttl.from_text(value)
else: else:
raise ValueError('not a TTL') raise ValueError("not a TTL")
@classmethod @classmethod
def _as_tuple(cls, value, as_value): def _as_tuple(cls, value, as_value):
@ -532,6 +604,7 @@ class Rdata:
return items return items
@dns.immutable.immutable
class GenericRdata(Rdata): class GenericRdata(Rdata):
"""Generic Rdata Class """Generic Rdata Class
@ -540,28 +613,32 @@ class GenericRdata(Rdata):
implementation. It implements the DNS "unknown RRs" scheme. implementation. It implements the DNS "unknown RRs" scheme.
""" """
__slots__ = ['data'] __slots__ = ["data"]
def __init__(self, rdclass, rdtype, data): def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'data', data) self.data = data
def to_text(self, origin=None, relativize=True, **kw): def to_text(
return r'\# %d ' % len(self.data) + _hexify(self.data, **kw) self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
) -> str:
return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
token = tok.get() token = tok.get()
if not token.is_identifier() or token.value != r'\#': if not token.is_identifier() or token.value != r"\#":
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(r"generic rdata does not start with \#")
r'generic rdata does not start with \#')
length = tok.get_int() length = tok.get_int()
hex = tok.concatenate_remaining_identifiers(True).encode() hex = tok.concatenate_remaining_identifiers(True).encode()
data = binascii.unhexlify(hex) data = binascii.unhexlify(hex)
if len(data) != length: if len(data) != length:
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError("generic rdata hex data has wrong length")
'generic rdata hex data has wrong length')
return cls(rdclass, rdtype, data) return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -571,8 +648,12 @@ class GenericRdata(Rdata):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
return cls(rdclass, rdtype, parser.get_remaining()) return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes = {}
_module_prefix = 'dns.rdtypes' _rdata_classes: Dict[
Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any
] = {}
_module_prefix = "dns.rdtypes"
def get_rdata_class(rdclass, rdtype): def get_rdata_class(rdclass, rdtype):
cls = _rdata_classes.get((rdclass, rdtype)) cls = _rdata_classes.get((rdclass, rdtype))
@ -581,16 +662,16 @@ def get_rdata_class(rdclass, rdtype):
if not cls: if not cls:
rdclass_text = dns.rdataclass.to_text(rdclass) rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype) rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace('-', '_') rdtype_text = rdtype_text.replace("-", "_")
try: try:
mod = import_module('.'.join([_module_prefix, mod = import_module(
rdclass_text, rdtype_text])) ".".join([_module_prefix, rdclass_text, rdtype_text])
)
cls = getattr(mod, rdtype_text) cls = getattr(mod, rdtype_text)
_rdata_classes[(rdclass, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls
except ImportError: except ImportError:
try: try:
mod = import_module('.'.join([_module_prefix, mod = import_module(".".join([_module_prefix, "ANY", rdtype_text]))
'ANY', rdtype_text]))
cls = getattr(mod, rdtype_text) cls = getattr(mod, rdtype_text)
_rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
_rdata_classes[(rdclass, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls
@ -602,8 +683,15 @@ def get_rdata_class(rdclass, rdtype):
return cls return cls
def from_text(rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None, idna_codec=None): rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
tok: Union[dns.tokenizer.Tokenizer, str],
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
relativize_to: Optional[dns.name.Name] = None,
idna_codec: Optional[dns.name.IDNACodec] = None,
) -> Rdata:
"""Build an rdata object from text format. """Build an rdata object from text format.
This function attempts to dynamically load a class which This function attempts to dynamically load a class which
@ -617,9 +705,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
If *tok* is a ``str``, then a tokenizer is created and the string If *tok* is a ``str``, then a tokenizer is created and the string
is used as its input. is used as its input.
*rdclass*, an ``int``, the rdataclass. *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, an ``int``, the rdatatype. *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``. *tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
@ -651,17 +739,18 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
# peek at first token # peek at first token
token = tok.get() token = tok.get()
tok.unget(token) tok.unget(token)
if token.is_identifier() and \ if token.is_identifier() and token.value == r"\#":
token.value == r'\#':
# #
# Known type using the generic syntax. Extract the # Known type using the generic syntax. Extract the
# wire form from the generic syntax, and then run # wire form from the generic syntax, and then run
# from_wire on it. # from_wire on it.
# #
grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, grdata = GenericRdata.from_text(
relativize, relativize_to) rdclass, rdtype, tok, origin, relativize, relativize_to
rdata = from_wire(rdclass, rdtype, grdata.data, 0, )
len(grdata.data), origin) rdata = from_wire(
rdclass, rdtype, grdata.data, 0, len(grdata.data), origin
)
# #
# If this comparison isn't equal, then there must have been # If this comparison isn't equal, then there must have been
# compressed names in the wire format, which is an error, # compressed names in the wire format, which is an error,
@ -669,19 +758,27 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
# #
rwire = rdata.to_wire() rwire = rdata.to_wire()
if rwire != grdata.data: if rwire != grdata.data:
raise dns.exception.SyntaxError('compressed data in ' raise dns.exception.SyntaxError(
'generic syntax form ' "compressed data in "
'of known rdatatype') "generic syntax form "
"of known rdatatype"
)
if rdata is None: if rdata is None:
rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize, rdata = cls.from_text(
relativize_to) rdclass, rdtype, tok, origin, relativize, relativize_to
)
token = tok.get_eol_as_token() token = tok.get_eol_as_token()
if token.comment is not None: if token.comment is not None:
object.__setattr__(rdata, 'rdcomment', token.comment) object.__setattr__(rdata, "rdcomment", token.comment)
return rdata return rdata
def from_wire_parser(rdclass, rdtype, parser, origin=None): def from_wire_parser(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
parser: dns.wire.Parser,
origin: Optional[dns.name.Name] = None,
) -> Rdata:
"""Build an rdata object from wire format """Build an rdata object from wire format
This function attempts to dynamically load a class which This function attempts to dynamically load a class which
@ -692,9 +789,9 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
Once a class is chosen, its from_wire() class method is called Once a class is chosen, its from_wire() class method is called
with the parameters to this function. with the parameters to this function.
*rdclass*, an ``int``, the rdataclass. *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, an ``int``, the rdatatype. *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*parser*, a ``dns.wire.Parser``, the parser, which should be *parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the rdata length. restricted to the rdata length.
@ -712,7 +809,14 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
return cls.from_wire_parser(rdclass, rdtype, parser, origin) return cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): def from_wire(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
wire: bytes,
current: int,
rdlen: int,
origin: Optional[dns.name.Name] = None,
) -> Rdata:
"""Build an rdata object from wire format """Build an rdata object from wire format
This function attempts to dynamically load a class which This function attempts to dynamically load a class which
@ -746,13 +850,21 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
class RdatatypeExists(dns.exception.DNSException): class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists.""" """DNS rdatatype already exists."""
supp_kwargs = {'rdclass', 'rdtype'}
fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \ supp_kwargs = {"rdclass", "rdtype"}
"already exists." fmt = (
"The rdata type with class {rdclass:d} and rdtype {rdtype:d} "
+ "already exists."
)
def register_type(implementation, rdtype, rdtype_text, is_singleton=False, def register_type(
rdclass=dns.rdataclass.IN): implementation: Any,
rdtype: int,
rdtype_text: str,
is_singleton: bool = False,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
) -> None:
"""Dynamically register a module to handle an rdatatype. """Dynamically register a module to handle an rdatatype.
*implementation*, a module implementing the type in the usual dnspython *implementation*, a module implementing the type in the usual dnspython
@ -769,14 +881,16 @@ def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
it applies to all classes. it applies to all classes.
""" """
existing_cls = get_rdata_class(rdclass, rdtype) the_rdtype = dns.rdatatype.RdataType.make(rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): existing_cls = get_rdata_class(rdclass, the_rdtype)
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
try: try:
if dns.rdatatype.RdataType(rdtype).name != rdtype_text: if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
except ValueError: except ValueError:
pass pass
_rdata_classes[(rdclass, rdtype)] = getattr(implementation, _rdata_classes[(rdclass, the_rdtype)] = getattr(
rdtype_text.replace('-', '_')) implementation, rdtype_text.replace("-", "_")
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) )
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)

View file

@ -1,19 +0,0 @@
from typing import Dict, Tuple, Any, Optional, BinaryIO
from .name import Name, IDNACodec
class Rdata:
def __init__(self):
self.address : str
def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]:
...
@classmethod
def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True):
...
_rdata_modules : Dict[Tuple[Any,Rdata],Any]
def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None,
relativize : bool = True, relativize_to : Optional[Name] = None,
idna_codec : Optional[IDNACodec] = None):
...
def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None):
...

View file

@ -20,8 +20,10 @@
import dns.enum import dns.enum
import dns.exception import dns.exception
class RdataClass(dns.enum.IntEnum): class RdataClass(dns.enum.IntEnum):
"""DNS Rdata Class""" """DNS Rdata Class"""
RESERVED0 = 0 RESERVED0 = 0
IN = 1 IN = 1
INTERNET = IN INTERNET = IN
@ -56,7 +58,7 @@ class UnknownRdataclass(dns.exception.DNSException):
"""A DNS class is unknown.""" """A DNS class is unknown."""
def from_text(text): def from_text(text: str) -> RdataClass:
"""Convert text into a DNS rdata class value. """Convert text into a DNS rdata class value.
The input text can be a defined DNS RR class mnemonic or The input text can be a defined DNS RR class mnemonic or
@ -68,13 +70,13 @@ def from_text(text):
Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535.
Returns an ``int``. Returns a ``dns.rdataclass.RdataClass``.
""" """
return RdataClass.from_text(text) return RdataClass.from_text(text)
def to_text(value): def to_text(value: RdataClass) -> str:
"""Convert a DNS rdata class value to text. """Convert a DNS rdata class value to text.
If the value has a known mnemonic, it will be used, otherwise the If the value has a known mnemonic, it will be used, otherwise the
@ -88,18 +90,19 @@ def to_text(value):
return RdataClass.to_text(value) return RdataClass.to_text(value)
def is_metaclass(rdclass): def is_metaclass(rdclass: RdataClass) -> bool:
"""True if the specified class is a metaclass. """True if the specified class is a metaclass.
The currently defined metaclasses are ANY and NONE. The currently defined metaclasses are ANY and NONE.
*rdclass* is an ``int``. *rdclass* is a ``dns.rdataclass.RdataClass``.
""" """
if rdclass in _metaclasses: if rdclass in _metaclasses:
return True return True
return False return False
### BEGIN generated RdataClass constants ### BEGIN generated RdataClass constants
RESERVED0 = RdataClass.RESERVED0 RESERVED0 = RdataClass.RESERVED0

View file

@ -17,16 +17,20 @@
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
from typing import Any, cast, Collection, Dict, List, Optional, Union
import io import io
import random import random
import struct import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.name
import dns.rdatatype import dns.rdatatype
import dns.rdataclass import dns.rdataclass
import dns.rdata import dns.rdata
import dns.set import dns.set
import dns.ttl
# define SimpleSet here for backwards compatibility # define SimpleSet here for backwards compatibility
SimpleSet = dns.set.Set SimpleSet = dns.set.Set
@ -45,24 +49,30 @@ class Rdataset(dns.set.Set):
"""A DNS rdataset.""" """A DNS rdataset."""
__slots__ = ['rdclass', 'rdtype', 'covers', 'ttl'] __slots__ = ["rdclass", "rdtype", "covers", "ttl"]
def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0): def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
ttl: int = 0,
):
"""Create a new rdataset of the specified class and type. """Create a new rdataset of the specified class and type.
*rdclass*, an ``int``, the rdataclass. *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
*rdtype*, an ``int``, the rdatatype. *rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype.
*covers*, an ``int``, the covered rdatatype. *covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype.
*ttl*, an ``int``, the TTL. *ttl*, an ``int``, the TTL.
""" """
super().__init__() super().__init__()
self.rdclass = rdclass self.rdclass = rdclass
self.rdtype = rdtype self.rdtype: dns.rdatatype.RdataType = rdtype
self.covers = covers self.covers: dns.rdatatype.RdataType = covers
self.ttl = ttl self.ttl = ttl
def _clone(self): def _clone(self):
@ -73,7 +83,7 @@ class Rdataset(dns.set.Set):
obj.ttl = self.ttl obj.ttl = self.ttl
return obj return obj
def update_ttl(self, ttl): def update_ttl(self, ttl: int) -> None:
"""Perform TTL minimization. """Perform TTL minimization.
Set the TTL of the rdataset to be the lesser of the set's current Set the TTL of the rdataset to be the lesser of the set's current
@ -88,7 +98,9 @@ class Rdataset(dns.set.Set):
elif ttl < self.ttl: elif ttl < self.ttl:
self.ttl = ttl self.ttl = ttl
def add(self, rd, ttl=None): # pylint: disable=arguments-differ def add( # pylint: disable=arguments-differ,arguments-renamed
self, rd: dns.rdata.Rdata, ttl: Optional[int] = None
) -> None:
"""Add the specified rdata to the rdataset. """Add the specified rdata to the rdataset.
If the optional *ttl* parameter is supplied, then If the optional *ttl* parameter is supplied, then
@ -115,8 +127,7 @@ class Rdataset(dns.set.Set):
raise IncompatibleTypes raise IncompatibleTypes
if ttl is not None: if ttl is not None:
self.update_ttl(ttl) self.update_ttl(ttl)
if self.rdtype == dns.rdatatype.RRSIG or \ if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG:
self.rdtype == dns.rdatatype.SIG:
covers = rd.covers() covers = rd.covers()
if len(self) == 0 and self.covers == dns.rdatatype.NONE: if len(self) == 0 and self.covers == dns.rdatatype.NONE:
self.covers = covers self.covers = covers
@ -147,19 +158,26 @@ class Rdataset(dns.set.Set):
def _rdata_repr(self): def _rdata_repr(self):
def maybe_truncate(s): def maybe_truncate(s):
if len(s) > 100: if len(s) > 100:
return s[:100] + '...' return s[:100] + "..."
return s return s
return '[%s]' % ', '.join('<%s>' % maybe_truncate(str(rr))
for rr in self) return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
def __repr__(self): def __repr__(self):
if self.covers == 0: if self.covers == 0:
ctext = '' ctext = ""
else: else:
ctext = '(' + dns.rdatatype.to_text(self.covers) + ')' ctext = "(" + dns.rdatatype.to_text(self.covers) + ")"
return '<DNS ' + dns.rdataclass.to_text(self.rdclass) + ' ' + \ return (
dns.rdatatype.to_text(self.rdtype) + ctext + \ "<DNS "
' rdataset: ' + self._rdata_repr() + '>' + dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdataset: "
+ self._rdata_repr()
+ ">"
)
def __str__(self): def __str__(self):
return self.to_text() return self.to_text()
@ -167,17 +185,26 @@ class Rdataset(dns.set.Set):
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Rdataset): if not isinstance(other, Rdataset):
return False return False
if self.rdclass != other.rdclass or \ if (
self.rdtype != other.rdtype or \ self.rdclass != other.rdclass
self.covers != other.covers: or self.rdtype != other.rdtype
or self.covers != other.covers
):
return False return False
return super().__eq__(other) return super().__eq__(other)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def to_text(self, name=None, origin=None, relativize=True, def to_text(
override_rdclass=None, want_comments=False, **kw): self,
name: Optional[dns.name.Name] = None,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
want_comments: bool = False,
**kw: Dict[str, Any],
) -> str:
"""Convert the rdataset into DNS zone file format. """Convert the rdataset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information See ``dns.name.Name.choose_relativity`` for more information
@ -206,10 +233,10 @@ class Rdataset(dns.set.Set):
if name is not None: if name is not None:
name = name.choose_relativity(origin, relativize) name = name.choose_relativity(origin, relativize)
ntext = str(name) ntext = str(name)
pad = ' ' pad = " "
else: else:
ntext = '' ntext = ""
pad = '' pad = ""
s = io.StringIO() s = io.StringIO()
if override_rdclass is not None: if override_rdclass is not None:
rdclass = override_rdclass rdclass = override_rdclass
@ -221,28 +248,46 @@ class Rdataset(dns.set.Set):
# some dynamic updates, so we don't need to print out the TTL # some dynamic updates, so we don't need to print out the TTL
# (which is meaningless anyway). # (which is meaningless anyway).
# #
s.write('{}{}{} {}\n'.format(ntext, pad, s.write(
"{}{}{} {}\n".format(
ntext,
pad,
dns.rdataclass.to_text(rdclass), dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype))) dns.rdatatype.to_text(self.rdtype),
)
)
else: else:
for rd in self: for rd in self:
extra = '' extra = ""
if want_comments: if want_comments:
if rd.rdcomment: if rd.rdcomment:
extra = f' ;{rd.rdcomment}' extra = f" ;{rd.rdcomment}"
s.write('%s%s%d %s %s %s%s\n' % s.write(
(ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass), "%s%s%d %s %s %s%s\n"
% (
ntext,
pad,
self.ttl,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype), dns.rdatatype.to_text(self.rdtype),
rd.to_text(origin=origin, relativize=relativize, rd.to_text(origin=origin, relativize=relativize, **kw),
**kw), extra,
extra)) )
)
# #
# We strip off the final \n for the caller's convenience in printing # We strip off the final \n for the caller's convenience in printing
# #
return s.getvalue()[:-1] return s.getvalue()[:-1]
def to_wire(self, name, file, compress=None, origin=None, def to_wire(
override_rdclass=None, want_shuffle=True): self,
name: dns.name.Name,
file: Any,
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
want_shuffle: bool = True,
) -> int:
"""Convert the rdataset to wire format. """Convert the rdataset to wire format.
*name*, a ``dns.name.Name`` is the owner name to use. *name*, a ``dns.name.Name`` is the owner name to use.
@ -279,6 +324,7 @@ class Rdataset(dns.set.Set):
file.write(stuff) file.write(stuff)
return 1 return 1
else: else:
l: Union[Rdataset, List[dns.rdata.Rdata]]
if want_shuffle: if want_shuffle:
l = list(self) l = list(self)
random.shuffle(l) random.shuffle(l)
@ -286,8 +332,7 @@ class Rdataset(dns.set.Set):
l = self l = self
for rd in l: for rd in l:
name.to_wire(file, compress, origin) name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
self.ttl, 0)
file.write(stuff) file.write(stuff)
start = file.tell() start = file.tell()
rd.to_wire(file, compress, origin) rd.to_wire(file, compress, origin)
@ -299,17 +344,20 @@ class Rdataset(dns.set.Set):
file.seek(0, io.SEEK_END) file.seek(0, io.SEEK_END)
return len(self) return len(self)
def match(self, rdclass, rdtype, covers): def match(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType,
) -> bool:
"""Returns ``True`` if this rdataset matches the specified class, """Returns ``True`` if this rdataset matches the specified class,
type, and covers. type, and covers.
""" """
if self.rdclass == rdclass and \ if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers:
self.rdtype == rdtype and \
self.covers == covers:
return True return True
return False return False
def processing_order(self): def processing_order(self) -> List[dns.rdata.Rdata]:
"""Return rdatas in a valid processing order according to the type's """Return rdatas in a valid processing order according to the type's
specification. For example, MX records are in preference order from specification. For example, MX records are in preference order from
lowest to highest preferences, with items of the same preference lowest to highest preferences, with items of the same preference
@ -325,51 +373,56 @@ class Rdataset(dns.set.Set):
@dns.immutable.immutable @dns.immutable.immutable
class ImmutableRdataset(Rdataset): class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
"""An immutable DNS rdataset.""" """An immutable DNS rdataset."""
_clone_class = Rdataset _clone_class = Rdataset
def __init__(self, rdataset): def __init__(self, rdataset: Rdataset):
"""Create an immutable rdataset from the specified rdataset.""" """Create an immutable rdataset from the specified rdataset."""
super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers, super().__init__(
rdataset.ttl) rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl
)
self.items = dns.immutable.Dict(rdataset.items) self.items = dns.immutable.Dict(rdataset.items)
def update_ttl(self, ttl): def update_ttl(self, ttl):
raise TypeError('immutable') raise TypeError("immutable")
def add(self, rd, ttl=None): def add(self, rd, ttl=None):
raise TypeError('immutable') raise TypeError("immutable")
def union_update(self, other): def union_update(self, other):
raise TypeError('immutable') raise TypeError("immutable")
def intersection_update(self, other): def intersection_update(self, other):
raise TypeError('immutable') raise TypeError("immutable")
def update(self, other): def update(self, other):
raise TypeError('immutable') raise TypeError("immutable")
def __delitem__(self, i): def __delitem__(self, i):
raise TypeError('immutable') raise TypeError("immutable")
def __ior__(self, other): # lgtm complains about these not raising ArithmeticError, but there is
raise TypeError('immutable') # precedent for overrides of these methods in other classes to raise
# TypeError, and it seems like the better exception.
def __iand__(self, other): def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError('immutable') raise TypeError("immutable")
def __iadd__(self, other): def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError('immutable') raise TypeError("immutable")
def __isub__(self, other): def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError('immutable') raise TypeError("immutable")
def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def clear(self): def clear(self):
raise TypeError('immutable') raise TypeError("immutable")
def __copy__(self): def __copy__(self):
return ImmutableRdataset(super().copy()) return ImmutableRdataset(super().copy())
@ -386,9 +439,20 @@ class ImmutableRdataset(Rdataset):
def difference(self, other): def difference(self, other):
return ImmutableRdataset(super().difference(other)) return ImmutableRdataset(super().difference(other))
def symmetric_difference(self, other):
return ImmutableRdataset(super().symmetric_difference(other))
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
origin=None, relativize=True, relativize_to=None): def from_text_list(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
ttl: int,
text_rdatas: Collection[str],
idna_codec: Optional[dns.name.IDNACodec] = None,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
relativize_to: Optional[dns.name.Name] = None,
) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with """Create an rdataset with the specified class, type, and TTL, and with
the specified list of rdatas in text format. the specified list of rdatas in text format.
@ -407,28 +471,34 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
Returns a ``dns.rdataset.Rdataset`` object. Returns a ``dns.rdataset.Rdataset`` object.
""" """
rdclass = dns.rdataclass.RdataClass.make(rdclass) the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype) the_rdtype = dns.rdatatype.RdataType.make(rdtype)
r = Rdataset(rdclass, rdtype) r = Rdataset(the_rdclass, the_rdtype)
r.update_ttl(ttl) r.update_ttl(ttl)
for t in text_rdatas: for t in text_rdatas:
rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, rd = dns.rdata.from_text(
relativize_to, idna_codec) r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec
)
r.add(rd) r.add(rd)
return r return r
def from_text(rdclass, rdtype, ttl, *text_rdatas): def from_text(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
ttl: int,
*text_rdatas: Any,
) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with """Create an rdataset with the specified class, type, and TTL, and with
the specified rdatas in text format. the specified rdatas in text format.
Returns a ``dns.rdataset.Rdataset`` object. Returns a ``dns.rdataset.Rdataset`` object.
""" """
return from_text_list(rdclass, rdtype, ttl, text_rdatas) return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
def from_rdata_list(ttl, rdatas): def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
"""Create an rdataset with the specified TTL, and with """Create an rdataset with the specified TTL, and with
the specified list of rdata objects. the specified list of rdata objects.
@ -443,14 +513,15 @@ def from_rdata_list(ttl, rdatas):
r = Rdataset(rd.rdclass, rd.rdtype) r = Rdataset(rd.rdclass, rd.rdtype)
r.update_ttl(ttl) r.update_ttl(ttl)
r.add(rd) r.add(rd)
assert r is not None
return r return r
def from_rdata(ttl, *rdatas): def from_rdata(ttl: int, *rdatas: Any) -> Rdataset:
"""Create an rdataset with the specified TTL, and with """Create an rdataset with the specified TTL, and with
the specified rdata objects. the specified rdata objects.
Returns a ``dns.rdataset.Rdataset`` object. Returns a ``dns.rdataset.Rdataset`` object.
""" """
return from_rdata_list(ttl, rdatas) return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))

View file

@ -1,58 +0,0 @@
from typing import Optional, Dict, List, Union
from io import BytesIO
from . import exception, name, set, rdatatype, rdata, rdataset
class DifferingCovers(exception.DNSException):
"""An attempt was made to add a DNS SIG/RRSIG whose covered type
is not the same as that of the other rdatas in the rdataset."""
class IncompatibleTypes(exception.DNSException):
"""An attempt was made to add DNS RR data of an incompatible type."""
class Rdataset(set.Set):
def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0):
self.rdclass : int = rdclass
self.rdtype : int = rdtype
self.covers : int = covers
self.ttl : int = ttl
def update_ttl(self, ttl : int) -> None:
...
def add(self, rd : rdata.Rdata, ttl : Optional[int] =None):
...
def union_update(self, other : Rdataset):
...
def intersection_update(self, other : Rdataset):
...
def update(self, other : Rdataset):
...
def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True,
override_rdclass : Optional[int] =None, **kw) -> bytes:
...
def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None,
override_rdclass : Optional[int] = None, want_shuffle=True) -> int:
...
def match(self, rdclass : int, rdtype : int, covers : int) -> bool:
...
def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset:
...
def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset:
...
def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
...
def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
...

View file

@ -17,11 +17,15 @@
"""DNS Rdata Types.""" """DNS Rdata Types."""
from typing import Dict
import dns.enum import dns.enum
import dns.exception import dns.exception
class RdataType(dns.enum.IntEnum): class RdataType(dns.enum.IntEnum):
"""DNS Rdata Type""" """DNS Rdata Type"""
TYPE0 = 0 TYPE0 = 0
NONE = 0 NONE = 0
A = 1 A = 1
@ -116,24 +120,47 @@ class RdataType(dns.enum.IntEnum):
def _prefix(cls): def _prefix(cls):
return "TYPE" return "TYPE"
@classmethod
def _extra_from_text(cls, text):
if text.find("-") >= 0:
try:
return cls[text.replace("-", "_")]
except KeyError:
pass
return _registered_by_text.get(text)
@classmethod
def _extra_to_text(cls, value, current_text):
if current_text is None:
return _registered_by_value.get(value)
if current_text.find("_") >= 0:
return current_text.replace("_", "-")
return current_text
@classmethod @classmethod
def _unknown_exception_class(cls): def _unknown_exception_class(cls):
return UnknownRdatatype return UnknownRdatatype
_registered_by_text = {}
_registered_by_value = {} _registered_by_text: Dict[str, RdataType] = {}
_registered_by_value: Dict[RdataType, str] = {}
_metatypes = {RdataType.OPT} _metatypes = {RdataType.OPT}
_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, _singletons = {
RdataType.NSEC, RdataType.CNAME} RdataType.SOA,
RdataType.NXT,
RdataType.DNAME,
RdataType.NSEC,
RdataType.CNAME,
}
class UnknownRdatatype(dns.exception.DNSException): class UnknownRdatatype(dns.exception.DNSException):
"""DNS resource record type is unknown.""" """DNS resource record type is unknown."""
def from_text(text): def from_text(text: str) -> RdataType:
"""Convert text into a DNS rdata type value. """Convert text into a DNS rdata type value.
The input text can be a defined DNS RR type mnemonic or The input text can be a defined DNS RR type mnemonic or
@ -145,20 +172,13 @@ def from_text(text):
Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535.
Returns an ``int``. Returns a ``dns.rdatatype.RdataType``.
""" """
text = text.upper().replace('-', '_')
try:
return RdataType.from_text(text) return RdataType.from_text(text)
except UnknownRdatatype:
registered_type = _registered_by_text.get(text)
if registered_type:
return registered_type
raise
def to_text(value): def to_text(value: RdataType) -> str:
"""Convert a DNS rdata type value to text. """Convert a DNS rdata type value to text.
If the value has a known mnemonic, it will be used, otherwise the If the value has a known mnemonic, it will be used, otherwise the
@ -169,18 +189,13 @@ def to_text(value):
Returns a ``str``. Returns a ``str``.
""" """
text = RdataType.to_text(value) return RdataType.to_text(value)
if text.startswith("TYPE"):
registered_text = _registered_by_value.get(value)
if registered_text:
text = registered_text
return text.replace('_', '-')
def is_metatype(rdtype): def is_metatype(rdtype: RdataType) -> bool:
"""True if the specified type is a metatype. """True if the specified type is a metatype.
*rdtype* is an ``int``. *rdtype* is a ``dns.rdatatype.RdataType``.
The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA, The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA,
MAILB, ANY, and OPT. MAILB, ANY, and OPT.
@ -191,7 +206,7 @@ def is_metatype(rdtype):
return (256 > rdtype >= 128) or rdtype in _metatypes return (256 > rdtype >= 128) or rdtype in _metatypes
def is_singleton(rdtype): def is_singleton(rdtype: RdataType) -> bool:
"""Is the specified type a singleton type? """Is the specified type a singleton type?
Singleton types can only have a single rdata in an rdataset, or a single Singleton types can only have a single rdata in an rdataset, or a single
@ -209,11 +224,14 @@ def is_singleton(rdtype):
return True return True
return False return False
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
def register_type(rdtype, rdtype_text, is_singleton=False): def register_type(
rdtype: RdataType, rdtype_text: str, is_singleton: bool = False
) -> None:
"""Dynamically register an rdatatype. """Dynamically register an rdatatype.
*rdtype*, an ``int``, the rdatatype to register. *rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register.
*rdtype_text*, a ``str``, the textual form of the rdatatype. *rdtype_text*, a ``str``, the textual form of the rdatatype.
@ -226,6 +244,7 @@ def register_type(rdtype, rdtype_text, is_singleton=False):
if is_singleton: if is_singleton:
_singletons.add(rdtype) _singletons.add(rdtype)
### BEGIN generated RdataType constants ### BEGIN generated RdataType constants
TYPE0 = RdataType.TYPE0 TYPE0 = RdataType.TYPE0

View file

@ -23,7 +23,7 @@ import dns.rdtypes.util
class Relay(dns.rdtypes.util.Gateway): class Relay(dns.rdtypes.util.Gateway):
name = 'AMTRELAY relay' name = "AMTRELAY relay"
@property @property
def relay(self): def relay(self):
@ -37,10 +37,11 @@ class AMTRELAY(dns.rdata.Rdata):
# see: RFC 8777 # see: RFC 8777
__slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay'] __slots__ = ["precedence", "discovery_optional", "relay_type", "relay"]
def __init__(self, rdclass, rdtype, precedence, discovery_optional, def __init__(
relay_type, relay): self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
relay = Relay(relay_type, relay) relay = Relay(relay_type, relay)
self.precedence = self._as_uint8(precedence) self.precedence = self._as_uint8(precedence)
@ -50,37 +51,42 @@ class AMTRELAY(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
return '%d %d %d %s' % (self.precedence, self.discovery_optional, return "%d %d %d %s" % (
self.relay_type, relay) self.precedence,
self.discovery_optional,
self.relay_type,
relay,
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
precedence = tok.get_uint8() precedence = tok.get_uint8()
discovery_optional = tok.get_uint8() discovery_optional = tok.get_uint8()
if discovery_optional > 1: if discovery_optional > 1:
raise dns.exception.SyntaxError('expecting 0 or 1') raise dns.exception.SyntaxError("expecting 0 or 1")
discovery_optional = bool(discovery_optional) discovery_optional = bool(discovery_optional)
relay_type = tok.get_uint8() relay_type = tok.get_uint8()
if relay_type > 0x7f: if relay_type > 0x7F:
raise dns.exception.SyntaxError('expecting an integer <= 127') raise dns.exception.SyntaxError("expecting an integer <= 127")
relay = Relay.from_text(relay_type, tok, origin, relativize, relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to)
relativize_to) return cls(
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
relay.relay) )
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
relay_type = self.relay_type | (self.discovery_optional << 7) relay_type = self.relay_type | (self.discovery_optional << 7)
header = struct.pack("!BB", self.precedence, relay_type) header = struct.pack("!BB", self.precedence, relay_type)
file.write(header) file.write(header)
Relay(self.relay_type, self.relay).to_wire(file, compress, origin, Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize)
canonicalize)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(precedence, relay_type) = parser.get_struct('!BB') (precedence, relay_type) = parser.get_struct("!BB")
discovery_optional = bool(relay_type >> 7) discovery_optional = bool(relay_type >> 7)
relay_type &= 0x7f relay_type &= 0x7F
relay = Relay.from_wire_parser(relay_type, parser, origin) relay = Relay.from_wire_parser(relay_type, parser, origin)
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, return cls(
relay.relay) rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
)

View file

@ -30,7 +30,7 @@ class CAA(dns.rdata.Rdata):
# see: RFC 6844 # see: RFC 6844
__slots__ = ['flags', 'tag', 'value'] __slots__ = ["flags", "tag", "value"]
def __init__(self, rdclass, rdtype, flags, tag, value): def __init__(self, rdclass, rdtype, flags, tag, value):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -41,23 +41,26 @@ class CAA(dns.rdata.Rdata):
self.value = self._as_bytes(value) self.value = self._as_bytes(value)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '%u %s "%s"' % (self.flags, return '%u %s "%s"' % (
self.flags,
dns.rdata._escapify(self.tag), dns.rdata._escapify(self.tag),
dns.rdata._escapify(self.value)) dns.rdata._escapify(self.value),
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
flags = tok.get_uint8() flags = tok.get_uint8()
tag = tok.get_string().encode() tag = tok.get_string().encode()
value = tok.get_string().encode() value = tok.get_string().encode()
return cls(rdclass, rdtype, flags, tag, value) return cls(rdclass, rdtype, flags, tag, value)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!B', self.flags)) file.write(struct.pack("!B", self.flags))
l = len(self.tag) l = len(self.tag)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.tag) file.write(self.tag)
file.write(self.value) file.write(self.value)

View file

@ -15,13 +15,19 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable import dns.immutable
# pylint: disable=unused-import # pylint: disable=unused-import
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 from dns.rdtypes.dnskeybase import (
SEP,
REVOKE,
ZONE,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import # pylint: enable=unused-import
@dns.immutable.immutable @dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):

View file

@ -20,34 +20,34 @@ import base64
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.dnssec import dns.dnssectypes
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
_ctype_by_value = { _ctype_by_value = {
1: 'PKIX', 1: "PKIX",
2: 'SPKI', 2: "SPKI",
3: 'PGP', 3: "PGP",
4: 'IPKIX', 4: "IPKIX",
5: 'ISPKI', 5: "ISPKI",
6: 'IPGP', 6: "IPGP",
7: 'ACPKIX', 7: "ACPKIX",
8: 'IACPKIX', 8: "IACPKIX",
253: 'URI', 253: "URI",
254: 'OID', 254: "OID",
} }
_ctype_by_name = { _ctype_by_name = {
'PKIX': 1, "PKIX": 1,
'SPKI': 2, "SPKI": 2,
'PGP': 3, "PGP": 3,
'IPKIX': 4, "IPKIX": 4,
'ISPKI': 5, "ISPKI": 5,
'IPGP': 6, "IPGP": 6,
'ACPKIX': 7, "ACPKIX": 7,
'IACPKIX': 8, "IACPKIX": 8,
'URI': 253, "URI": 253,
'OID': 254, "OID": 254,
} }
@ -72,10 +72,11 @@ class CERT(dns.rdata.Rdata):
# see RFC 4398 # see RFC 4398
__slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate'] __slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"]
def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, def __init__(
certificate): self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.certificate_type = self._as_uint16(certificate_type) self.certificate_type = self._as_uint16(certificate_type)
self.key_tag = self._as_uint16(key_tag) self.key_tag = self._as_uint16(key_tag)
@ -84,24 +85,28 @@ class CERT(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
certificate_type = _ctype_to_text(self.certificate_type) certificate_type = _ctype_to_text(self.certificate_type)
return "%s %d %s %s" % (certificate_type, self.key_tag, return "%s %d %s %s" % (
dns.dnssec.algorithm_to_text(self.algorithm), certificate_type,
dns.rdata._base64ify(self.certificate, **kw)) self.key_tag,
dns.dnssectypes.Algorithm.to_text(self.algorithm),
dns.rdata._base64ify(self.certificate, **kw),
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
certificate_type = _ctype_from_text(tok.get_string()) certificate_type = _ctype_from_text(tok.get_string())
key_tag = tok.get_uint16() key_tag = tok.get_uint16()
algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
certificate = base64.b64decode(b64) certificate = base64.b64decode(b64)
return cls(rdclass, rdtype, certificate_type, key_tag, return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
algorithm, certificate)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
prefix = struct.pack("!HHB", self.certificate_type, self.key_tag, prefix = struct.pack(
self.algorithm) "!HHB", self.certificate_type, self.key_tag, self.algorithm
)
file.write(prefix) file.write(prefix)
file.write(self.certificate) file.write(self.certificate)
@ -109,5 +114,4 @@ class CERT(dns.rdata.Rdata):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(certificate_type, key_tag, algorithm) = parser.get_struct("!HHB") (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB")
certificate = parser.get_remaining() certificate = parser.get_remaining()
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
certificate)

View file

@ -27,7 +27,7 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap): class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'CSYNC' type_name = "CSYNC"
@dns.immutable.immutable @dns.immutable.immutable
@ -35,7 +35,7 @@ class CSYNC(dns.rdata.Rdata):
"""CSYNC record""" """CSYNC record"""
__slots__ = ['serial', 'flags', 'windows'] __slots__ = ["serial", "flags", "windows"]
def __init__(self, rdclass, rdtype, serial, flags, windows): def __init__(self, rdclass, rdtype, serial, flags, windows):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -47,18 +47,19 @@ class CSYNC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
text = Bitmap(self.windows).to_text() text = Bitmap(self.windows).to_text()
return '%d %d%s' % (self.serial, self.flags, text) return "%d %d%s" % (self.serial, self.flags, text)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
serial = tok.get_uint32() serial = tok.get_uint32()
flags = tok.get_uint16() flags = tok.get_uint16()
bitmap = Bitmap.from_text(tok) bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, serial, flags, bitmap) return cls(rdclass, rdtype, serial, flags, bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!IH', self.serial, self.flags)) file.write(struct.pack("!IH", self.serial, self.flags))
Bitmap(self.windows).to_wire(file) Bitmap(self.windows).to_wire(file)
@classmethod @classmethod

View file

@ -15,13 +15,19 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable import dns.immutable
# pylint: disable=unused-import # pylint: disable=unused-import
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 from dns.rdtypes.dnskeybase import (
SEP,
REVOKE,
ZONE,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import # pylint: enable=unused-import
@dns.immutable.immutable @dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):

View file

@ -26,19 +26,19 @@ import dns.tokenizer
def _validate_float_string(what): def _validate_float_string(what):
if len(what) == 0: if len(what) == 0:
raise dns.exception.FormError raise dns.exception.FormError
if what[0] == b'-'[0] or what[0] == b'+'[0]: if what[0] == b"-"[0] or what[0] == b"+"[0]:
what = what[1:] what = what[1:]
if what.isdigit(): if what.isdigit():
return return
try: try:
(left, right) = what.split(b'.') (left, right) = what.split(b".")
except ValueError: except ValueError:
raise dns.exception.FormError raise dns.exception.FormError
if left == b'' and right == b'': if left == b"" and right == b"":
raise dns.exception.FormError raise dns.exception.FormError
if not left == b'' and not left.decode().isdigit(): if not left == b"" and not left.decode().isdigit():
raise dns.exception.FormError raise dns.exception.FormError
if not right == b'' and not right.decode().isdigit(): if not right == b"" and not right.decode().isdigit():
raise dns.exception.FormError raise dns.exception.FormError
@ -49,18 +49,15 @@ class GPOS(dns.rdata.Rdata):
# see: RFC 1712 # see: RFC 1712
__slots__ = ['latitude', 'longitude', 'altitude'] __slots__ = ["latitude", "longitude", "altitude"]
def __init__(self, rdclass, rdtype, latitude, longitude, altitude): def __init__(self, rdclass, rdtype, latitude, longitude, altitude):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
if isinstance(latitude, float) or \ if isinstance(latitude, float) or isinstance(latitude, int):
isinstance(latitude, int):
latitude = str(latitude) latitude = str(latitude)
if isinstance(longitude, float) or \ if isinstance(longitude, float) or isinstance(longitude, int):
isinstance(longitude, int):
longitude = str(longitude) longitude = str(longitude)
if isinstance(altitude, float) or \ if isinstance(altitude, float) or isinstance(altitude, int):
isinstance(altitude, int):
altitude = str(altitude) altitude = str(altitude)
latitude = self._as_bytes(latitude, True, 255) latitude = self._as_bytes(latitude, True, 255)
longitude = self._as_bytes(longitude, True, 255) longitude = self._as_bytes(longitude, True, 255)
@ -73,19 +70,20 @@ class GPOS(dns.rdata.Rdata):
self.altitude = altitude self.altitude = altitude
flat = self.float_latitude flat = self.float_latitude
if flat < -90.0 or flat > 90.0: if flat < -90.0 or flat > 90.0:
raise dns.exception.FormError('bad latitude') raise dns.exception.FormError("bad latitude")
flong = self.float_longitude flong = self.float_longitude
if flong < -180.0 or flong > 180.0: if flong < -180.0 or flong > 180.0:
raise dns.exception.FormError('bad longitude') raise dns.exception.FormError("bad longitude")
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '{} {} {}'.format(self.latitude.decode(), return "{} {} {}".format(
self.longitude.decode(), self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
self.altitude.decode()) )
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
latitude = tok.get_string() latitude = tok.get_string()
longitude = tok.get_string() longitude = tok.get_string()
altitude = tok.get_string() altitude = tok.get_string()
@ -94,15 +92,15 @@ class GPOS(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.latitude) l = len(self.latitude)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.latitude) file.write(self.latitude)
l = len(self.longitude) l = len(self.longitude)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.longitude) file.write(self.longitude)
l = len(self.altitude) l = len(self.altitude)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.altitude) file.write(self.altitude)
@classmethod @classmethod

View file

@ -30,7 +30,7 @@ class HINFO(dns.rdata.Rdata):
# see: RFC 1035 # see: RFC 1035
__slots__ = ['cpu', 'os'] __slots__ = ["cpu", "os"]
def __init__(self, rdclass, rdtype, cpu, os): def __init__(self, rdclass, rdtype, cpu, os):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -38,12 +38,14 @@ class HINFO(dns.rdata.Rdata):
self.os = self._as_bytes(os, True, 255) self.os = self._as_bytes(os, True, 255)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), return '"{}" "{}"'.format(
dns.rdata._escapify(self.os)) dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
cpu = tok.get_string(max_length=255) cpu = tok.get_string(max_length=255)
os = tok.get_string(max_length=255) os = tok.get_string(max_length=255)
return cls(rdclass, rdtype, cpu, os) return cls(rdclass, rdtype, cpu, os)
@ -51,11 +53,11 @@ class HINFO(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.cpu) l = len(self.cpu)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.cpu) file.write(self.cpu)
l = len(self.os) l = len(self.os)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.os) file.write(self.os)
@classmethod @classmethod

View file

@ -32,7 +32,7 @@ class HIP(dns.rdata.Rdata):
# see: RFC 5205 # see: RFC 5205
__slots__ = ['hit', 'algorithm', 'key', 'servers'] __slots__ = ["hit", "algorithm", "key", "servers"]
def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -43,18 +43,19 @@ class HIP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
hit = binascii.hexlify(self.hit).decode() hit = binascii.hexlify(self.hit).decode()
key = base64.b64encode(self.key).replace(b'\n', b'').decode() key = base64.b64encode(self.key).replace(b"\n", b"").decode()
text = '' text = ""
servers = [] servers = []
for server in self.servers: for server in self.servers:
servers.append(server.choose_relativity(origin, relativize)) servers.append(server.choose_relativity(origin, relativize))
if len(servers) > 0: if len(servers) > 0:
text += (' ' + ' '.join((x.to_unicode() for x in servers))) text += " " + " ".join((x.to_unicode() for x in servers))
return '%u %s %s%s' % (self.algorithm, hit, key, text) return "%u %s %s%s" % (self.algorithm, hit, key, text)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
hit = binascii.unhexlify(tok.get_string().encode()) hit = binascii.unhexlify(tok.get_string().encode())
key = base64.b64decode(tok.get_string().encode()) key = base64.b64decode(tok.get_string().encode())
@ -75,7 +76,7 @@ class HIP(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(lh, algorithm, lk) = parser.get_struct('!BBH') (lh, algorithm, lk) = parser.get_struct("!BBH")
hit = parser.get_bytes(lh) hit = parser.get_bytes(lh)
key = parser.get_bytes(lk) key = parser.get_bytes(lk)
servers = [] servers = []

View file

@ -30,7 +30,7 @@ class ISDN(dns.rdata.Rdata):
# see: RFC 1183 # see: RFC 1183
__slots__ = ['address', 'subaddress'] __slots__ = ["address", "subaddress"]
def __init__(self, rdclass, rdtype, address, subaddress): def __init__(self, rdclass, rdtype, address, subaddress):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -39,31 +39,33 @@ class ISDN(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress: if self.subaddress:
return '"{}" "{}"'.format(dns.rdata._escapify(self.address), return '"{}" "{}"'.format(
dns.rdata._escapify(self.subaddress)) dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
)
else: else:
return '"%s"' % dns.rdata._escapify(self.address) return '"%s"' % dns.rdata._escapify(self.address)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string() address = tok.get_string()
tokens = tok.get_remaining(max_tokens=1) tokens = tok.get_remaining(max_tokens=1)
if len(tokens) >= 1: if len(tokens) >= 1:
subaddress = tokens[0].unescape().value subaddress = tokens[0].unescape().value
else: else:
subaddress = '' subaddress = ""
return cls(rdclass, rdtype, address, subaddress) return cls(rdclass, rdtype, address, subaddress)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.address) l = len(self.address)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.address) file.write(self.address)
l = len(self.subaddress) l = len(self.subaddress)
if l > 0: if l > 0:
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.subaddress) file.write(self.subaddress)
@classmethod @classmethod
@ -72,5 +74,5 @@ class ISDN(dns.rdata.Rdata):
if parser.remaining() > 0: if parser.remaining() > 0:
subaddress = parser.get_counted_bytes() subaddress = parser.get_counted_bytes()
else: else:
subaddress = b'' subaddress = b""
return cls(rdclass, rdtype, address, subaddress) return cls(rdclass, rdtype, address, subaddress)

View file

@ -3,6 +3,7 @@
import struct import struct
import dns.immutable import dns.immutable
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
@ -12,7 +13,7 @@ class L32(dns.rdata.Rdata):
# see: rfc6742.txt # see: rfc6742.txt
__slots__ = ['preference', 'locator32'] __slots__ = ["preference", "locator32"]
def __init__(self, rdclass, rdtype, preference, locator32): def __init__(self, rdclass, rdtype, preference, locator32):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -20,17 +21,18 @@ class L32(dns.rdata.Rdata):
self.locator32 = self._as_ipv4_address(locator32) self.locator32 = self._as_ipv4_address(locator32)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.locator32}' return f"{self.preference} {self.locator32}"
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16() preference = tok.get_uint16()
nodeid = tok.get_identifier() nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid) return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference)) file.write(struct.pack("!H", self.preference))
file.write(dns.ipv4.inet_aton(self.locator32)) file.write(dns.ipv4.inet_aton(self.locator32))
@classmethod @classmethod

View file

@ -13,33 +13,33 @@ class L64(dns.rdata.Rdata):
# see: rfc6742.txt # see: rfc6742.txt
__slots__ = ['preference', 'locator64'] __slots__ = ["preference", "locator64"]
def __init__(self, rdclass, rdtype, preference, locator64): def __init__(self, rdclass, rdtype, preference, locator64):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference) self.preference = self._as_uint16(preference)
if isinstance(locator64, bytes): if isinstance(locator64, bytes):
if len(locator64) != 8: if len(locator64) != 8:
raise ValueError('invalid locator64') raise ValueError("invalid locator64")
self.locator64 = dns.rdata._hexify(locator64, 4, b':') self.locator64 = dns.rdata._hexify(locator64, 4, b":")
else: else:
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':') dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":")
self.locator64 = locator64 self.locator64 = locator64
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.locator64}' return f"{self.preference} {self.locator64}"
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16() preference = tok.get_uint16()
locator64 = tok.get_identifier() locator64 = tok.get_identifier()
return cls(rdclass, rdtype, preference, locator64) return cls(rdclass, rdtype, preference, locator64)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference)) file.write(struct.pack("!H", self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":"))
4, 4, ':'))
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -93,15 +93,15 @@ def _decode_size(what, desc):
def _check_coordinate_list(value, low, high): def _check_coordinate_list(value, low, high):
if value[0] < low or value[0] > high: if value[0] < low or value[0] > high:
raise ValueError(f'not in range [{low}, {high}]') raise ValueError(f"not in range [{low}, {high}]")
if value[1] < 0 or value[1] > 59: if value[1] < 0 or value[1] > 59:
raise ValueError('bad minutes value') raise ValueError("bad minutes value")
if value[2] < 0 or value[2] > 59: if value[2] < 0 or value[2] > 59:
raise ValueError('bad seconds value') raise ValueError("bad seconds value")
if value[3] < 0 or value[3] > 999: if value[3] < 0 or value[3] > 999:
raise ValueError('bad milliseconds value') raise ValueError("bad milliseconds value")
if value[4] != 1 and value[4] != -1: if value[4] != 1 and value[4] != -1:
raise ValueError('bad hemisphere value') raise ValueError("bad hemisphere value")
@dns.immutable.immutable @dns.immutable.immutable
@ -111,12 +111,26 @@ class LOC(dns.rdata.Rdata):
# see: RFC 1876 # see: RFC 1876
__slots__ = ['latitude', 'longitude', 'altitude', 'size', __slots__ = [
'horizontal_precision', 'vertical_precision'] "latitude",
"longitude",
"altitude",
"size",
"horizontal_precision",
"vertical_precision",
]
def __init__(self, rdclass, rdtype, latitude, longitude, altitude, def __init__(
size=_default_size, hprec=_default_hprec, self,
vprec=_default_vprec): rdclass,
rdtype,
latitude,
longitude,
altitude,
size=_default_size,
hprec=_default_hprec,
vprec=_default_vprec,
):
"""Initialize a LOC record instance. """Initialize a LOC record instance.
The parameters I{latitude} and I{longitude} may be either a 4-tuple The parameters I{latitude} and I{longitude} may be either a 4-tuple
@ -145,34 +159,44 @@ class LOC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.latitude[4] > 0: if self.latitude[4] > 0:
lat_hemisphere = 'N' lat_hemisphere = "N"
else: else:
lat_hemisphere = 'S' lat_hemisphere = "S"
if self.longitude[4] > 0: if self.longitude[4] > 0:
long_hemisphere = 'E' long_hemisphere = "E"
else: else:
long_hemisphere = 'W' long_hemisphere = "W"
text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % ( text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % (
self.latitude[0], self.latitude[1], self.latitude[0],
self.latitude[2], self.latitude[3], lat_hemisphere, self.latitude[1],
self.longitude[0], self.longitude[1], self.longitude[2], self.latitude[2],
self.longitude[3], long_hemisphere, self.latitude[3],
self.altitude / 100.0 lat_hemisphere,
self.longitude[0],
self.longitude[1],
self.longitude[2],
self.longitude[3],
long_hemisphere,
self.altitude / 100.0,
) )
# do not print default values # do not print default values
if self.size != _default_size or \ if (
self.horizontal_precision != _default_hprec or \ self.size != _default_size
self.vertical_precision != _default_vprec: or self.horizontal_precision != _default_hprec
or self.vertical_precision != _default_vprec
):
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
self.size / 100.0, self.horizontal_precision / 100.0, self.size / 100.0,
self.vertical_precision / 100.0 self.horizontal_precision / 100.0,
self.vertical_precision / 100.0,
) )
return text return text
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
latitude = [0, 0, 0, 0, 1] latitude = [0, 0, 0, 0, 1]
longitude = [0, 0, 0, 0, 1] longitude = [0, 0, 0, 0, 1]
size = _default_size size = _default_size
@ -184,16 +208,14 @@ class LOC(dns.rdata.Rdata):
if t.isdigit(): if t.isdigit():
latitude[1] = int(t) latitude[1] = int(t)
t = tok.get_string() t = tok.get_string()
if '.' in t: if "." in t:
(seconds, milliseconds) = t.split('.') (seconds, milliseconds) = t.split(".")
if not seconds.isdigit(): if not seconds.isdigit():
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError("bad latitude seconds value")
'bad latitude seconds value')
latitude[2] = int(seconds) latitude[2] = int(seconds)
l = len(milliseconds) l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit(): if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError("bad latitude milliseconds value")
'bad latitude milliseconds value')
if l == 1: if l == 1:
m = 100 m = 100
elif l == 2: elif l == 2:
@ -205,26 +227,24 @@ class LOC(dns.rdata.Rdata):
elif t.isdigit(): elif t.isdigit():
latitude[2] = int(t) latitude[2] = int(t)
t = tok.get_string() t = tok.get_string()
if t == 'S': if t == "S":
latitude[4] = -1 latitude[4] = -1
elif t != 'N': elif t != "N":
raise dns.exception.SyntaxError('bad latitude hemisphere value') raise dns.exception.SyntaxError("bad latitude hemisphere value")
longitude[0] = tok.get_int() longitude[0] = tok.get_int()
t = tok.get_string() t = tok.get_string()
if t.isdigit(): if t.isdigit():
longitude[1] = int(t) longitude[1] = int(t)
t = tok.get_string() t = tok.get_string()
if '.' in t: if "." in t:
(seconds, milliseconds) = t.split('.') (seconds, milliseconds) = t.split(".")
if not seconds.isdigit(): if not seconds.isdigit():
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError("bad longitude seconds value")
'bad longitude seconds value')
longitude[2] = int(seconds) longitude[2] = int(seconds)
l = len(milliseconds) l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit(): if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError("bad longitude milliseconds value")
'bad longitude milliseconds value')
if l == 1: if l == 1:
m = 100 m = 100
elif l == 2: elif l == 2:
@ -236,30 +256,30 @@ class LOC(dns.rdata.Rdata):
elif t.isdigit(): elif t.isdigit():
longitude[2] = int(t) longitude[2] = int(t)
t = tok.get_string() t = tok.get_string()
if t == 'W': if t == "W":
longitude[4] = -1 longitude[4] = -1
elif t != 'E': elif t != "E":
raise dns.exception.SyntaxError('bad longitude hemisphere value') raise dns.exception.SyntaxError("bad longitude hemisphere value")
t = tok.get_string() t = tok.get_string()
if t[-1] == 'm': if t[-1] == "m":
t = t[0:-1] t = t[0:-1]
altitude = float(t) * 100.0 # m -> cm altitude = float(t) * 100.0 # m -> cm
tokens = tok.get_remaining(max_tokens=3) tokens = tok.get_remaining(max_tokens=3)
if len(tokens) >= 1: if len(tokens) >= 1:
value = tokens[0].unescape().value value = tokens[0].unescape().value
if value[-1] == 'm': if value[-1] == "m":
value = value[0:-1] value = value[0:-1]
size = float(value) * 100.0 # m -> cm size = float(value) * 100.0 # m -> cm
if len(tokens) >= 2: if len(tokens) >= 2:
value = tokens[1].unescape().value value = tokens[1].unescape().value
if value[-1] == 'm': if value[-1] == "m":
value = value[0:-1] value = value[0:-1]
hprec = float(value) * 100.0 # m -> cm hprec = float(value) * 100.0 # m -> cm
if len(tokens) >= 3: if len(tokens) >= 3:
value = tokens[2].unescape().value value = tokens[2].unescape().value
if value[-1] == 'm': if value[-1] == "m":
value = value[0:-1] value = value[0:-1]
vprec = float(value) * 100.0 # m -> cm vprec = float(value) * 100.0 # m -> cm
@ -268,32 +288,43 @@ class LOC(dns.rdata.Rdata):
_encode_size(hprec, "horizontal precision") _encode_size(hprec, "horizontal precision")
_encode_size(vprec, "vertical precision") _encode_size(vprec, "vertical precision")
return cls(rdclass, rdtype, latitude, longitude, altitude, return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
size, hprec, vprec)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
milliseconds = (self.latitude[0] * 3600000 + milliseconds = (
self.latitude[1] * 60000 + self.latitude[0] * 3600000
self.latitude[2] * 1000 + + self.latitude[1] * 60000
self.latitude[3]) * self.latitude[4] + self.latitude[2] * 1000
+ self.latitude[3]
) * self.latitude[4]
latitude = 0x80000000 + milliseconds latitude = 0x80000000 + milliseconds
milliseconds = (self.longitude[0] * 3600000 + milliseconds = (
self.longitude[1] * 60000 + self.longitude[0] * 3600000
self.longitude[2] * 1000 + + self.longitude[1] * 60000
self.longitude[3]) * self.longitude[4] + self.longitude[2] * 1000
+ self.longitude[3]
) * self.longitude[4]
longitude = 0x80000000 + milliseconds longitude = 0x80000000 + milliseconds
altitude = int(self.altitude) + 10000000 altitude = int(self.altitude) + 10000000
size = _encode_size(self.size, "size") size = _encode_size(self.size, "size")
hprec = _encode_size(self.horizontal_precision, "horizontal precision") hprec = _encode_size(self.horizontal_precision, "horizontal precision")
vprec = _encode_size(self.vertical_precision, "vertical precision") vprec = _encode_size(self.vertical_precision, "vertical precision")
wire = struct.pack("!BBBBIII", 0, size, hprec, vprec, latitude, wire = struct.pack(
longitude, altitude) "!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude
)
file.write(wire) file.write(wire)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(version, size, hprec, vprec, latitude, longitude, altitude) = \ (
parser.get_struct("!BBBBIII") version,
size,
hprec,
vprec,
latitude,
longitude,
altitude,
) = parser.get_struct("!BBBBIII")
if version != 0: if version != 0:
raise dns.exception.FormError("LOC version not zero") raise dns.exception.FormError("LOC version not zero")
if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
@ -312,8 +343,7 @@ class LOC(dns.rdata.Rdata):
size = _decode_size(size, "size") size = _decode_size(size, "size")
hprec = _decode_size(hprec, "horizontal precision") hprec = _decode_size(hprec, "horizontal precision")
vprec = _decode_size(vprec, "vertical precision") vprec = _decode_size(vprec, "vertical precision")
return cls(rdclass, rdtype, latitude, longitude, altitude, return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
size, hprec, vprec)
@property @property
def float_latitude(self): def float_latitude(self):

View file

@ -3,6 +3,7 @@
import struct import struct
import dns.immutable import dns.immutable
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
@ -12,7 +13,7 @@ class LP(dns.rdata.Rdata):
# see: rfc6742.txt # see: rfc6742.txt
__slots__ = ['preference', 'fqdn'] __slots__ = ["preference", "fqdn"]
def __init__(self, rdclass, rdtype, preference, fqdn): def __init__(self, rdclass, rdtype, preference, fqdn):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -21,17 +22,18 @@ class LP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
fqdn = self.fqdn.choose_relativity(origin, relativize) fqdn = self.fqdn.choose_relativity(origin, relativize)
return '%d %s' % (self.preference, fqdn) return "%d %s" % (self.preference, fqdn)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16() preference = tok.get_uint16()
fqdn = tok.get_name(origin, relativize, relativize_to) fqdn = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, preference, fqdn) return cls(rdclass, rdtype, preference, fqdn)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference)) file.write(struct.pack("!H", self.preference))
self.fqdn.to_wire(file, compress, origin, canonicalize) self.fqdn.to_wire(file, compress, origin, canonicalize)
@classmethod @classmethod

View file

@ -13,32 +13,33 @@ class NID(dns.rdata.Rdata):
# see: rfc6742.txt # see: rfc6742.txt
__slots__ = ['preference', 'nodeid'] __slots__ = ["preference", "nodeid"]
def __init__(self, rdclass, rdtype, preference, nodeid): def __init__(self, rdclass, rdtype, preference, nodeid):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference) self.preference = self._as_uint16(preference)
if isinstance(nodeid, bytes): if isinstance(nodeid, bytes):
if len(nodeid) != 8: if len(nodeid) != 8:
raise ValueError('invalid nodeid') raise ValueError("invalid nodeid")
self.nodeid = dns.rdata._hexify(nodeid, 4, b':') self.nodeid = dns.rdata._hexify(nodeid, 4, b":")
else: else:
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':') dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":")
self.nodeid = nodeid self.nodeid = nodeid
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.nodeid}' return f"{self.preference} {self.nodeid}"
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16() preference = tok.get_uint16()
nodeid = tok.get_identifier() nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid) return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference)) file.write(struct.pack("!H", self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':')) file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":"))
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -25,7 +25,7 @@ import dns.rdtypes.util
@dns.immutable.immutable @dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap): class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'NSEC' type_name = "NSEC"
@dns.immutable.immutable @dns.immutable.immutable
@ -33,7 +33,7 @@ class NSEC(dns.rdata.Rdata):
"""NSEC record""" """NSEC record"""
__slots__ = ['next', 'windows'] __slots__ = ["next", "windows"]
def __init__(self, rdclass, rdtype, next, windows): def __init__(self, rdclass, rdtype, next, windows):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -45,11 +45,12 @@ class NSEC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize) next = self.next.choose_relativity(origin, relativize)
text = Bitmap(self.windows).to_text() text = Bitmap(self.windows).to_text()
return '{}{}'.format(next, text) return "{}{}".format(next, text)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
next = tok.get_name(origin, relativize, relativize_to) next = tok.get_name(origin, relativize, relativize_to)
windows = Bitmap.from_text(tok) windows = Bitmap.from_text(tok)
return cls(rdclass, rdtype, next, windows) return cls(rdclass, rdtype, next, windows)

View file

@ -26,10 +26,12 @@ import dns.rdatatype
import dns.rdtypes.util import dns.rdtypes.util
b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV', b32_hex_to_normal = bytes.maketrans(
b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', )
b'0123456789ABCDEFGHIJKLMNOPQRSTUV') b32_normal_to_hex = bytes.maketrans(
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV"
)
# hash algorithm constants # hash algorithm constants
SHA1 = 1 SHA1 = 1
@ -40,7 +42,7 @@ OPTOUT = 1
@dns.immutable.immutable @dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap): class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'NSEC3' type_name = "NSEC3"
@dns.immutable.immutable @dns.immutable.immutable
@ -48,10 +50,11 @@ class NSEC3(dns.rdata.Rdata):
"""NSEC3 record""" """NSEC3 record"""
__slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows'] __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt, def __init__(
next, windows): self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm) self.algorithm = self._as_uint8(algorithm)
self.flags = self._as_uint8(flags) self.flags = self._as_uint8(flags)
@ -63,38 +66,41 @@ class NSEC3(dns.rdata.Rdata):
self.windows = tuple(windows.windows) self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = base64.b32encode(self.next).translate( next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
b32_normal_to_hex).lower().decode() if self.salt == b"":
if self.salt == b'': salt = "-"
salt = '-'
else: else:
salt = binascii.hexlify(self.salt).decode() salt = binascii.hexlify(self.salt).decode()
text = Bitmap(self.windows).to_text() text = Bitmap(self.windows).to_text()
return '%u %u %u %s %s%s' % (self.algorithm, self.flags, return "%u %u %u %s %s%s" % (
self.iterations, salt, next, text) self.algorithm,
self.flags,
self.iterations,
salt,
next,
text,
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
flags = tok.get_uint8() flags = tok.get_uint8()
iterations = tok.get_uint16() iterations = tok.get_uint16()
salt = tok.get_string() salt = tok.get_string()
if salt == '-': if salt == "-":
salt = b'' salt = b""
else: else:
salt = binascii.unhexlify(salt.encode('ascii')) salt = binascii.unhexlify(salt.encode("ascii"))
next = tok.get_string().encode( next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
'ascii').upper().translate(b32_hex_to_normal)
next = base64.b32decode(next) next = base64.b32decode(next)
bitmap = Bitmap.from_text(tok) bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt) l = len(self.salt)
file.write(struct.pack("!BBHB", self.algorithm, self.flags, file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
self.iterations, l))
file.write(self.salt) file.write(self.salt)
l = len(self.next) l = len(self.next)
file.write(struct.pack("!B", l)) file.write(struct.pack("!B", l))
@ -103,9 +109,8 @@ class NSEC3(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(algorithm, flags, iterations) = parser.get_struct('!BBH') (algorithm, flags, iterations) = parser.get_struct("!BBH")
salt = parser.get_counted_bytes() salt = parser.get_counted_bytes()
next = parser.get_counted_bytes() next = parser.get_counted_bytes()
bitmap = Bitmap.from_wire_parser(parser) bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
bitmap)

View file

@ -28,7 +28,7 @@ class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record""" """NSEC3PARAM record"""
__slots__ = ['algorithm', 'flags', 'iterations', 'salt'] __slots__ = ["algorithm", "flags", "iterations", "salt"]
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -38,34 +38,33 @@ class NSEC3PARAM(dns.rdata.Rdata):
self.salt = self._as_bytes(salt, True, 255) self.salt = self._as_bytes(salt, True, 255)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.salt == b'': if self.salt == b"":
salt = '-' salt = "-"
else: else:
salt = binascii.hexlify(self.salt).decode() salt = binascii.hexlify(self.salt).decode()
return '%u %u %u %s' % (self.algorithm, self.flags, self.iterations, return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt)
salt)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
flags = tok.get_uint8() flags = tok.get_uint8()
iterations = tok.get_uint16() iterations = tok.get_uint16()
salt = tok.get_string() salt = tok.get_string()
if salt == '-': if salt == "-":
salt = '' salt = ""
else: else:
salt = binascii.unhexlify(salt.encode()) salt = binascii.unhexlify(salt.encode())
return cls(rdclass, rdtype, algorithm, flags, iterations, salt) return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt) l = len(self.salt)
file.write(struct.pack("!BBHB", self.algorithm, self.flags, file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
self.iterations, l))
file.write(self.salt) file.write(self.salt)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(algorithm, flags, iterations) = parser.get_struct('!BBH') (algorithm, flags, iterations) = parser.get_struct("!BBH")
salt = parser.get_counted_bytes() salt = parser.get_counted_bytes()
return cls(rdclass, rdtype, algorithm, flags, iterations, salt) return cls(rdclass, rdtype, algorithm, flags, iterations, salt)

View file

@ -22,6 +22,7 @@ import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class OPENPGPKEY(dns.rdata.Rdata): class OPENPGPKEY(dns.rdata.Rdata):
@ -37,8 +38,9 @@ class OPENPGPKEY(dns.rdata.Rdata):
return dns.rdata._base64ify(self.key, chunksize=None, **kw) return dns.rdata._base64ify(self.key, chunksize=None, **kw)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
key = base64.b64decode(b64) key = base64.b64decode(b64)
return cls(rdclass, rdtype, key) return cls(rdclass, rdtype, key)

View file

@ -26,12 +26,13 @@ import dns.rdata
# We don't implement from_text, and that's ok. # We don't implement from_text, and that's ok.
# pylint: disable=abstract-method # pylint: disable=abstract-method
@dns.immutable.immutable @dns.immutable.immutable
class OPT(dns.rdata.Rdata): class OPT(dns.rdata.Rdata):
"""OPT record""" """OPT record"""
__slots__ = ['options'] __slots__ = ["options"]
def __init__(self, rdclass, rdtype, options): def __init__(self, rdclass, rdtype, options):
"""Initialize an OPT rdata. """Initialize an OPT rdata.
@ -45,10 +46,12 @@ class OPT(dns.rdata.Rdata):
""" """
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
def as_option(option): def as_option(option):
if not isinstance(option, dns.edns.Option): if not isinstance(option, dns.edns.Option):
raise ValueError('option is not a dns.edns.option') raise ValueError("option is not a dns.edns.option")
return option return option
self.options = self._as_tuple(options, as_option) self.options = self._as_tuple(options, as_option)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -58,13 +61,13 @@ class OPT(dns.rdata.Rdata):
file.write(owire) file.write(owire)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return ' '.join(opt.to_text() for opt in self.options) return " ".join(opt.to_text() for opt in self.options)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
options = [] options = []
while parser.remaining() > 0: while parser.remaining() > 0:
(otype, olen) = parser.get_struct('!HH') (otype, olen) = parser.get_struct("!HH")
with parser.restrict_to(olen): with parser.restrict_to(olen):
opt = dns.edns.option_from_wire_parser(otype, parser) opt = dns.edns.option_from_wire_parser(otype, parser)
options.append(opt) options.append(opt)

View file

@ -28,7 +28,7 @@ class RP(dns.rdata.Rdata):
# see: RFC 1183 # see: RFC 1183
__slots__ = ['mbox', 'txt'] __slots__ = ["mbox", "txt"]
def __init__(self, rdclass, rdtype, mbox, txt): def __init__(self, rdclass, rdtype, mbox, txt):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -41,8 +41,9 @@ class RP(dns.rdata.Rdata):
return "{} {}".format(str(mbox), str(txt)) return "{} {}".format(str(mbox), str(txt))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
mbox = tok.get_name(origin, relativize, relativize_to) mbox = tok.get_name(origin, relativize, relativize_to)
txt = tok.get_name(origin, relativize, relativize_to) txt = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, mbox, txt) return cls(rdclass, rdtype, mbox, txt)

View file

@ -20,7 +20,7 @@ import calendar
import struct import struct
import time import time
import dns.dnssec import dns.dnssectypes
import dns.immutable import dns.immutable
import dns.exception import dns.exception
import dns.rdata import dns.rdata
@ -43,12 +43,11 @@ def sigtime_to_posixtime(what):
hour = int(what[8:10]) hour = int(what[8:10])
minute = int(what[10:12]) minute = int(what[10:12])
second = int(what[12:14]) second = int(what[12:14])
return calendar.timegm((year, month, day, hour, minute, second, return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0))
0, 0, 0))
def posixtime_to_sigtime(what): def posixtime_to_sigtime(what):
return time.strftime('%Y%m%d%H%M%S', time.gmtime(what)) return time.strftime("%Y%m%d%H%M%S", time.gmtime(what))
@dns.immutable.immutable @dns.immutable.immutable
@ -56,16 +55,35 @@ class RRSIG(dns.rdata.Rdata):
"""RRSIG record""" """RRSIG record"""
__slots__ = ['type_covered', 'algorithm', 'labels', 'original_ttl', __slots__ = [
'expiration', 'inception', 'key_tag', 'signer', "type_covered",
'signature'] "algorithm",
"labels",
"original_ttl",
"expiration",
"inception",
"key_tag",
"signer",
"signature",
]
def __init__(self, rdclass, rdtype, type_covered, algorithm, labels, def __init__(
original_ttl, expiration, inception, key_tag, signer, self,
signature): rdclass,
rdtype,
type_covered,
algorithm,
labels,
original_ttl,
expiration,
inception,
key_tag,
signer,
signature,
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.type_covered = self._as_rdatatype(type_covered) self.type_covered = self._as_rdatatype(type_covered)
self.algorithm = dns.dnssec.Algorithm.make(algorithm) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.labels = self._as_uint8(labels) self.labels = self._as_uint8(labels)
self.original_ttl = self._as_ttl(original_ttl) self.original_ttl = self._as_ttl(original_ttl)
self.expiration = self._as_uint32(expiration) self.expiration = self._as_uint32(expiration)
@ -78,7 +96,7 @@ class RRSIG(dns.rdata.Rdata):
return self.type_covered return self.type_covered
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '%s %d %d %d %s %s %d %s %s' % ( return "%s %d %d %d %s %s %d %s %s" % (
dns.rdatatype.to_text(self.type_covered), dns.rdatatype.to_text(self.type_covered),
self.algorithm, self.algorithm,
self.labels, self.labels,
@ -87,14 +105,15 @@ class RRSIG(dns.rdata.Rdata):
posixtime_to_sigtime(self.inception), posixtime_to_sigtime(self.inception),
self.key_tag, self.key_tag,
self.signer.choose_relativity(origin, relativize), self.signer.choose_relativity(origin, relativize),
dns.rdata._base64ify(self.signature, **kw) dns.rdata._base64ify(self.signature, **kw),
) )
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
type_covered = dns.rdatatype.from_text(tok.get_string()) type_covered = dns.rdatatype.from_text(tok.get_string())
algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
labels = tok.get_int() labels = tok.get_int()
original_ttl = tok.get_ttl() original_ttl = tok.get_ttl()
expiration = sigtime_to_posixtime(tok.get_string()) expiration = sigtime_to_posixtime(tok.get_string())
@ -103,22 +122,38 @@ class RRSIG(dns.rdata.Rdata):
signer = tok.get_name(origin, relativize, relativize_to) signer = tok.get_name(origin, relativize, relativize_to)
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
signature = base64.b64decode(b64) signature = base64.b64decode(b64)
return cls(rdclass, rdtype, type_covered, algorithm, labels, return cls(
original_ttl, expiration, inception, key_tag, signer, rdclass,
signature) rdtype,
type_covered,
algorithm,
labels,
original_ttl,
expiration,
inception,
key_tag,
signer,
signature,
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack('!HBBIIIH', self.type_covered, header = struct.pack(
self.algorithm, self.labels, "!HBBIIIH",
self.original_ttl, self.expiration, self.type_covered,
self.inception, self.key_tag) self.algorithm,
self.labels,
self.original_ttl,
self.expiration,
self.inception,
self.key_tag,
)
file.write(header) file.write(header)
self.signer.to_wire(file, None, origin, canonicalize) self.signer.to_wire(file, None, origin, canonicalize)
file.write(self.signature) file.write(self.signature)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!HBBIIIH') header = parser.get_struct("!HBBIIIH")
signer = parser.get_name(origin) signer = parser.get_name(origin)
signature = parser.get_remaining() signature = parser.get_remaining()
return cls(rdclass, rdtype, *header, signer, signature) return cls(rdclass, rdtype, *header, signer, signature)

View file

@ -30,11 +30,11 @@ class SOA(dns.rdata.Rdata):
# see: RFC 1035 # see: RFC 1035
__slots__ = ['mname', 'rname', 'serial', 'refresh', 'retry', 'expire', __slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"]
'minimum']
def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry, def __init__(
expire, minimum): self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.mname = self._as_name(mname) self.mname = self._as_name(mname)
self.rname = self._as_name(rname) self.rname = self._as_name(rname)
@ -47,13 +47,20 @@ class SOA(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
mname = self.mname.choose_relativity(origin, relativize) mname = self.mname.choose_relativity(origin, relativize)
rname = self.rname.choose_relativity(origin, relativize) rname = self.rname.choose_relativity(origin, relativize)
return '%s %s %d %d %d %d %d' % ( return "%s %s %d %d %d %d %d" % (
mname, rname, self.serial, self.refresh, self.retry, mname,
self.expire, self.minimum) rname,
self.serial,
self.refresh,
self.retry,
self.expire,
self.minimum,
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
mname = tok.get_name(origin, relativize, relativize_to) mname = tok.get_name(origin, relativize, relativize_to)
rname = tok.get_name(origin, relativize, relativize_to) rname = tok.get_name(origin, relativize, relativize_to)
serial = tok.get_uint32() serial = tok.get_uint32()
@ -61,18 +68,20 @@ class SOA(dns.rdata.Rdata):
retry = tok.get_ttl() retry = tok.get_ttl()
expire = tok.get_ttl() expire = tok.get_ttl()
minimum = tok.get_ttl() minimum = tok.get_ttl()
return cls(rdclass, rdtype, mname, rname, serial, refresh, retry, return cls(
expire, minimum) rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.mname.to_wire(file, compress, origin, canonicalize) self.mname.to_wire(file, compress, origin, canonicalize)
self.rname.to_wire(file, compress, origin, canonicalize) self.rname.to_wire(file, compress, origin, canonicalize)
five_ints = struct.pack('!IIIII', self.serial, self.refresh, five_ints = struct.pack(
self.retry, self.expire, self.minimum) "!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum
)
file.write(five_ints) file.write(five_ints)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
mname = parser.get_name(origin) mname = parser.get_name(origin)
rname = parser.get_name(origin) rname = parser.get_name(origin)
return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII')) return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII"))

View file

@ -30,10 +30,9 @@ class SSHFP(dns.rdata.Rdata):
# See RFC 4255 # See RFC 4255
__slots__ = ['algorithm', 'fp_type', 'fingerprint'] __slots__ = ["algorithm", "fp_type", "fingerprint"]
def __init__(self, rdclass, rdtype, algorithm, fp_type, def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint):
fingerprint):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm) self.algorithm = self._as_uint8(algorithm)
self.fp_type = self._as_uint8(fp_type) self.fp_type = self._as_uint8(fp_type)
@ -41,16 +40,17 @@ class SSHFP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy() kw = kw.copy()
chunksize = kw.pop('chunksize', 128) chunksize = kw.pop("chunksize", 128)
return '%d %d %s' % (self.algorithm, return "%d %d %s" % (
self.algorithm,
self.fp_type, self.fp_type,
dns.rdata._hexify(self.fingerprint, dns.rdata._hexify(self.fingerprint, chunksize=chunksize, **kw),
chunksize=chunksize, )
**kw))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
fp_type = tok.get_uint8() fp_type = tok.get_uint8()
fingerprint = tok.concatenate_remaining_identifiers().encode() fingerprint = tok.concatenate_remaining_identifiers().encode()

View file

@ -18,7 +18,6 @@
import base64 import base64
import struct import struct
import dns.dnssec
import dns.immutable import dns.immutable
import dns.exception import dns.exception
import dns.rdata import dns.rdata
@ -29,11 +28,28 @@ class TKEY(dns.rdata.Rdata):
"""TKEY Record""" """TKEY Record"""
__slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error', __slots__ = [
'key', 'other'] "algorithm",
"inception",
"expiration",
"mode",
"error",
"key",
"other",
]
def __init__(self, rdclass, rdtype, algorithm, inception, expiration, def __init__(
mode, error, key, other=b''): self,
rdclass,
rdtype,
algorithm,
inception,
expiration,
mode,
error,
key,
other=b"",
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm) self.algorithm = self._as_name(algorithm)
self.inception = self._as_uint32(inception) self.inception = self._as_uint32(inception)
@ -45,17 +61,23 @@ class TKEY(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
_algorithm = self.algorithm.choose_relativity(origin, relativize) _algorithm = self.algorithm.choose_relativity(origin, relativize)
text = '%s %u %u %u %u %s' % (str(_algorithm), self.inception, text = "%s %u %u %u %u %s" % (
self.expiration, self.mode, self.error, str(_algorithm),
dns.rdata._base64ify(self.key, 0)) self.inception,
self.expiration,
self.mode,
self.error,
dns.rdata._base64ify(self.key, 0),
)
if len(self.other) > 0: if len(self.other) > 0:
text += ' %s' % (dns.rdata._base64ify(self.other, 0)) text += " %s" % (dns.rdata._base64ify(self.other, 0))
return text return text
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_name(relativize=False) algorithm = tok.get_name(relativize=False)
inception = tok.get_uint32() inception = tok.get_uint32()
expiration = tok.get_uint32() expiration = tok.get_uint32()
@ -66,13 +88,15 @@ class TKEY(dns.rdata.Rdata):
other_b64 = tok.concatenate_remaining_identifiers(True).encode() other_b64 = tok.concatenate_remaining_identifiers(True).encode()
other = base64.b64decode(other_b64) other = base64.b64decode(other_b64)
return cls(rdclass, rdtype, algorithm, inception, expiration, mode, return cls(
error, key, other) rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, compress, origin) self.algorithm.to_wire(file, compress, origin)
file.write(struct.pack("!IIHH", self.inception, self.expiration, file.write(
self.mode, self.error)) struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error)
)
file.write(struct.pack("!H", len(self.key))) file.write(struct.pack("!H", len(self.key)))
file.write(self.key) file.write(self.key)
file.write(struct.pack("!H", len(self.other))) file.write(struct.pack("!H", len(self.other)))
@ -86,8 +110,9 @@ class TKEY(dns.rdata.Rdata):
key = parser.get_counted_bytes(2) key = parser.get_counted_bytes(2)
other = parser.get_counted_bytes(2) other = parser.get_counted_bytes(2)
return cls(rdclass, rdtype, algorithm, inception, expiration, mode, return cls(
error, key, other) rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
)
# Constants for the mode field - from RFC 2930: # Constants for the mode field - from RFC 2930:
# 2.5 The Mode Field # 2.5 The Mode Field

View file

@ -29,11 +29,28 @@ class TSIG(dns.rdata.Rdata):
"""TSIG record""" """TSIG record"""
__slots__ = ['algorithm', 'time_signed', 'fudge', 'mac', __slots__ = [
'original_id', 'error', 'other'] "algorithm",
"time_signed",
"fudge",
"mac",
"original_id",
"error",
"other",
]
def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac, def __init__(
original_id, error, other): self,
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
):
"""Initialize a TSIG rdata. """Initialize a TSIG rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata. *rdclass*, an ``int`` is the rdataclass of the Rdata.
@ -67,45 +84,60 @@ class TSIG(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
algorithm = self.algorithm.choose_relativity(origin, relativize) algorithm = self.algorithm.choose_relativity(origin, relativize)
error = dns.rcode.to_text(self.error, True) error = dns.rcode.to_text(self.error, True)
text = f"{algorithm} {self.time_signed} {self.fudge} " + \ text = (
f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \ f"{algorithm} {self.time_signed} {self.fudge} "
f"{self.original_id} {error} {len(self.other)}" + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} "
+ f"{self.original_id} {error} {len(self.other)}"
)
if self.other: if self.other:
text += f" {dns.rdata._base64ify(self.other, 0)}" text += f" {dns.rdata._base64ify(self.other, 0)}"
return text return text
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_name(relativize=False) algorithm = tok.get_name(relativize=False)
time_signed = tok.get_uint48() time_signed = tok.get_uint48()
fudge = tok.get_uint16() fudge = tok.get_uint16()
mac_len = tok.get_uint16() mac_len = tok.get_uint16()
mac = base64.b64decode(tok.get_string()) mac = base64.b64decode(tok.get_string())
if len(mac) != mac_len: if len(mac) != mac_len:
raise SyntaxError('invalid MAC') raise SyntaxError("invalid MAC")
original_id = tok.get_uint16() original_id = tok.get_uint16()
error = dns.rcode.from_text(tok.get_string()) error = dns.rcode.from_text(tok.get_string())
other_len = tok.get_uint16() other_len = tok.get_uint16()
if other_len > 0: if other_len > 0:
other = base64.b64decode(tok.get_string()) other = base64.b64decode(tok.get_string())
if len(other) != other_len: if len(other) != other_len:
raise SyntaxError('invalid other data') raise SyntaxError("invalid other data")
else: else:
other = b'' other = b""
return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, return cls(
original_id, error, other) rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, None, origin, False) self.algorithm.to_wire(file, None, origin, False)
file.write(struct.pack('!HIHH', file.write(
(self.time_signed >> 32) & 0xffff, struct.pack(
self.time_signed & 0xffffffff, "!HIHH",
(self.time_signed >> 32) & 0xFFFF,
self.time_signed & 0xFFFFFFFF,
self.fudge, self.fudge,
len(self.mac))) len(self.mac),
)
)
file.write(self.mac) file.write(self.mac)
file.write(struct.pack('!HHH', self.original_id, self.error, file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other)))
len(self.other)))
file.write(self.other) file.write(self.other)
@classmethod @classmethod
@ -114,7 +146,16 @@ class TSIG(dns.rdata.Rdata):
time_signed = parser.get_uint48() time_signed = parser.get_uint48()
fudge = parser.get_uint16() fudge = parser.get_uint16()
mac = parser.get_counted_bytes(2) mac = parser.get_counted_bytes(2)
(original_id, error) = parser.get_struct('!HH') (original_id, error) = parser.get_struct("!HH")
other = parser.get_counted_bytes(2) other = parser.get_counted_bytes(2)
return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, return cls(
original_id, error, other) rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
)

View file

@ -32,7 +32,7 @@ class URI(dns.rdata.Rdata):
# see RFC 7553 # see RFC 7553
__slots__ = ['priority', 'weight', 'target'] __slots__ = ["priority", "weight", "target"]
def __init__(self, rdclass, rdtype, priority, weight, target): def __init__(self, rdclass, rdtype, priority, weight, target):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -43,12 +43,12 @@ class URI(dns.rdata.Rdata):
raise dns.exception.SyntaxError("URI target cannot be empty") raise dns.exception.SyntaxError("URI target cannot be empty")
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '%d %d "%s"' % (self.priority, self.weight, return '%d %d "%s"' % (self.priority, self.weight, self.target.decode())
self.target.decode())
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
priority = tok.get_uint16() priority = tok.get_uint16()
weight = tok.get_uint16() weight = tok.get_uint16()
target = tok.get().unescape() target = tok.get().unescape()
@ -63,10 +63,10 @@ class URI(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(priority, weight) = parser.get_struct('!HH') (priority, weight) = parser.get_struct("!HH")
target = parser.get_remaining() target = parser.get_remaining()
if len(target) == 0: if len(target) == 0:
raise dns.exception.FormError('URI target may not be empty') raise dns.exception.FormError("URI target may not be empty")
return cls(rdclass, rdtype, priority, weight, target) return cls(rdclass, rdtype, priority, weight, target)
def _processing_priority(self): def _processing_priority(self):

View file

@ -30,7 +30,7 @@ class X25(dns.rdata.Rdata):
# see RFC 1183 # see RFC 1183
__slots__ = ['address'] __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -40,15 +40,16 @@ class X25(dns.rdata.Rdata):
return '"%s"' % dns.rdata._escapify(self.address) return '"%s"' % dns.rdata._escapify(self.address)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string() address = tok.get_string()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.address) l = len(self.address)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(self.address) file.write(self.address)
@classmethod @classmethod

View file

@ -6,7 +6,7 @@ import binascii
import dns.immutable import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.zone import dns.zonetypes
@dns.immutable.immutable @dns.immutable.immutable
@ -16,35 +16,38 @@ class ZONEMD(dns.rdata.Rdata):
# See RFC 8976 # See RFC 8976
__slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest'] __slots__ = ["serial", "scheme", "hash_algorithm", "digest"]
def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest): def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.serial = self._as_uint32(serial) self.serial = self._as_uint32(serial)
self.scheme = dns.zone.DigestScheme.make(scheme) self.scheme = dns.zonetypes.DigestScheme.make(scheme)
self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm) self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm)
self.digest = self._as_bytes(digest) self.digest = self._as_bytes(digest)
if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2 if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2
raise ValueError('scheme 0 is reserved') raise ValueError("scheme 0 is reserved")
if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3 if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3
raise ValueError('hash_algorithm 0 is reserved') raise ValueError("hash_algorithm 0 is reserved")
hasher = dns.zone._digest_hashers.get(self.hash_algorithm) hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm)
if hasher and hasher().digest_size != len(self.digest): if hasher and hasher().digest_size != len(self.digest):
raise ValueError('digest length inconsistent with hash algorithm') raise ValueError("digest length inconsistent with hash algorithm")
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy() kw = kw.copy()
chunksize = kw.pop('chunksize', 128) chunksize = kw.pop("chunksize", 128)
return '%d %d %d %s' % (self.serial, self.scheme, self.hash_algorithm, return "%d %d %d %s" % (
dns.rdata._hexify(self.digest, self.serial,
chunksize=chunksize, self.scheme,
**kw)) self.hash_algorithm,
dns.rdata._hexify(self.digest, chunksize=chunksize, **kw),
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
serial = tok.get_uint32() serial = tok.get_uint32()
scheme = tok.get_uint8() scheme = tok.get_uint8()
hash_algorithm = tok.get_uint8() hash_algorithm = tok.get_uint8()
@ -53,8 +56,7 @@ class ZONEMD(dns.rdata.Rdata):
return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest) return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!IBB", self.serial, self.scheme, header = struct.pack("!IBB", self.serial, self.scheme, self.hash_algorithm)
self.hash_algorithm)
file.write(header) file.write(header)
file.write(self.digest) file.write(self.digest)

View file

@ -18,51 +18,51 @@
"""Class ANY (generic) rdata type classes.""" """Class ANY (generic) rdata type classes."""
__all__ = [ __all__ = [
'AFSDB', "AFSDB",
'AMTRELAY', "AMTRELAY",
'AVC', "AVC",
'CAA', "CAA",
'CDNSKEY', "CDNSKEY",
'CDS', "CDS",
'CERT', "CERT",
'CNAME', "CNAME",
'CSYNC', "CSYNC",
'DLV', "DLV",
'DNAME', "DNAME",
'DNSKEY', "DNSKEY",
'DS', "DS",
'EUI48', "EUI48",
'EUI64', "EUI64",
'GPOS', "GPOS",
'HINFO', "HINFO",
'HIP', "HIP",
'ISDN', "ISDN",
'L32', "L32",
'L64', "L64",
'LOC', "LOC",
'LP', "LP",
'MX', "MX",
'NID', "NID",
'NINFO', "NINFO",
'NS', "NS",
'NSEC', "NSEC",
'NSEC3', "NSEC3",
'NSEC3PARAM', "NSEC3PARAM",
'OPENPGPKEY', "OPENPGPKEY",
'OPT', "OPT",
'PTR', "PTR",
'RP', "RP",
'RRSIG', "RRSIG",
'RT', "RT",
'SMIMEA', "SMIMEA",
'SOA', "SOA",
'SPF', "SPF",
'SSHFP', "SSHFP",
'TKEY', "TKEY",
'TLSA', "TLSA",
'TSIG', "TSIG",
'TXT', "TXT",
'URI', "URI",
'X25', "X25",
'ZONEMD', "ZONEMD",
] ]

View file

@ -20,6 +20,7 @@ import struct
import dns.rdtypes.mxbase import dns.rdtypes.mxbase
import dns.immutable import dns.immutable
@dns.immutable.immutable @dns.immutable.immutable
class A(dns.rdata.Rdata): class A(dns.rdata.Rdata):
@ -28,7 +29,7 @@ class A(dns.rdata.Rdata):
# domain: the domain of the address # domain: the domain of the address
# address: the 16-bit address # address: the 16-bit address
__slots__ = ['domain', 'address'] __slots__ = ["domain", "address"]
def __init__(self, rdclass, rdtype, domain, address): def __init__(self, rdclass, rdtype, domain, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -37,11 +38,12 @@ class A(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
domain = self.domain.choose_relativity(origin, relativize) domain = self.domain.choose_relativity(origin, relativize)
return '%s %o' % (domain, self.address) return "%s %o" % (domain, self.address)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
domain = tok.get_name(origin, relativize, relativize_to) domain = tok.get_name(origin, relativize, relativize_to)
address = tok.get_uint16(base=8) address = tok.get_uint16(base=8)
return cls(rdclass, rdtype, domain, address) return cls(rdclass, rdtype, domain, address)

View file

@ -18,5 +18,5 @@
"""Class CH rdata type classes.""" """Class CH rdata type classes."""
__all__ = [ __all__ = [
'A', "A",
] ]

View file

@ -27,7 +27,7 @@ class A(dns.rdata.Rdata):
"""A record.""" """A record."""
__slots__ = ['address'] __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -37,8 +37,9 @@ class A(dns.rdata.Rdata):
return self.address return self.address
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_identifier() address = tok.get_identifier()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)

View file

@ -27,7 +27,7 @@ class AAAA(dns.rdata.Rdata):
"""AAAA record.""" """AAAA record."""
__slots__ = ['address'] __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -37,8 +37,9 @@ class AAAA(dns.rdata.Rdata):
return self.address return self.address
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_identifier() address = tok.get_identifier()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)

View file

@ -26,12 +26,13 @@ import dns.ipv6
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable @dns.immutable.immutable
class APLItem: class APLItem:
"""An APL list item.""" """An APL list item."""
__slots__ = ['family', 'negation', 'address', 'prefix'] __slots__ = ["family", "negation", "address", "prefix"]
def __init__(self, family, negation, address, prefix): def __init__(self, family, negation, address, prefix):
self.family = dns.rdata.Rdata._as_uint16(family) self.family = dns.rdata.Rdata._as_uint16(family)
@ -72,7 +73,7 @@ class APLItem:
assert l < 128 assert l < 128
if self.negation: if self.negation:
l |= 0x80 l |= 0x80
header = struct.pack('!HBB', self.family, self.prefix, l) header = struct.pack("!HBB", self.family, self.prefix, l)
file.write(header) file.write(header)
file.write(address) file.write(address)
@ -84,32 +85,33 @@ class APL(dns.rdata.Rdata):
# see: RFC 3123 # see: RFC 3123
__slots__ = ['items'] __slots__ = ["items"]
def __init__(self, rdclass, rdtype, items): def __init__(self, rdclass, rdtype, items):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
for item in items: for item in items:
if not isinstance(item, APLItem): if not isinstance(item, APLItem):
raise ValueError('item not an APLItem') raise ValueError("item not an APLItem")
self.items = tuple(items) self.items = tuple(items)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return ' '.join(map(str, self.items)) return " ".join(map(str, self.items))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
items = [] items = []
for token in tok.get_remaining(): for token in tok.get_remaining():
item = token.unescape().value item = token.unescape().value
if item[0] == '!': if item[0] == "!":
negation = True negation = True
item = item[1:] item = item[1:]
else: else:
negation = False negation = False
(family, rest) = item.split(':', 1) (family, rest) = item.split(":", 1)
family = int(family) family = int(family)
(address, prefix) = rest.split('/', 1) (address, prefix) = rest.split("/", 1)
prefix = int(prefix) prefix = int(prefix)
item = APLItem(family, negation, address, prefix) item = APLItem(family, negation, address, prefix)
items.append(item) items.append(item)
@ -125,7 +127,7 @@ class APL(dns.rdata.Rdata):
items = [] items = []
while parser.remaining() > 0: while parser.remaining() > 0:
header = parser.get_struct('!HBB') header = parser.get_struct("!HBB")
afdlen = header[2] afdlen = header[2]
if afdlen > 127: if afdlen > 127:
negation = True negation = True
@ -136,16 +138,16 @@ class APL(dns.rdata.Rdata):
l = len(address) l = len(address)
if header[0] == 1: if header[0] == 1:
if l < 4: if l < 4:
address += b'\x00' * (4 - l) address += b"\x00" * (4 - l)
elif header[0] == 2: elif header[0] == 2:
if l < 16: if l < 16:
address += b'\x00' * (16 - l) address += b"\x00" * (16 - l)
else: else:
# #
# This isn't really right according to the RFC, but it # This isn't really right according to the RFC, but it
# seems better than throwing an exception # seems better than throwing an exception
# #
address = codecs.encode(address, 'hex_codec') address = codecs.encode(address, "hex_codec")
item = APLItem(header[0], negation, address, header[1]) item = APLItem(header[0], negation, address, header[1])
items.append(item) items.append(item)
return cls(rdclass, rdtype, items) return cls(rdclass, rdtype, items)

View file

@ -19,6 +19,7 @@ import base64
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.rdata
@dns.immutable.immutable @dns.immutable.immutable
@ -28,7 +29,7 @@ class DHCID(dns.rdata.Rdata):
# see: RFC 4701 # see: RFC 4701
__slots__ = ['data'] __slots__ = ["data"]
def __init__(self, rdclass, rdtype, data): def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -38,8 +39,9 @@ class DHCID(dns.rdata.Rdata):
return dns.rdata._base64ify(self.data, **kw) return dns.rdata._base64ify(self.data, **kw)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
data = base64.b64decode(b64) data = base64.b64decode(b64)
return cls(rdclass, rdtype, data) return cls(rdclass, rdtype, data)

View file

@ -3,6 +3,7 @@
import dns.rdtypes.svcbbase import dns.rdtypes.svcbbase
import dns.immutable import dns.immutable
@dns.immutable.immutable @dns.immutable.immutable
class HTTPS(dns.rdtypes.svcbbase.SVCBBase): class HTTPS(dns.rdtypes.svcbbase.SVCBBase):
"""HTTPS record""" """HTTPS record"""

View file

@ -24,7 +24,8 @@ import dns.rdtypes.util
class Gateway(dns.rdtypes.util.Gateway): class Gateway(dns.rdtypes.util.Gateway):
name = 'IPSECKEY gateway' name = "IPSECKEY gateway"
@dns.immutable.immutable @dns.immutable.immutable
class IPSECKEY(dns.rdata.Rdata): class IPSECKEY(dns.rdata.Rdata):
@ -33,10 +34,11 @@ class IPSECKEY(dns.rdata.Rdata):
# see: RFC 4025 # see: RFC 4025
__slots__ = ['precedence', 'gateway_type', 'algorithm', 'gateway', 'key'] __slots__ = ["precedence", "gateway_type", "algorithm", "gateway", "key"]
def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm, def __init__(
gateway, key): self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
gateway = Gateway(gateway_type, gateway) gateway = Gateway(gateway_type, gateway)
self.precedence = self._as_uint8(precedence) self.precedence = self._as_uint8(precedence)
@ -46,38 +48,45 @@ class IPSECKEY(dns.rdata.Rdata):
self.key = self._as_bytes(key) self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize)
relativize) return "%d %d %d %s %s" % (
return '%d %d %d %s %s' % (self.precedence, self.gateway_type, self.precedence,
self.algorithm, gateway, self.gateway_type,
dns.rdata._base64ify(self.key, **kw)) self.algorithm,
gateway,
dns.rdata._base64ify(self.key, **kw),
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
precedence = tok.get_uint8() precedence = tok.get_uint8()
gateway_type = tok.get_uint8() gateway_type = tok.get_uint8()
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
gateway = Gateway.from_text(gateway_type, tok, origin, relativize, gateway = Gateway.from_text(
relativize_to) gateway_type, tok, origin, relativize, relativize_to
)
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
key = base64.b64decode(b64) key = base64.b64decode(b64)
return cls(rdclass, rdtype, precedence, gateway_type, algorithm, return cls(
gateway.gateway, key) rdclass, rdtype, precedence, gateway_type, algorithm, gateway.gateway, key
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!BBB", self.precedence, self.gateway_type, header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm)
self.algorithm)
file.write(header) file.write(header)
Gateway(self.gateway_type, self.gateway).to_wire(file, compress, Gateway(self.gateway_type, self.gateway).to_wire(
origin, canonicalize) file, compress, origin, canonicalize
)
file.write(self.key) file.write(self.key)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!BBB') header = parser.get_struct("!BBB")
gateway_type = header[1] gateway_type = header[1]
gateway = Gateway.from_wire_parser(gateway_type, parser, origin) gateway = Gateway.from_wire_parser(gateway_type, parser, origin)
key = parser.get_remaining() key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], gateway_type, header[2], return cls(
gateway.gateway, key) rdclass, rdtype, header[0], gateway_type, header[2], gateway.gateway, key
)

View file

@ -27,7 +27,7 @@ import dns.rdtypes.util
def _write_string(file, s): def _write_string(file, s):
l = len(s) l = len(s)
assert l < 256 assert l < 256
file.write(struct.pack('!B', l)) file.write(struct.pack("!B", l))
file.write(s) file.write(s)
@ -38,11 +38,11 @@ class NAPTR(dns.rdata.Rdata):
# see: RFC 3403 # see: RFC 3403
__slots__ = ['order', 'preference', 'flags', 'service', 'regexp', __slots__ = ["order", "preference", "flags", "service", "regexp", "replacement"]
'replacement']
def __init__(self, rdclass, rdtype, order, preference, flags, service, def __init__(
regexp, replacement): self, rdclass, rdtype, order, preference, flags, service, regexp, replacement
):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.flags = self._as_bytes(flags, True, 255) self.flags = self._as_bytes(flags, True, 255)
self.service = self._as_bytes(service, True, 255) self.service = self._as_bytes(service, True, 255)
@ -53,24 +53,28 @@ class NAPTR(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
replacement = self.replacement.choose_relativity(origin, relativize) replacement = self.replacement.choose_relativity(origin, relativize)
return '%d %d "%s" "%s" "%s" %s' % \ return '%d %d "%s" "%s" "%s" %s' % (
(self.order, self.preference, self.order,
self.preference,
dns.rdata._escapify(self.flags), dns.rdata._escapify(self.flags),
dns.rdata._escapify(self.service), dns.rdata._escapify(self.service),
dns.rdata._escapify(self.regexp), dns.rdata._escapify(self.regexp),
replacement) replacement,
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
order = tok.get_uint16() order = tok.get_uint16()
preference = tok.get_uint16() preference = tok.get_uint16()
flags = tok.get_string() flags = tok.get_string()
service = tok.get_string() service = tok.get_string()
regexp = tok.get_string() regexp = tok.get_string()
replacement = tok.get_name(origin, relativize, relativize_to) replacement = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, order, preference, flags, service, return cls(
regexp, replacement) rdclass, rdtype, order, preference, flags, service, regexp, replacement
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
two_ints = struct.pack("!HH", self.order, self.preference) two_ints = struct.pack("!HH", self.order, self.preference)
@ -82,14 +86,22 @@ class NAPTR(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(order, preference) = parser.get_struct('!HH') (order, preference) = parser.get_struct("!HH")
strings = [] strings = []
for _ in range(3): for _ in range(3):
s = parser.get_counted_bytes() s = parser.get_counted_bytes()
strings.append(s) strings.append(s)
replacement = parser.get_name(origin) replacement = parser.get_name(origin)
return cls(rdclass, rdtype, order, preference, strings[0], strings[1], return cls(
strings[2], replacement) rdclass,
rdtype,
order,
preference,
strings[0],
strings[1],
strings[2],
replacement,
)
def _processing_priority(self): def _processing_priority(self):
return (self.order, self.preference) return (self.order, self.preference)

View file

@ -30,7 +30,7 @@ class NSAP(dns.rdata.Rdata):
# see: RFC 1706 # see: RFC 1706
__slots__ = ['address'] __slots__ = ["address"]
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -40,14 +40,15 @@ class NSAP(dns.rdata.Rdata):
return "0x%s" % binascii.hexlify(self.address).decode() return "0x%s" % binascii.hexlify(self.address).decode()
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string() address = tok.get_string()
if address[0:2] != '0x': if address[0:2] != "0x":
raise dns.exception.SyntaxError('string does not start with 0x') raise dns.exception.SyntaxError("string does not start with 0x")
address = address[2:].replace('.', '') address = address[2:].replace(".", "")
if len(address) % 2 != 0: if len(address) % 2 != 0:
raise dns.exception.SyntaxError('hexstring has odd length') raise dns.exception.SyntaxError("hexstring has odd length")
address = binascii.unhexlify(address.encode()) address = binascii.unhexlify(address.encode())
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)

View file

@ -31,7 +31,7 @@ class PX(dns.rdata.Rdata):
# see: RFC 2163 # see: RFC 2163
__slots__ = ['preference', 'map822', 'mapx400'] __slots__ = ["preference", "map822", "mapx400"]
def __init__(self, rdclass, rdtype, preference, map822, mapx400): def __init__(self, rdclass, rdtype, preference, map822, mapx400):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -42,11 +42,12 @@ class PX(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
map822 = self.map822.choose_relativity(origin, relativize) map822 = self.map822.choose_relativity(origin, relativize)
mapx400 = self.mapx400.choose_relativity(origin, relativize) mapx400 = self.mapx400.choose_relativity(origin, relativize)
return '%d %s %s' % (self.preference, map822, mapx400) return "%d %s %s" % (self.preference, map822, mapx400)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16() preference = tok.get_uint16()
map822 = tok.get_name(origin, relativize, relativize_to) map822 = tok.get_name(origin, relativize, relativize_to)
mapx400 = tok.get_name(origin, relativize, relativize_to) mapx400 = tok.get_name(origin, relativize, relativize_to)

View file

@ -31,7 +31,7 @@ class SRV(dns.rdata.Rdata):
# see: RFC 2782 # see: RFC 2782
__slots__ = ['priority', 'weight', 'port', 'target'] __slots__ = ["priority", "weight", "port", "target"]
def __init__(self, rdclass, rdtype, priority, weight, port, target): def __init__(self, rdclass, rdtype, priority, weight, port, target):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -42,12 +42,12 @@ class SRV(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
target = self.target.choose_relativity(origin, relativize) target = self.target.choose_relativity(origin, relativize)
return '%d %d %d %s' % (self.priority, self.weight, self.port, return "%d %d %d %s" % (self.priority, self.weight, self.port, target)
target)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
priority = tok.get_uint16() priority = tok.get_uint16()
weight = tok.get_uint16() weight = tok.get_uint16()
port = tok.get_uint16() port = tok.get_uint16()
@ -61,7 +61,7 @@ class SRV(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(priority, weight, port) = parser.get_struct('!HHH') (priority, weight, port) = parser.get_struct("!HHH")
target = parser.get_name(origin) target = parser.get_name(origin)
return cls(rdclass, rdtype, priority, weight, port, target) return cls(rdclass, rdtype, priority, weight, port, target)

View file

@ -3,6 +3,7 @@
import dns.rdtypes.svcbbase import dns.rdtypes.svcbbase
import dns.immutable import dns.immutable
@dns.immutable.immutable @dns.immutable.immutable
class SVCB(dns.rdtypes.svcbbase.SVCBBase): class SVCB(dns.rdtypes.svcbbase.SVCBBase):
"""SVCB record""" """SVCB record"""

View file

@ -23,13 +23,14 @@ import dns.immutable
import dns.rdata import dns.rdata
try: try:
_proto_tcp = socket.getprotobyname('tcp') _proto_tcp = socket.getprotobyname("tcp")
_proto_udp = socket.getprotobyname('udp') _proto_udp = socket.getprotobyname("udp")
except OSError: except OSError:
# Fall back to defaults in case /etc/protocols is unavailable. # Fall back to defaults in case /etc/protocols is unavailable.
_proto_tcp = 6 _proto_tcp = 6
_proto_udp = 17 _proto_udp = 17
@dns.immutable.immutable @dns.immutable.immutable
class WKS(dns.rdata.Rdata): class WKS(dns.rdata.Rdata):
@ -37,7 +38,7 @@ class WKS(dns.rdata.Rdata):
# see: RFC 1035 # see: RFC 1035
__slots__ = ['address', 'protocol', 'bitmap'] __slots__ = ["address", "protocol", "bitmap"]
def __init__(self, rdclass, rdtype, address, protocol, bitmap): def __init__(self, rdclass, rdtype, address, protocol, bitmap):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
@ -51,12 +52,13 @@ class WKS(dns.rdata.Rdata):
for j in range(0, 8): for j in range(0, 8):
if byte & (0x80 >> j): if byte & (0x80 >> j):
bits.append(str(i * 8 + j)) bits.append(str(i * 8 + j))
text = ' '.join(bits) text = " ".join(bits)
return '%s %d %s' % (self.address, self.protocol, text) return "%s %d %s" % (self.address, self.protocol, text)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string() address = tok.get_string()
protocol = tok.get_string() protocol = tok.get_string()
if protocol.isdigit(): if protocol.isdigit():
@ -87,7 +89,7 @@ class WKS(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(dns.ipv4.inet_aton(self.address)) file.write(dns.ipv4.inet_aton(self.address))
protocol = struct.pack('!B', self.protocol) protocol = struct.pack("!B", self.protocol)
file.write(protocol) file.write(protocol)
file.write(self.bitmap) file.write(self.bitmap)

View file

@ -18,18 +18,18 @@
"""Class IN rdata type classes.""" """Class IN rdata type classes."""
__all__ = [ __all__ = [
'A', "A",
'AAAA', "AAAA",
'APL', "APL",
'DHCID', "DHCID",
'HTTPS', "HTTPS",
'IPSECKEY', "IPSECKEY",
'KX', "KX",
'NAPTR', "NAPTR",
'NSAP', "NSAP",
'NSAP_PTR', "NSAP_PTR",
'PX', "PX",
'SRV', "SRV",
'SVCB', "SVCB",
'WKS', "WKS",
] ]

View file

@ -18,16 +18,16 @@
"""DNS rdata type classes""" """DNS rdata type classes"""
__all__ = [ __all__ = [
'ANY', "ANY",
'IN', "IN",
'CH', "CH",
'dnskeybase', "dnskeybase",
'dsbase', "dsbase",
'euibase', "euibase",
'mxbase', "mxbase",
'nsbase', "nsbase",
'svcbbase', "svcbbase",
'tlsabase', "tlsabase",
'txtbase', "txtbase",
'util' "util",
] ]

View file

@ -21,12 +21,13 @@ import struct
import dns.exception import dns.exception
import dns.immutable import dns.immutable
import dns.dnssec import dns.dnssectypes
import dns.rdata import dns.rdata
# wildcard import # wildcard import
__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 __all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822
class Flag(enum.IntFlag): class Flag(enum.IntFlag):
SEP = 0x0001 SEP = 0x0001
REVOKE = 0x0080 REVOKE = 0x0080
@ -38,22 +39,27 @@ class DNSKEYBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DNSKEY record""" """Base class for rdata that is like a DNSKEY record"""
__slots__ = ['flags', 'protocol', 'algorithm', 'key'] __slots__ = ["flags", "protocol", "algorithm", "key"]
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
self.flags = self._as_uint16(flags) self.flags = self._as_uint16(flags)
self.protocol = self._as_uint8(protocol) self.protocol = self._as_uint8(protocol)
self.algorithm = dns.dnssec.Algorithm.make(algorithm) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.key = self._as_bytes(key) self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '%d %d %d %s' % (self.flags, self.protocol, self.algorithm, return "%d %d %d %s" % (
dns.rdata._base64ify(self.key, **kw)) self.flags,
self.protocol,
self.algorithm,
dns.rdata._base64ify(self.key, **kw),
)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(
relativize_to=None): cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
flags = tok.get_uint16() flags = tok.get_uint16()
protocol = tok.get_uint8() protocol = tok.get_uint8()
algorithm = tok.get_string() algorithm = tok.get_string()
@ -68,10 +74,10 @@ class DNSKEYBase(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!HBB') header = parser.get_struct("!HBB")
key = parser.get_remaining() key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2], return cls(rdclass, rdtype, header[0], header[1], header[2], key)
key)
### BEGIN generated Flag constants ### BEGIN generated Flag constants

Some files were not shown because too many files have changed in this diff Show more