Bump dnspython from 2.2.1 to 2.3.0 (#1975)

* Bump dnspython from 2.2.1 to 2.3.0

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.2.1 to 2.3.0.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.2.1...v2.3.0)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.3.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2023-03-02 20:54:32 -08:00 committed by GitHub
parent 6910079330
commit 32c06a8b72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
137 changed files with 7699 additions and 4277 deletions

View file

@ -18,49 +18,52 @@
"""dnspython DNS toolkit"""
__all__ = [
'asyncbackend',
'asyncquery',
'asyncresolver',
'dnssec',
'e164',
'edns',
'entropy',
'exception',
'flags',
'immutable',
'inet',
'ipv4',
'ipv6',
'message',
'name',
'namedict',
'node',
'opcode',
'query',
'rcode',
'rdata',
'rdataclass',
'rdataset',
'rdatatype',
'renderer',
'resolver',
'reversename',
'rrset',
'serial',
'set',
'tokenizer',
'transaction',
'tsig',
'tsigkeyring',
'ttl',
'rdtypes',
'update',
'version',
'versioned',
'wire',
'xfr',
'zone',
'zonefile',
"asyncbackend",
"asyncquery",
"asyncresolver",
"dnssec",
"dnssectypes",
"e164",
"edns",
"entropy",
"exception",
"flags",
"immutable",
"inet",
"ipv4",
"ipv6",
"message",
"name",
"namedict",
"node",
"opcode",
"query",
"quic",
"rcode",
"rdata",
"rdataclass",
"rdataset",
"rdatatype",
"renderer",
"resolver",
"reversename",
"rrset",
"serial",
"set",
"tokenizer",
"transaction",
"tsig",
"tsigkeyring",
"ttl",
"rdtypes",
"update",
"version",
"versioned",
"wire",
"xfr",
"zone",
"zonetypes",
"zonefile",
]
from dns.version import version as __version__ # noqa

View file

@ -3,6 +3,7 @@
# This is a nullcontext for both sync and async. 3.7 has a nullcontext,
# but it is only for sync use.
class NullContext:
def __init__(self, enter_result=None):
self.enter_result = enter_result
@ -23,6 +24,7 @@ class NullContext:
# These are declared here so backends can import them without creating
# circular dependencies with dns.asyncbackend.
class Socket: # pragma: no cover
async def close(self):
pass
@ -41,6 +43,9 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout):
raise NotImplementedError
@ -56,14 +61,25 @@ class StreamSocket(Socket): # pragma: no cover
raise NotImplementedError
class Backend: # pragma: no cover
class Backend: # pragma: no cover
def name(self):
return 'unknown'
return "unknown"
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
raise NotImplementedError
def datagram_connection_required(self):
return False
async def sleep(self, interval):
raise NotImplementedError

View file

@ -10,7 +10,8 @@ import dns._asyncbackend
import dns.exception
_is_win32 = sys.platform == 'win32'
_is_win32 = sys.platform == "win32"
def _get_running_loop():
try:
@ -30,7 +31,6 @@ class _DatagramProtocol:
def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
@ -56,30 +56,34 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
self.family = family
super().__init__(family)
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
return len(what)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
try:
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
finally:
self.protocol.recvfrom = None
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info('peername')
return self.transport.get_extra_info("peername")
async def getsockname(self):
return self.transport.get_extra_info('sockname')
return self.transport.get_extra_info("sockname")
class StreamSocket(dns._asyncbackend.StreamSocket):
@ -93,8 +97,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size),
timeout)
return await _maybe_wait_for(self.reader.read(size), timeout)
async def close(self):
self.writer.close()
@ -104,43 +107,64 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
pass
async def getpeername(self):
return self.writer.get_extra_info('peername')
return self.writer.get_extra_info("peername")
async def getsockname(self):
return self.writer.get_extra_info('sockname')
return self.writer.get_extra_info("sockname")
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'asyncio'
return "asyncio"
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
if destination is None and socktype == socket.SOCK_DGRAM and \
_is_win32:
raise NotImplementedError('destinationless datagram sockets '
'are not supported by asyncio '
'on Windows')
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if destination is None and socktype == socket.SOCK_DGRAM and _is_win32:
raise NotImplementedError(
"destinationless datagram sockets "
"are not supported by asyncio "
"on Windows"
)
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af,
proto=proto, remote_addr=destination)
_DatagramProtocol,
source,
family=af,
proto=proto,
remote_addr=destination,
)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
if destination is None:
# This shouldn't happen, but we check to make code analysis software
# happier.
raise ValueError("destination required for stream sockets")
(r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname),
timeout)
asyncio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname,
),
timeout,
)
return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
@ -57,12 +59,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.socket.close()
@ -76,11 +78,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'curio'
return "curio"
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto)
try:
@ -96,13 +106,17 @@ class Backend(dns._asyncbackend.Backend):
else:
source_addr = None
async with _maybe_timeout(timeout):
s = await curio.open_connection(destination[0], destination[1],
ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname)
s = await curio.open_connection(
destination[0],
destination[1],
ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname,
)
return StreamSocket(s)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await curio.sleep(interval)

View file

@ -1,84 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator is for python 3.6,
# which doesn't have Context Variables. This implementation is somewhat
# costly for classes with slots, as it adds a __dict__ to them.
import inspect
class _Immutable:
"""Immutable mixin class"""
# Note we MUST NOT have __slots__ as that causes
#
# TypeError: multiple bases have instance lay-out conflict
#
# when we get mixed in with another class with slots. When we
# get mixed into something with slots, it effectively adds __dict__ to
# the slots of the other class, which allows attribute setting to work,
# albeit at the cost of the dictionary.
def __setattr__(self, name, value):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
try:
# Are we already initializing an immutable class?
previous = args[0]._immutable_init
except AttributeError:
# We are the first!
previous = None
object.__setattr__(args[0], '_immutable_init', args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
if not previous:
# If we started the initialization, establish immutability
# by removing the attribute that allows mutation
object.__delattr__(args[0], '_immutable_init')
nf.__signature__ = inspect.signature(f)
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View file

@ -8,7 +8,7 @@ import contextvars
import inspect
_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False)
class _Immutable:
@ -41,6 +41,7 @@ def _immutable_init(f):
f(*args, **kwargs)
finally:
_in__init__.reset(previous)
nf.__signature__ = inspect.signature(f)
return nf
@ -50,7 +51,7 @@ def immutable(cls):
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
if hasattr(cls, "__setstate__"):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
@ -63,7 +64,8 @@ def immutable(cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
if hasattr(cls, "__setstate__"):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)

View file

@ -26,18 +26,20 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
self.socket.close()
@ -58,12 +60,12 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
async def sendall(self, what, timeout):
with _maybe_timeout(timeout):
return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
with _maybe_timeout(timeout):
return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.stream.aclose()
@ -83,11 +85,19 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'trio'
return "trio"
async def make_socket(self, af, socktype, proto=0, source=None,
destination=None, timeout=None,
ssl_context=None, server_hostname=None):
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto)
stream = None
try:
@ -103,19 +113,20 @@ class Backend(dns._asyncbackend.Backend):
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s)
s = None
tls = False
if ssl_context:
tls = True
try:
stream = trio.SSLStream(stream, ssl_context,
server_hostname=server_hostname)
stream = trio.SSLStream(
stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover
await stream.aclose()
raise
return StreamSocket(af, stream, tls)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await trio.sleep(interval)

View file

@ -1,26 +1,33 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import Dict
import dns.exception
# pylint: disable=unused-import
from dns._asyncbackend import Socket, DatagramSocket, \
StreamSocket, Backend # noqa:
from dns._asyncbackend import (
Socket,
DatagramSocket,
StreamSocket,
Backend,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import
_default_backend = None
_backends = {}
_backends: Dict[str, Backend] = {}
# Allow sniffio import to be disabled for testing purposes
_no_sniffio = False
class AsyncLibraryNotFoundError(dns.exception.DNSException):
pass
def get_backend(name):
def get_backend(name: str) -> Backend:
"""Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio",
@ -32,22 +39,25 @@ def get_backend(name):
backend = _backends.get(name)
if backend:
return backend
if name == 'trio':
if name == "trio":
import dns._trio_backend
backend = dns._trio_backend.Backend()
elif name == 'curio':
elif name == "curio":
import dns._curio_backend
backend = dns._curio_backend.Backend()
elif name == 'asyncio':
elif name == "asyncio":
import dns._asyncio_backend
backend = dns._asyncio_backend.Backend()
else:
raise NotImplementedError(f'unimplemented async backend {name}')
raise NotImplementedError(f"unimplemented async backend {name}")
_backends[name] = backend
return backend
def sniff():
def sniff() -> str:
"""Attempt to determine the in-use asynchronous I/O library by using
the ``sniffio`` module if it is available.
@ -59,35 +69,32 @@ def sniff():
if _no_sniffio:
raise ImportError
import sniffio
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
raise AsyncLibraryNotFoundError('sniffio cannot determine ' +
'async library')
raise AsyncLibraryNotFoundError(
"sniffio cannot determine " + "async library"
)
except ImportError:
import asyncio
try:
asyncio.get_running_loop()
return 'asyncio'
return "asyncio"
except RuntimeError:
raise AsyncLibraryNotFoundError('no async library detected')
except AttributeError: # pragma: no cover
# we have to check current_task on 3.6
if not asyncio.Task.current_task():
raise AsyncLibraryNotFoundError('no async library detected')
return 'asyncio'
raise AsyncLibraryNotFoundError("no async library detected")
def get_default_backend():
"""Get the default backend, initializing it if necessary.
"""
def get_default_backend() -> Backend:
"""Get the default backend, initializing it if necessary."""
if _default_backend:
return _default_backend
return set_default_backend(sniff())
def set_default_backend(name):
def set_default_backend(name: str) -> Backend:
"""Set the default backend.
It's not normally necessary to call this method, as

View file

@ -1,13 +0,0 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
class Backend:
...
def get_backend(name: str) -> Backend:
...
def sniff() -> str:
...
def get_default_backend() -> Backend:
...
def set_default_backend(name: str) -> Backend:
...

View file

@ -17,7 +17,10 @@
"""Talk to a DNS server."""
from typing import Any, Dict, Optional, Tuple, Union
import base64
import contextlib
import socket
import struct
import time
@ -27,12 +30,24 @@ import dns.exception
import dns.inet
import dns.name
import dns.message
import dns.quic
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.transaction
from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
UDPMode, _have_httpx, _have_http2, NoDOH
from dns._asyncbackend import NullContext
from dns.query import (
_compute_times,
_matches_destination,
BadResponse,
ssl,
UDPMode,
_have_httpx,
_have_http2,
NoDOH,
NoDOQ,
)
if _have_httpx:
import httpx
@ -47,11 +62,11 @@ def _source_tuple(af, address, port):
if address or port:
if address is None:
if af == socket.AF_INET:
address = '0.0.0.0'
address = "0.0.0.0"
elif af == socket.AF_INET6:
address = '::'
address = "::"
else:
raise NotImplementedError(f'unknown address family {af}')
raise NotImplementedError(f"unknown address family {af}")
return (address, port)
else:
return None
@ -66,7 +81,12 @@ def _timeout(expiration, now=None):
return None
async def send_udp(sock, what, destination, expiration=None):
async def send_udp(
sock: dns.asyncbackend.DatagramSocket,
what: Union[dns.message.Message, bytes],
destination: Any,
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
@ -78,7 +98,8 @@ async def send_udp(sock, what, destination, expiration=None):
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
occur. The expiration value is meaningless for the asyncio backend, as
asyncio's transport sendto() never blocks.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
"""
@ -90,35 +111,61 @@ async def send_udp(sock, what, destination, expiration=None):
return (n, sent_time)
async def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
async def receive_udp(
sock: dns.asyncbackend.DatagramSocket,
destination: Optional[Any] = None,
expiration: Optional[float] = None,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
) -> Any:
"""Read a DNS message from a UDP socket.
*sock*, a ``dns.asyncbackend.DatagramSocket``.
See :py:func:`dns.query.receive_udp()` for the documentation of the other
parameters, exceptions, and return type of this method.
parameters, and exceptions.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
received time, and the address where the message arrived from.
"""
wire = b''
wire = b""
while 1:
(wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
if _matches_destination(sock.family, from_address, destination,
ignore_unexpected):
if _matches_destination(
sock.family, from_address, destination, ignore_unexpected
):
break
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation,
)
return (r, received_time, from_address)
async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
ignore_trailing=False, raise_on_truncation=False, sock=None,
backend=None):
async def udp(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
raise_on_truncation: bool = False,
sock: Optional[dns.asyncbackend.DatagramSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via UDP.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
@ -134,42 +181,52 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
"""
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
s = None
# After 3.6 is no longer supported, this can use an AsyncExitStack.
try:
af = dns.inet.af_for_address(where)
destination = _lltuple((where, port), af)
if sock:
s = sock
af = dns.inet.af_for_address(where)
destination = _lltuple((where, port), af)
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if not backend:
backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port)
if backend.datagram_connection_required():
dtuple = (where, port)
else:
if not backend:
backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port)
if backend.datagram_connection_required():
dtuple = (where, port)
else:
dtuple = None
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
dtuple)
dtuple = None
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
async with cm as s:
await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration,
ignore_unexpected,
one_rr_per_rrset,
q.keyring, q.mac,
ignore_trailing,
raise_on_truncation)
(r, received_time, _) = await receive_udp(
s,
destination,
expiration,
ignore_unexpected,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
raise_on_truncation,
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
finally:
if not sock and s:
await s.close()
async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
source_port=0, ignore_unexpected=False,
one_rr_per_rrset=False, ignore_trailing=False,
udp_sock=None, tcp_sock=None, backend=None):
async def udp_with_fallback(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
ignore_unexpected: bool = False,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None,
tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> Tuple[dns.message.Message, bool]:
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
@ -191,18 +248,42 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
method.
"""
try:
response = await udp(q, where, timeout, port, source, source_port,
ignore_unexpected, one_rr_per_rrset,
ignore_trailing, True, udp_sock, backend)
response = await udp(
q,
where,
timeout,
port,
source,
source_port,
ignore_unexpected,
one_rr_per_rrset,
ignore_trailing,
True,
udp_sock,
backend,
)
return (response, False)
except dns.message.Truncated:
response = await tcp(q, where, timeout, port, source, source_port,
one_rr_per_rrset, ignore_trailing, tcp_sock,
backend)
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
tcp_sock,
backend,
)
return (response, True)
async def send_tcp(sock, what, expiration=None):
async def send_tcp(
sock: dns.asyncbackend.StreamSocket,
what: Union[dns.message.Message, bytes],
expiration: Optional[float] = None,
) -> Tuple[int, float]:
"""Send a DNS message to the specified TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
@ -212,12 +293,14 @@ async def send_tcp(sock, what, expiration=None):
"""
if isinstance(what, dns.message.Message):
what = what.to_wire()
l = len(what)
wire = what.to_wire()
else:
wire = what
l = len(wire)
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + what
tcpmsg = struct.pack("!H", l) + wire
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
@ -227,18 +310,24 @@ async def _read_exactly(sock, count, expiration):
"""Read the specified number of bytes from stream. Keep trying until we
either get the desired amount, or we hit EOF.
"""
s = b''
s = b""
while count > 0:
n = await sock.recv(count, _timeout(expiration))
if n == b'':
if n == b"":
raise EOFError
count = count - len(n)
s = s + n
return s
async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False):
async def receive_tcp(
sock: dns.asyncbackend.StreamSocket,
expiration: Optional[float] = None,
one_rr_per_rrset: bool = False,
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None,
request_mac: Optional[bytes] = b"",
ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
"""Read a DNS message from a TCP socket.
*sock*, a ``dns.asyncbackend.StreamSocket``.
@ -251,15 +340,28 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
(l,) = struct.unpack("!H", ldata)
wire = await _read_exactly(sock, l, expiration)
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
r = dns.message.from_wire(
wire,
keyring=keyring,
request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
return (r, received_time)
async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
backend=None):
async def tcp(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 53,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TCP.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
@ -276,41 +378,48 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
s = None
# After 3.6 is no longer supported, this can use an AsyncExitStack.
try:
if sock:
# Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or
# hang forever.
await sock.getpeername()
s = sock
else:
# These are simple (address, port) pairs, not
# family-dependent tuples you pass to lowlevel socket
# code.
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
dtuple, timeout)
if sock:
# Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or
# hang forever.
await sock.getpeername()
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
# These are simple (address, port) pairs, not family-dependent tuples
# you pass to low-level socket code.
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
cm = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
await send_tcp(s, wire, expiration)
(r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac,
ignore_trailing)
(r, received_time) = await receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
)
r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
finally:
if not sock and s:
await s.close()
async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
backend=None, ssl_context=None, server_hostname=None):
async def tls(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
sock: Optional[dns.asyncbackend.StreamSocket] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
@ -326,11 +435,14 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
See :py:func:`dns.query.tls()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
# After 3.6 is no longer supported, this can use an AsyncExitStack.
(begin_time, expiration) = _compute_times(timeout)
if not sock:
if sock:
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
ssl_context = ssl.create_default_context()
# See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
else:
@ -341,25 +453,49 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
dtuple, timeout, ssl_context,
server_hostname)
else:
s = sock
try:
cm = await backend.make_socket(
af,
socket.SOCK_STREAM,
0,
stuple,
dtuple,
timeout,
ssl_context,
server_hostname,
)
async with cm as s:
timeout = _timeout(expiration)
response = await tcp(q, where, timeout, port, source, source_port,
one_rr_per_rrset, ignore_trailing, s, backend)
response = await tcp(
q,
where,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
s,
backend,
)
end_time = time.time()
response.time = end_time - begin_time
return response
finally:
if not sock and s:
await s.close()
async def https(q, where, timeout=None, port=443, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, client=None,
path='/dns-query', post=True, verify=True):
async def https(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 443,
source: Optional[str] = None,
source_port: int = 0, # pylint: disable=W0613
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
client: Optional["httpx.AsyncClient"] = None,
path: str = "/dns-query",
post: bool = True,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
@ -373,7 +509,7 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
"""
if not _have_httpx:
raise NoDOH('httpx is not available.') # pragma: no cover
raise NoDOH("httpx is not available.") # pragma: no cover
wire = q.to_wire()
try:
@ -381,65 +517,78 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0,
except ValueError:
af = None
transport = None
headers = {
"accept": "application/dns-message"
}
headers = {"accept": "application/dns-message"}
if af is not None:
if af == socket.AF_INET:
url = 'https://{}:{}{}'.format(where, port, path)
url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path)
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0])
# After 3.6 is no longer supported, this can use an AsyncExitStack
client_to_close = None
try:
if not client:
client = httpx.AsyncClient(http1=True, http2=_have_http2,
verify=verify, transport=transport)
client_to_close = client
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=_have_http2, verify=verify, transport=transport
)
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
headers.update({
"content-type": "application/dns-message",
"content-length": str(len(wire))
})
response = await client.post(url, headers=headers, content=wire,
timeout=timeout)
headers.update(
{
"content-type": "application/dns-message",
"content-length": str(len(wire)),
}
)
response = await the_client.post(
url, headers=headers, content=wire, timeout=timeout
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
wire = wire.decode() # httpx does a repr() if we give it bytes
response = await client.get(url, headers=headers, timeout=timeout,
params={"dns": wire})
finally:
if client_to_close:
await client.aclose()
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await the_client.get(
url, headers=headers, timeout=timeout, params={"dns": twire}
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError('{} responded with status code {}'
'\nResponse body: {}'.format(where,
response.status_code,
response.content))
r = dns.message.from_wire(response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
r.time = response.elapsed
raise ValueError(
"{} responded with status code {}"
"\nResponse body: {!r}".format(
where, response.status_code, response.content
)
)
r = dns.message.from_wire(
response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = response.elapsed.total_seconds()
if not q.is_response(r):
raise BadResponse
return r
async def inbound_xfr(where, txn_manager, query=None,
port=53, timeout=None, lifetime=None, source=None,
source_port=0, udp_mode=UDPMode.NEVER, backend=None):
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
query: Optional[dns.message.Message] = None,
port: int = 53,
timeout: Optional[float] = None,
lifetime: Optional[float] = None,
source: Optional[str] = None,
source_port: int = 0,
udp_mode: UDPMode = UDPMode.NEVER,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> None:
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
@ -472,42 +621,48 @@ async def inbound_xfr(where, txn_manager, query=None,
is_udp = False
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, sock_type, 0, stuple, dtuple,
_timeout(expiration))
s = await backend.make_socket(
af, sock_type, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration))
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial,
is_udp) as inbound:
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \
(expiration is not None and mexpiration > expiration):
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
destination = _lltuple((where, port), af)
while True:
timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535,
timeout)
if _matches_destination(af, from_address,
destination, True):
(rwire, from_address) = await s.recvfrom(65535, timeout)
if _matches_destination(
af, from_address, destination, True
):
break
else:
ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = (rdtype == dns.rdatatype.IXFR)
r = dns.message.from_wire(rwire, keyring=query.keyring,
request_mac=query.mac, xfr=True,
origin=origin, tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
@ -521,3 +676,62 @@ async def inbound_xfr(where, txn_manager, query=None,
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(context, verify_mode=verify) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
start = time.time()
stream = await the_connection.make_stream()
async with stream:
await stream.send(wire, True)
wire = await stream.receive(timeout)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r

View file

@ -1,43 +0,0 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
async def udp(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.DatagramSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None) -> message.Message:
pass
async def tls(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

View file

@ -17,13 +17,18 @@
"""Asynchronous DNS stub resolver."""
from typing import Any, Dict, Optional, Union
import time
import dns.asyncbackend
import dns.asyncquery
import dns.exception
import dns.name
import dns.query
import dns.resolver
import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
@ -37,11 +42,19 @@ _tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
async def resolve(self, qname, rdtype=dns.rdatatype.A,
rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime=None, search=None,
backend=None):
async def resolve(
self,
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
@ -52,8 +65,9 @@ class Resolver(dns.resolver.BaseResolver):
type of this method.
"""
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
raise_on_no_answer, search)
resolution = dns.resolver._Resolution(
self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search
)
if not backend:
backend = dns.asyncbackend.get_default_backend()
start = time.time()
@ -66,30 +80,40 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None:
# cache hit!
return answer
assert request is not None # needed for type checking
done = False
while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
if backoff:
await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime,
resolution.errors)
timeout = self._compute_timeout(start, lifetime, resolution.errors)
try:
if dns.inet.is_address(nameserver):
if tcp:
response = await _tcp(request, nameserver,
timeout, port,
source, source_port,
backend=backend)
response = await _tcp(
request,
nameserver,
timeout,
port,
source,
source_port,
backend=backend,
)
else:
response = await _udp(request, nameserver,
timeout, port,
source, source_port,
raise_on_truncation=True,
backend=backend)
response = await _udp(
request,
nameserver,
timeout,
port,
source,
source_port,
raise_on_truncation=True,
backend=backend,
)
else:
response = await dns.asyncquery.https(request,
nameserver,
timeout=timeout)
response = await dns.asyncquery.https(
request, nameserver, timeout=timeout
)
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
@ -101,7 +125,9 @@ class Resolver(dns.resolver.BaseResolver):
if answer is not None:
return answer
async def resolve_address(self, ipaddr, *args, **kwargs):
async def resolve_address(
self, ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use an asynchronous resolver to run a reverse query for PTR
records.
@ -116,15 +142,20 @@ class Resolver(dns.resolver.BaseResolver):
function.
"""
return await self.resolve(dns.reversename.from_address(ipaddr),
rdtype=dns.rdatatype.PTR,
rdclass=dns.rdataclass.IN,
*args, **kwargs)
# We make a modified kwargs for type checking happiness, as otherwise
# we get a legit warning about possibly having rdtype and rdclass
# in the kwargs more than once.
modified_kwargs: Dict[str, Any] = {}
modified_kwargs.update(kwargs)
modified_kwargs["rdtype"] = dns.rdatatype.PTR
modified_kwargs["rdclass"] = dns.rdataclass.IN
return await self.resolve(
dns.reversename.from_address(ipaddr), *args, **modified_kwargs
)
# pylint: disable=redefined-outer-name
async def canonical_name(self, name):
async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
@ -149,14 +180,15 @@ class Resolver(dns.resolver.BaseResolver):
default_resolver = None
def get_default_resolver():
def get_default_resolver() -> Resolver:
"""Get the default asynchronous resolver, initializing it if necessary."""
if default_resolver is None:
reset_default_resolver()
assert default_resolver is not None
return default_resolver
def reset_default_resolver():
def reset_default_resolver() -> None:
"""Re-initialize default asynchronous resolver.
Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
@ -167,9 +199,18 @@ def reset_default_resolver():
default_resolver = Resolver()
async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime=None, search=None, backend=None):
async def resolve(
qname: Union[dns.name.Name, str],
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A,
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN,
tcp: bool = False,
source: Optional[str] = None,
raise_on_no_answer: bool = True,
source_port: int = 0,
lifetime: Optional[float] = None,
search: Optional[bool] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.resolver.Answer:
"""Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver
@ -179,13 +220,23 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
information on the parameters.
"""
return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp,
source, raise_on_no_answer,
source_port, lifetime, search,
backend)
return await get_default_resolver().resolve(
qname,
rdtype,
rdclass,
tcp,
source,
raise_on_no_answer,
source_port,
lifetime,
search,
backend,
)
async def resolve_address(ipaddr, *args, **kwargs):
async def resolve_address(
ipaddr: str, *args: Any, **kwargs: Any
) -> dns.resolver.Answer:
"""Use a resolver to run a reverse query for PTR records.
See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
@ -194,7 +245,8 @@ async def resolve_address(ipaddr, *args, **kwargs):
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def canonical_name(name):
async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
@ -203,8 +255,14 @@ async def canonical_name(name):
return await get_default_resolver().canonical_name(name)
async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
resolver=None, backend=None):
async def zone_for_name(
name: Union[dns.name.Name, str],
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
tcp: bool = False,
resolver: Optional[Resolver] = None,
backend: Optional[dns.asyncbackend.Backend] = None,
) -> dns.name.Name:
"""Find the name of the zone which contains the specified name.
See :py:func:`dns.resolver.Resolver.zone_for_name` for more
@ -219,8 +277,10 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
raise NotAbsolute(name)
while True:
try:
answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass,
tcp, backend=backend)
answer = await resolver.resolve(
name, dns.rdatatype.SOA, rdclass, tcp, backend=backend
)
assert answer.rrset is not None
if answer.rrset.name == name:
return name
# otherwise we were CNAMEd or DNAMEd and need to look higher

View file

@ -1,26 +0,0 @@
from typing import Union, Optional, List, Any, Dict
from . import exception, rdataclass, name, rdatatype, asyncbackend
async def resolve(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...
async def resolve_address(self, ipaddr: str,
*args: Any, **kwargs: Optional[Dict]):
...
class Resolver:
def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
configure : Optional[bool] = True):
self.nameservers : List[str]
async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
rdclass : Union[int,str] = rdataclass.IN,
tcp : bool = False, source : Optional[str] = None,
raise_on_no_answer=True, source_port : int = 0,
lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...

File diff suppressed because it is too large Load diff

View file

@ -1,21 +0,0 @@
from typing import Union, Dict, Tuple, Optional
from . import rdataset, rrset, exception, name, rdtypes, rdata, node
import dns.rdtypes.ANY.DS as DS
import dns.rdtypes.ANY.DNSKEY as DNSKEY
_have_pyca : bool
def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
...
def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
...
class ValidationFailure(exception.DNSException):
...
def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
...
def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
...

71
lib/dns/dnssectypes.py Normal file
View file

@ -0,0 +1,71 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""Common DNSSEC-related types."""
# This is a separate file to avoid import circularity between dns.dnssec and
# the implementations of the DS and DNSKEY types.
import dns.enum
class Algorithm(dns.enum.IntEnum):
RSAMD5 = 1
DH = 2
DSA = 3
ECC = 4
RSASHA1 = 5
DSANSEC3SHA1 = 6
RSASHA1NSEC3SHA1 = 7
RSASHA256 = 8
RSASHA512 = 10
ECCGOST = 12
ECDSAP256SHA256 = 13
ECDSAP384SHA384 = 14
ED25519 = 15
ED448 = 16
INDIRECT = 252
PRIVATEDNS = 253
PRIVATEOID = 254
@classmethod
def _maximum(cls):
return 255
class DSDigest(dns.enum.IntEnum):
"""DNSSEC Delegation Signer Digest Algorithm"""
NULL = 0
SHA1 = 1
SHA256 = 2
GOST = 3
SHA384 = 4
@classmethod
def _maximum(cls):
return 255
class NSEC3Hash(dns.enum.IntEnum):
"""NSEC3 hash algorithm"""
SHA1 = 1
@classmethod
def _maximum(cls):
return 255

View file

@ -17,15 +17,19 @@
"""DNS E.164 helpers."""
from typing import Iterable, Optional, Union
import dns.exception
import dns.name
import dns.resolver
#: The public E.164 domain.
public_enum_domain = dns.name.from_text('e164.arpa.')
public_enum_domain = dns.name.from_text("e164.arpa.")
def from_e164(text, origin=public_enum_domain):
def from_e164(
text: str, origin: Optional[dns.name.Name] = public_enum_domain
) -> dns.name.Name:
"""Convert an E.164 number in textual form into a Name object whose
value is the ENUM domain name for that number.
@ -42,10 +46,14 @@ def from_e164(text, origin=public_enum_domain):
parts = [d for d in text if d.isdigit()]
parts.reverse()
return dns.name.from_text('.'.join(parts), origin=origin)
return dns.name.from_text(".".join(parts), origin=origin)
def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
def to_e164(
name: dns.name.Name,
origin: Optional[dns.name.Name] = public_enum_domain,
want_plus_prefix: bool = True,
) -> str:
"""Convert an ENUM domain name into an E.164 number.
Note that dnspython does not have any information about preferred
@ -69,15 +77,19 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True):
name = name.relativize(origin)
dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1]
if len(dlabels) != len(name.labels):
raise dns.exception.SyntaxError('non-digit labels in ENUM domain name')
raise dns.exception.SyntaxError("non-digit labels in ENUM domain name")
dlabels.reverse()
text = b''.join(dlabels)
text = b"".join(dlabels)
if want_plus_prefix:
text = b'+' + text
text = b"+" + text
return text.decode()
def query(number, domains, resolver=None):
def query(
number: str,
domains: Iterable[Union[dns.name.Name, str]],
resolver: Optional[dns.resolver.Resolver] = None,
) -> dns.resolver.Answer:
"""Look for NAPTR RRs for the specified number in the specified domains.
e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.'])
@ -98,7 +110,7 @@ def query(number, domains, resolver=None):
domain = dns.name.from_text(domain)
qname = dns.e164.from_e164(number, domain)
try:
return resolver.resolve(qname, 'NAPTR')
return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e:
e_nx += e
raise e_nx

View file

@ -1,10 +0,0 @@
from typing import Optional, Iterable
from . import name, resolver
def from_e164(text : str, origin=name.Name(".")) -> name.Name:
...
def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
...
def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
...

View file

@ -17,6 +17,8 @@
"""EDNS Options"""
from typing import Any, Dict, Optional, Union
import math
import socket
import struct
@ -24,6 +26,7 @@ import struct
import dns.enum
import dns.inet
import dns.rdata
import dns.wire
class OptionType(dns.enum.IntEnum):
@ -59,14 +62,14 @@ class Option:
"""Base class for all EDNS option types."""
def __init__(self, otype):
def __init__(self, otype: Union[OptionType, str]):
"""Initialize an option.
*otype*, an ``int``, is the option type.
*otype*, a ``dns.edns.OptionType``, is the option type.
"""
self.otype = OptionType.make(otype)
def to_wire(self, file=None):
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
"""Convert an option to wire format.
Returns a ``bytes`` or ``None``.
@ -75,10 +78,10 @@ class Option:
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, otype, parser):
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
*otype*, a ``dns.edns.OptionType``, is the option type.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restructed to the option length.
@ -115,26 +118,22 @@ class Option:
return self._cmp(other) != 0
def __lt__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if not isinstance(other, Option) or \
self.otype != other.otype:
if not isinstance(other, Option) or self.otype != other.otype:
return NotImplemented
return self._cmp(other) > 0
@ -142,7 +141,7 @@ class Option:
return self.to_text()
class GenericOption(Option):
class GenericOption(Option): # lgtm[py/missing-equals]
"""Generic Option Class
@ -150,28 +149,31 @@ class GenericOption(Option):
implementation.
"""
def __init__(self, otype, data):
def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]):
super().__init__(otype)
self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file=None):
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
if file:
file.write(self.data)
return None
else:
return self.data
def to_text(self):
def to_text(self) -> str:
return "Generic %d" % self.otype
@classmethod
def from_wire_parser(cls, otype, parser):
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
return cls(otype, parser.get_remaining())
class ECSOption(Option):
class ECSOption(Option): # lgtm[py/missing-equals]
"""EDNS Client Subnet (ECS, RFC7871)"""
def __init__(self, address, srclen=None, scopelen=0):
def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0):
"""*address*, a ``str``, is the client address information.
*srclen*, an ``int``, the source prefix length, which is the
@ -200,8 +202,9 @@ class ECSOption(Option):
srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen)
raise ValueError('Bad address family')
raise ValueError("Bad address family")
assert srclen is not None
self.address = address
self.srclen = srclen
self.scopelen = scopelen
@ -214,16 +217,14 @@ class ECSOption(Option):
self.addrdata = addrdata[:nbytes]
nbits = srclen % 8
if nbits != 0:
last = struct.pack('B',
ord(self.addrdata[-1:]) & (0xff << (8 - nbits)))
last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits)))
self.addrdata = self.addrdata[:-1] + last
def to_text(self):
return "ECS {}/{} scope/{}".format(self.address, self.srclen,
self.scopelen)
def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
@staticmethod
def from_text(text):
def from_text(text: str) -> Option:
"""Convert a string into a `dns.edns.ECSOption`
*text*, a `str`, the text form of the option.
@ -246,7 +247,7 @@ class ECSOption(Option):
>>> # it understands results from `dns.edns.ECSOption.to_text()`
>>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32')
"""
optional_prefix = 'ECS'
optional_prefix = "ECS"
tokens = text.split()
ecs_text = None
if len(tokens) == 1:
@ -257,47 +258,53 @@ class ECSOption(Option):
ecs_text = tokens[1]
else:
raise ValueError('could not parse ECS from "{}"'.format(text))
n_slashes = ecs_text.count('/')
n_slashes = ecs_text.count("/")
if n_slashes == 1:
address, srclen = ecs_text.split('/')
scope = 0
address, tsrclen = ecs_text.split("/")
tscope = "0"
elif n_slashes == 2:
address, srclen, scope = ecs_text.split('/')
address, tsrclen, tscope = ecs_text.split("/")
else:
raise ValueError('could not parse ECS from "{}"'.format(text))
try:
scope = int(scope)
scope = int(tscope)
except ValueError:
raise ValueError('invalid scope ' +
'"{}": scope must be an integer'.format(scope))
raise ValueError(
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
try:
srclen = int(srclen)
srclen = int(tsrclen)
except ValueError:
raise ValueError('invalid srclen ' +
'"{}": srclen must be an integer'.format(srclen))
raise ValueError(
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
)
return ECSOption(address, srclen, scope)
def to_wire(self, file=None):
value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) +
self.addrdata)
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = (
struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata
)
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(cls, otype, parser):
family, src, scope = parser.get_struct('!HBB')
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
family, src, scope = parser.get_struct("!HBB")
addrlen = int(math.ceil(src / 8.0))
prefix = parser.get_bytes(addrlen)
if family == 1:
pad = 4 - addrlen
addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad)
addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad)
elif family == 2:
pad = 16 - addrlen
addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad)
addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad)
else:
raise ValueError('unsupported family')
raise ValueError("unsupported family")
return cls(addr, src, scope)
@ -334,10 +341,10 @@ class EDECode(dns.enum.IntEnum):
return 65535
class EDEOption(Option):
class EDEOption(Option): # lgtm[py/missing-equals]
"""Extended DNS Error (EDE, RFC8914)"""
def __init__(self, code, text=None):
def __init__(self, code: Union[EDECode, str], text: Optional[str] = None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
@ -349,49 +356,50 @@ class EDEOption(Option):
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
raise ValueError('text must be string or None')
self.code = code
raise ValueError("text must be string or None")
self.text = text
def to_text(self):
output = f'EDE {self.code}'
def to_text(self) -> str:
output = f"EDE {self.code}"
if self.text is not None:
output += f': {self.text}'
output += f": {self.text}"
return output
def to_wire(self, file=None):
value = struct.pack('!H', self.code)
def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
value = struct.pack("!H", self.code)
if self.text is not None:
value += self.text.encode('utf8')
value += self.text.encode("utf8")
if file:
file.write(value)
return None
else:
return value
@classmethod
def from_wire_parser(cls, otype, parser):
code = parser.get_uint16()
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
the_code = EDECode.make(parser.get_uint16())
text = parser.get_remaining()
if text:
if text[-1] == 0: # text MAY be null-terminated
text = text[:-1]
text = text.decode('utf8')
btext = text.decode("utf8")
else:
text = None
btext = None
return cls(code, text)
return cls(the_code, btext)
_type_to_class = {
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
}
def get_option_class(otype):
def get_option_class(otype: OptionType) -> Any:
"""Return the class for the specified option type.
The GenericOption class is used if a more specific class is not
@ -404,7 +412,9 @@ def get_option_class(otype):
return cls
def option_from_wire_parser(otype, parser):
def option_from_wire_parser(
otype: Union[OptionType, str], parser: "dns.wire.Parser"
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
@ -414,12 +424,14 @@ def option_from_wire_parser(otype, parser):
Returns an instance of a subclass of ``dns.edns.Option``.
"""
cls = get_option_class(otype)
otype = OptionType.make(otype)
the_otype = OptionType.make(otype)
cls = get_option_class(the_otype)
return cls.from_wire_parser(otype, parser)
def option_from_wire(otype, wire, current, olen):
def option_from_wire(
otype: Union[OptionType, str], wire: bytes, current: int, olen: int
) -> Option:
"""Build an EDNS option object from wire format.
*otype*, an ``int``, is the option type.
@ -437,7 +449,8 @@ def option_from_wire(otype, wire, current, olen):
with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser)
def register_type(implementation, otype):
def register_type(implementation: Any, otype: OptionType) -> None:
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
@ -447,6 +460,7 @@ def register_type(implementation, otype):
_type_to_class[otype] = implementation
### BEGIN generated OptionType constants
NSID = OptionType.NSID

View file

@ -15,14 +15,13 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from typing import Any, Optional
import os
import hashlib
import random
import threading
import time
try:
import threading as _threading
except ImportError: # pragma: no cover
import dummy_threading as _threading # type: ignore
class EntropyPool:
@ -32,51 +31,51 @@ class EntropyPool:
# leaving this code doesn't hurt anything as the library code
# is used if present.
def __init__(self, seed=None):
def __init__(self, seed: Optional[bytes] = None):
self.pool_index = 0
self.digest = None
self.digest: Optional[bytearray] = None
self.next_byte = 0
self.lock = _threading.Lock()
self.lock = threading.Lock()
self.hash = hashlib.sha1()
self.hash_len = 20
self.pool = bytearray(b'\0' * self.hash_len)
self.pool = bytearray(b"\0" * self.hash_len)
if seed is not None:
self._stir(bytearray(seed))
self._stir(seed)
self.seeded = True
self.seed_pid = os.getpid()
else:
self.seeded = False
self.seed_pid = 0
def _stir(self, entropy):
def _stir(self, entropy: bytes) -> None:
for c in entropy:
if self.pool_index == self.hash_len:
self.pool_index = 0
b = c & 0xff
b = c & 0xFF
self.pool[self.pool_index] ^= b
self.pool_index += 1
def stir(self, entropy):
def stir(self, entropy: bytes) -> None:
with self.lock:
self._stir(entropy)
def _maybe_seed(self):
def _maybe_seed(self) -> None:
if not self.seeded or self.seed_pid != os.getpid():
try:
seed = os.urandom(16)
except Exception: # pragma: no cover
try:
with open('/dev/urandom', 'rb', 0) as r:
with open("/dev/urandom", "rb", 0) as r:
seed = r.read(16)
except Exception:
seed = str(time.time())
seed = str(time.time()).encode()
self.seeded = True
self.seed_pid = os.getpid()
self.digest = None
seed = bytearray(seed)
self._stir(seed)
def random_8(self):
def random_8(self) -> int:
with self.lock:
self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len:
@ -88,16 +87,16 @@ class EntropyPool:
self.next_byte += 1
return value
def random_16(self):
def random_16(self) -> int:
return self.random_8() * 256 + self.random_8()
def random_32(self):
def random_32(self) -> int:
return self.random_16() * 65536 + self.random_16()
def random_between(self, first, last):
def random_between(self, first: int, last: int) -> int:
size = last - first + 1
if size > 4294967296:
raise ValueError('too big')
raise ValueError("too big")
if size > 65536:
rand = self.random_32
max = 4294967295
@ -109,20 +108,24 @@ class EntropyPool:
max = 255
return first + size * rand() // (max + 1)
pool = EntropyPool()
system_random: Optional[Any]
try:
system_random = random.SystemRandom()
except Exception: # pragma: no cover
system_random = None
def random_16():
def random_16() -> int:
if system_random is not None:
return system_random.randrange(0, 65536)
else:
return pool.random_16()
def between(first, last):
def between(first: int, last: int) -> int:
if system_random is not None:
return system_random.randrange(first, last + 1)
else:

View file

@ -1,10 +0,0 @@
from typing import Optional
from random import SystemRandom
system_random : Optional[SystemRandom]
def random_16() -> int:
pass
def between(first: int, last: int) -> int:
pass

View file

@ -17,6 +17,7 @@
import enum
class IntEnum(enum.IntEnum):
@classmethod
def _check_value(cls, value):
@ -32,9 +33,12 @@ class IntEnum(enum.IntEnum):
return cls[text]
except KeyError:
pass
value = cls._extra_from_text(text)
if value:
return value
prefix = cls._prefix()
if text.startswith(prefix) and text[len(prefix):].isdigit():
value = int(text[len(prefix):])
if text.startswith(prefix) and text[len(prefix) :].isdigit():
value = int(text[len(prefix) :])
cls._check_value(value)
try:
return cls(value)
@ -46,9 +50,13 @@ class IntEnum(enum.IntEnum):
def to_text(cls, value):
cls._check_value(value)
try:
return cls(value).name
text = cls(value).name
except ValueError:
return f"{cls._prefix()}{value}"
text = None
text = cls._extra_to_text(value, text)
if text is None:
text = f"{cls._prefix()}{value}"
return text
@classmethod
def make(cls, value):
@ -83,7 +91,15 @@ class IntEnum(enum.IntEnum):
@classmethod
def _prefix(cls):
return ''
return ""
@classmethod
def _extra_from_text(cls, text): # pylint: disable=W0613
return None
@classmethod
def _extra_to_text(cls, value, current_text): # pylint: disable=W0613
return current_text
@classmethod
def _unknown_exception_class(cls):

View file

@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will
always be subclasses of ``DNSException``.
"""
from typing import Optional, Set
class DNSException(Exception):
"""Abstract base class shared by all dnspython exceptions.
@ -44,14 +48,15 @@ class DNSException(Exception):
and ``fmt`` class variables to get nice parametrized messages.
"""
msg = None # non-parametrized message
supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt = None # message parametrized with results from _fmt_kwargs
msg: Optional[str] = None # non-parametrized message
supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check)
fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs
def __init__(self, *args, **kwargs):
self._check_params(*args, **kwargs)
if kwargs:
self.kwargs = self._check_kwargs(**kwargs)
# This call to a virtual method from __init__ is ok in our usage
self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass]
self.msg = str(self)
else:
self.kwargs = dict() # defined but empty for old mode exceptions
@ -68,14 +73,15 @@ class DNSException(Exception):
For sanity we do not allow to mix old and new behavior."""
if args or kwargs:
assert bool(args) != bool(kwargs), \
'keyword arguments are mutually exclusive with positional args'
assert bool(args) != bool(
kwargs
), "keyword arguments are mutually exclusive with positional args"
def _check_kwargs(self, **kwargs):
if kwargs:
assert set(kwargs.keys()) == self.supp_kwargs, \
'following set of keyword args is required: %s' % (
self.supp_kwargs)
assert (
set(kwargs.keys()) == self.supp_kwargs
), "following set of keyword args is required: %s" % (self.supp_kwargs)
return kwargs
def _fmt_kwargs(self, **kwargs):
@ -124,9 +130,15 @@ class TooBig(DNSException):
class Timeout(DNSException):
"""The DNS operation timed out."""
supp_kwargs = {'timeout'}
supp_kwargs = {"timeout"}
fmt = "The DNS operation timed out after {timeout:.3f} seconds"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ExceptionWrapper:
def __init__(self, exception_class):
@ -136,7 +148,6 @@ class ExceptionWrapper:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val,
self.exception_class):
if exc_type is not None and not isinstance(exc_val, self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val
return False

View file

@ -1,12 +0,0 @@
from typing import Set, Optional, Dict
class DNSException(Exception):
supp_kwargs : Set[str]
kwargs : Optional[Dict]
fmt : Optional[str]
class SyntaxError(DNSException): ...
class FormError(DNSException): ...
class Timeout(DNSException): ...
class TooBig(DNSException): ...
class UnexpectedEnd(SyntaxError): ...

View file

@ -17,10 +17,13 @@
"""DNS Message Flags."""
from typing import Any
import enum
# Standard DNS flags
class Flag(enum.IntFlag):
#: Query Response
QR = 0x8000
@ -40,12 +43,13 @@ class Flag(enum.IntFlag):
# EDNS flags
class EDNSFlag(enum.IntFlag):
#: DNSSEC answer OK
DO = 0x8000
def _from_text(text, enum_class):
def _from_text(text: str, enum_class: Any) -> int:
flags = 0
tokens = text.split()
for t in tokens:
@ -53,15 +57,15 @@ def _from_text(text, enum_class):
return flags
def _to_text(flags, enum_class):
def _to_text(flags: int, enum_class: Any) -> str:
text_flags = []
for k, v in enum_class.__members__.items():
if flags & v != 0:
text_flags.append(k)
return ' '.join(text_flags)
return " ".join(text_flags)
def from_text(text):
def from_text(text: str) -> int:
"""Convert a space-separated list of flag text values into a flags
value.
@ -71,7 +75,7 @@ def from_text(text):
return _from_text(text, Flag)
def to_text(flags):
def to_text(flags: int) -> str:
"""Convert a flags value into a space-separated list of flag text
values.
@ -81,7 +85,7 @@ def to_text(flags):
return _to_text(flags, Flag)
def edns_from_text(text):
def edns_from_text(text: str) -> int:
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
@ -91,7 +95,7 @@ def edns_from_text(text):
return _from_text(text, EDNSFlag)
def edns_to_text(flags):
def edns_to_text(flags: int) -> str:
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
@ -100,6 +104,7 @@ def edns_to_text(flags):
return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants
QR = Flag.QR

View file

@ -17,9 +17,12 @@
"""DNS GENERATE range conversion."""
from typing import Tuple
import dns
def from_text(text):
def from_text(text: str) -> Tuple[int, int, int]:
"""Convert the text form of a range in a ``$GENERATE`` statement to an
integer.
@ -31,22 +34,22 @@ def from_text(text):
start = -1
stop = -1
step = 1
cur = ''
cur = ""
state = 0
# state 0 1 2
# x - y / z
if text and text[0] == '-':
if text and text[0] == "-":
raise dns.exception.SyntaxError("Start cannot be a negative number")
for c in text:
if c == '-' and state == 0:
if c == "-" and state == 0:
start = int(cur)
cur = ''
cur = ""
state = 1
elif c == '/':
elif c == "/":
stop = int(cur)
cur = ''
cur = ""
state = 2
elif c.isdigit():
cur += c
@ -64,6 +67,6 @@ def from_text(text):
assert step >= 1
assert start >= 0
if start > stop:
raise dns.exception.SyntaxError('start must be <= stop')
raise dns.exception.SyntaxError("start must be <= stop")
return (start, stop, step)

View file

@ -1,32 +1,25 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc
import sys
from typing import Any
# pylint: disable=unused-import
if sys.version_info >= (3, 7):
odict = dict
from dns._immutable_ctx import immutable
else:
# pragma: no cover
from collections import OrderedDict as odict
from dns._immutable_attr import immutable # noqa
# pylint: enable=unused-import
import collections.abc
from dns._immutable_ctx import immutable
@immutable
class Dict(collections.abc.Mapping):
def __init__(self, dictionary, no_copy=False):
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
def __init__(self, dictionary: Any, no_copy: bool = False):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external
references to the dictionary.
"""
if no_copy and isinstance(dictionary, odict):
if no_copy and isinstance(dictionary, dict):
self._odict = dictionary
else:
self._odict = odict(dictionary)
self._odict = dict(dictionary)
self._hash = None
def __getitem__(self, key):
@ -37,7 +30,7 @@ class Dict(collections.abc.Mapping):
h = 0
for key in sorted(self._odict.keys()):
h ^= hash(key)
object.__setattr__(self, '_hash', h)
object.__setattr__(self, "_hash", h)
# this does return an int, but pylint doesn't figure that out
return self._hash
@ -48,7 +41,7 @@ class Dict(collections.abc.Mapping):
return iter(self._odict)
def constify(o):
def constify(o: Any) -> Any:
"""
Convert mutable types to immutable types.
"""
@ -63,7 +56,7 @@ def constify(o):
if isinstance(o, list):
return tuple(constify(elt) for elt in o)
if isinstance(o, dict):
cdict = odict()
cdict = dict()
for k, v in o.items():
cdict[k] = constify(v)
return Dict(cdict, True)

View file

@ -17,6 +17,8 @@
"""Generic Internet address helper functions."""
from typing import Any, Optional, Tuple
import socket
import dns.ipv4
@ -30,7 +32,7 @@ AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6
def inet_pton(family, text):
def inet_pton(family: int, text: str) -> bytes:
"""Convert the textual form of a network address into its binary form.
*family* is an ``int``, the address family.
@ -51,7 +53,7 @@ def inet_pton(family, text):
raise NotImplementedError
def inet_ntop(family, address):
def inet_ntop(family: int, address: bytes) -> str:
"""Convert the binary form of a network address into its textual form.
*family* is an ``int``, the address family.
@ -72,7 +74,7 @@ def inet_ntop(family, address):
raise NotImplementedError
def af_for_address(text):
def af_for_address(text: str) -> int:
"""Determine the address family of a textual-form network address.
*text*, a ``str``, the textual address.
@ -94,7 +96,7 @@ def af_for_address(text):
raise ValueError
def is_multicast(text):
def is_multicast(text: str) -> bool:
"""Is the textual-form network address a multicast address?
*text*, a ``str``, the textual address.
@ -116,7 +118,7 @@ def is_multicast(text):
raise ValueError
def is_address(text):
def is_address(text: str) -> bool:
"""Is the specified string an IPv4 or IPv6 address?
*text*, a ``str``, the textual address.
@ -135,7 +137,9 @@ def is_address(text):
return False
def low_level_address_tuple(high_tuple, af=None):
def low_level_address_tuple(
high_tuple: Tuple[str, int], af: Optional[int] = None
) -> Any:
"""Given a "high-level" address tuple, i.e.
an (address, port) return the appropriate "low-level" address tuple
suitable for use in socket calls.
@ -143,7 +147,6 @@ def low_level_address_tuple(high_tuple, af=None):
If an *af* other than ``None`` is provided, it is assumed the
address in the high-level tuple is valid and has that af. If af
is ``None``, then af_for_address will be called.
"""
address, port = high_tuple
if af is None:
@ -151,13 +154,13 @@ def low_level_address_tuple(high_tuple, af=None):
if af == AF_INET:
return (address, port)
elif af == AF_INET6:
i = address.find('%')
i = address.find("%")
if i < 0:
# no scope, shortcut!
return (address, port, 0, 0)
# try to avoid getaddrinfo()
addrpart = address[:i]
scope = address[i + 1:]
scope = address[i + 1 :]
if scope.isdigit():
return (addrpart, port, 0, int(scope))
try:
@ -167,4 +170,4 @@ def low_level_address_tuple(high_tuple, af=None):
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
return tup
else:
raise NotImplementedError(f'unknown address family {af}')
raise NotImplementedError(f"unknown address family {af}")

View file

@ -1,4 +0,0 @@
from typing import Union
from socket import AddressFamily
AF_INET6 : Union[int, AddressFamily]

View file

@ -17,11 +17,14 @@
"""IPv4 helper functions."""
from typing import Union
import struct
import dns.exception
def inet_ntoa(address):
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv4 address in binary form to text form.
*address*, a ``bytes``, the IPv4 address in binary form.
@ -31,30 +34,32 @@ def inet_ntoa(address):
if len(address) != 4:
raise dns.exception.SyntaxError
return ('%u.%u.%u.%u' % (address[0], address[1],
address[2], address[3]))
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
def inet_aton(text):
def inet_aton(text: Union[str, bytes]) -> bytes:
"""Convert an IPv4 address in text form to binary form.
*text*, a ``str``, the IPv4 address in textual form.
*text*, a ``str`` or ``bytes``, the IPv4 address in textual form.
Returns a ``bytes``.
"""
if not isinstance(text, bytes):
text = text.encode()
parts = text.split(b'.')
btext = text.encode()
else:
btext = text
parts = btext.split(b".")
if len(parts) != 4:
raise dns.exception.SyntaxError
for part in parts:
if not part.isdigit():
raise dns.exception.SyntaxError
if len(part) > 1 and part[0] == ord('0'):
if len(part) > 1 and part[0] == ord("0"):
# No leading zeros
raise dns.exception.SyntaxError
try:
b = [int(part) for part in parts]
return struct.pack('BBBB', *b)
return struct.pack("BBBB", *b)
except Exception:
raise dns.exception.SyntaxError

View file

@ -17,15 +17,18 @@
"""IPv6 helper functions."""
from typing import List, Union
import re
import binascii
import dns.exception
import dns.ipv4
_leading_zero = re.compile(r'0+([0-9a-f]+)')
_leading_zero = re.compile(r"0+([0-9a-f]+)")
def inet_ntoa(address):
def inet_ntoa(address: bytes) -> str:
"""Convert an IPv6 address in binary form to text form.
*address*, a ``bytes``, the IPv6 address in binary form.
@ -41,7 +44,7 @@ def inet_ntoa(address):
i = 0
l = len(hex)
while i < l:
chunk = hex[i:i + 4].decode()
chunk = hex[i : i + 4].decode()
# strip leading zeros. we do this with an re instead of
# with lstrip() because lstrip() didn't support chars until
# python 2.2.2
@ -58,7 +61,7 @@ def inet_ntoa(address):
start = -1
last_was_zero = False
for i in range(8):
if chunks[i] != '0':
if chunks[i] != "0":
if last_was_zero:
end = i
current_len = end - start
@ -76,27 +79,30 @@ def inet_ntoa(address):
best_start = start
best_len = current_len
if best_len > 1:
if best_start == 0 and \
(best_len == 6 or
best_len == 5 and chunks[5] == 'ffff'):
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
# We have an embedded IPv4 address
if best_len == 6:
prefix = '::'
prefix = "::"
else:
prefix = '::ffff:'
hex = prefix + dns.ipv4.inet_ntoa(address[12:])
prefix = "::ffff:"
thex = prefix + dns.ipv4.inet_ntoa(address[12:])
else:
hex = ':'.join(chunks[:best_start]) + '::' + \
':'.join(chunks[best_start + best_len:])
thex = (
":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else:
hex = ':'.join(chunks)
return hex
thex = ":".join(chunks)
return thex
_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$')
_colon_colon_start = re.compile(br'::.*')
_colon_colon_end = re.compile(br'.*::$')
def inet_aton(text, ignore_scope=False):
_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$")
_colon_colon_start = re.compile(rb"::.*")
_colon_colon_end = re.compile(rb".*::$")
def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
"""Convert an IPv6 address in text form to binary form.
*text*, a ``str``, the IPv6 address in textual form.
@ -111,82 +117,88 @@ def inet_aton(text, ignore_scope=False):
# Our aim here is not something fast; we just want something that works.
#
if not isinstance(text, bytes):
text = text.encode()
btext = text.encode()
else:
btext = text
if ignore_scope:
parts = text.split(b'%')
parts = btext.split(b"%")
l = len(parts)
if l == 2:
text = parts[0]
btext = parts[0]
elif l > 2:
raise dns.exception.SyntaxError
if text == b'':
if btext == b"":
raise dns.exception.SyntaxError
elif text.endswith(b':') and not text.endswith(b'::'):
elif btext.endswith(b":") and not btext.endswith(b"::"):
raise dns.exception.SyntaxError
elif text.startswith(b':') and not text.startswith(b'::'):
elif btext.startswith(b":") and not btext.startswith(b"::"):
raise dns.exception.SyntaxError
elif text == b'::':
text = b'0::'
elif btext == b"::":
btext = b"0::"
#
# Get rid of the icky dot-quad syntax if we have it.
#
m = _v4_ending.match(text)
m = _v4_ending.match(btext)
if m is not None:
b = dns.ipv4.inet_aton(m.group(2))
text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
b[0], b[1], b[2],
b[3])).encode()
btext = (
"{}:{:02x}{:02x}:{:02x}{:02x}".format(
m.group(1).decode(), b[0], b[1], b[2], b[3]
)
).encode()
#
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
# turn '<whatever>::' into '<whatever>:'
#
m = _colon_colon_start.match(text)
m = _colon_colon_start.match(btext)
if m is not None:
text = text[1:]
btext = btext[1:]
else:
m = _colon_colon_end.match(text)
m = _colon_colon_end.match(btext)
if m is not None:
text = text[:-1]
btext = btext[:-1]
#
# Now canonicalize into 8 chunks of 4 hex digits each
#
chunks = text.split(b':')
chunks = btext.split(b":")
l = len(chunks)
if l > 8:
raise dns.exception.SyntaxError
seen_empty = False
canonical = []
canonical: List[bytes] = []
for c in chunks:
if c == b'':
if c == b"":
if seen_empty:
raise dns.exception.SyntaxError
seen_empty = True
for _ in range(0, 8 - l + 1):
canonical.append(b'0000')
canonical.append(b"0000")
else:
lc = len(c)
if lc > 4:
raise dns.exception.SyntaxError
if lc != 4:
c = (b'0' * (4 - lc)) + c
c = (b"0" * (4 - lc)) + c
canonical.append(c)
if l < 8 and not seen_empty:
raise dns.exception.SyntaxError
text = b''.join(canonical)
btext = b"".join(canonical)
#
# Finally we can go to binary.
#
try:
return binascii.unhexlify(text)
return binascii.unhexlify(btext)
except (binascii.Error, TypeError):
raise dns.exception.SyntaxError
_mapped_prefix = b'\x00' * 10 + b'\xff\xff'
def is_mapped(address):
_mapped_prefix = b"\x00" * 10 + b"\xff\xff"
def is_mapped(address: bytes) -> bool:
"""Is the specified address a mapped IPv4 address?
*address*, a ``bytes`` is an IPv6 address in binary form.

File diff suppressed because it is too large Load diff

View file

@ -1,47 +0,0 @@
from typing import Optional, Dict, List, Tuple, Union
from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode
import hmac
class Message:
def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes:
...
def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int,
covers=rdatatype.NONE, deleting : Optional[int]=None, create=False,
force_unique=False) -> rrset.RRset:
...
def __init__(self, id : Optional[int] =None) -> None:
self.id : int
self.flags = 0
self.sections : List[List[rrset.RRset]] = [[], [], [], []]
self.opt : rrset.RRset = None
self.request_payload = 0
self.keyring = None
self.tsig : rrset.RRset = None
self.request_mac = b''
self.xfr = False
self.origin = None
self.tsig_ctx = None
self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
def is_response(self, other : Message) -> bool:
...
def set_rcode(self, rcode : rcode.Rcode):
...
def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
...
def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
question_only=False, one_rr_per_rrset=False,
ignore_trailing=False) -> Message:
...
def make_response(query : Message, recursion_available=False, our_payload=8192,
fudge=300) -> Message:
...
def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None,
want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None,
request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message:
...

View file

@ -18,32 +18,61 @@
"""DNS Names.
"""
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import copy
import struct
import encodings.idna # type: ignore
import encodings.idna # type: ignore
try:
import idna # type: ignore
import idna # type: ignore
have_idna_2008 = True
except ImportError: # pragma: no cover
have_idna_2008 = False
import dns.enum
import dns.wire
import dns.exception
import dns.immutable
# fullcompare() result values
#: The compared names have no relationship to each other.
NAMERELN_NONE = 0
#: the first name is a superdomain of the second.
NAMERELN_SUPERDOMAIN = 1
#: The first name is a subdomain of the second.
NAMERELN_SUBDOMAIN = 2
#: The compared names are equal.
NAMERELN_EQUAL = 3
#: The compared names have a common ancestor.
NAMERELN_COMMONANCESTOR = 4
CompressType = Dict["Name", int]
class NameRelation(dns.enum.IntEnum):
"""Name relation result from fullcompare()."""
# This is an IntEnum for backwards compatibility in case anyone
# has hardwired the constants.
#: The compared names have no relationship to each other.
NONE = 0
#: the first name is a superdomain of the second.
SUPERDOMAIN = 1
#: The first name is a subdomain of the second.
SUBDOMAIN = 2
#: The compared names are equal.
EQUAL = 3
#: The compared names have a common ancestor.
COMMONANCESTOR = 4
@classmethod
def _maximum(cls):
return cls.COMMONANCESTOR
@classmethod
def _short_name(cls):
return cls.__name__
# Backwards compatibility
NAMERELN_NONE = NameRelation.NONE
NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN
NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN
NAMERELN_EQUAL = NameRelation.EQUAL
NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR
class EmptyLabel(dns.exception.SyntaxError):
@ -84,6 +113,7 @@ class NoParent(dns.exception.DNSException):
"""An attempt was made to get the parent of the root name
or the empty name."""
class NoIDNA2008(dns.exception.DNSException):
"""IDNA 2008 processing was requested but the idna module is not
available."""
@ -92,9 +122,47 @@ class NoIDNA2008(dns.exception.DNSException):
class IDNAException(dns.exception.DNSException):
"""IDNA processing raised an exception."""
supp_kwargs = {'idna_exception'}
supp_kwargs = {"idna_exception"}
fmt = "IDNA processing exception: {idna_exception}"
# We do this as otherwise mypy complains about unexpected keyword argument
# idna_exception
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
def _escapify(label: Union[bytes, str]) -> str:
"""Escape the characters in label which need it.
@returns: the escaped string
@rtype: string"""
if isinstance(label, bytes):
# Ordinary DNS label mode. Escape special characters and values
# < 0x20 or > 0x7f.
text = ""
for c in label:
if c in _escaped:
text += "\\" + chr(c)
elif c > 0x20 and c < 0x7F:
text += chr(c)
else:
text += "\\%03d" % c
return text
# Unicode label mode. Escape only special characters and values < 0x20
text = ""
for uc in label:
if uc in _escaped_text:
text += "\\" + uc
elif uc <= "\x20":
text += "\\%03d" % ord(uc)
else:
text += uc
return text
class IDNACodec:
"""Abstract base class for IDNA encoder/decoders."""
@ -102,26 +170,28 @@ class IDNACodec:
def __init__(self):
pass
def is_idna(self, label):
return label.lower().startswith(b'xn--')
def is_idna(self, label: bytes) -> bool:
return label.lower().startswith(b"xn--")
def encode(self, label):
def encode(self, label: str) -> bytes:
raise NotImplementedError # pragma: no cover
def decode(self, label):
def decode(self, label: bytes) -> str:
# We do not apply any IDNA policy on decode.
if self.is_idna(label):
try:
label = label[4:].decode('punycode')
slabel = label[4:].decode("punycode")
return _escapify(slabel)
except Exception as e:
raise IDNAException(idna_exception=e)
return _escapify(label)
else:
return _escapify(label)
class IDNA2003Codec(IDNACodec):
"""IDNA 2003 encoder/decoder."""
def __init__(self, strict_decode=False):
def __init__(self, strict_decode: bool = False):
"""Initialize the IDNA 2003 encoder/decoder.
*strict_decode* is a ``bool``. If `True`, then IDNA2003 checking
@ -132,22 +202,22 @@ class IDNA2003Codec(IDNACodec):
super().__init__()
self.strict_decode = strict_decode
def encode(self, label):
def encode(self, label: str) -> bytes:
"""Encode *label*."""
if label == '':
return b''
if label == "":
return b""
try:
return encodings.idna.ToASCII(label)
except UnicodeError:
raise LabelTooLong
def decode(self, label):
def decode(self, label: bytes) -> str:
"""Decode *label*."""
if not self.strict_decode:
return super().decode(label)
if label == b'':
return ''
if label == b"":
return ""
try:
return _escapify(encodings.idna.ToUnicode(label))
except Exception as e:
@ -155,16 +225,20 @@ class IDNA2003Codec(IDNACodec):
class IDNA2008Codec(IDNACodec):
"""IDNA 2008 encoder/decoder.
"""
"""IDNA 2008 encoder/decoder."""
def __init__(self, uts_46=False, transitional=False,
allow_pure_ascii=False, strict_decode=False):
def __init__(
self,
uts_46: bool = False,
transitional: bool = False,
allow_pure_ascii: bool = False,
strict_decode: bool = False,
):
"""Initialize the IDNA 2008 encoder/decoder.
*uts_46* is a ``bool``. If True, apply Unicode IDNA
compatibility processing as described in Unicode Technical
Standard #46 (http://unicode.org/reports/tr46/).
Standard #46 (https://unicode.org/reports/tr46/).
If False, do not apply the mapping. The default is False.
*transitional* is a ``bool``: If True, use the
@ -188,11 +262,11 @@ class IDNA2008Codec(IDNACodec):
self.allow_pure_ascii = allow_pure_ascii
self.strict_decode = strict_decode
def encode(self, label):
if label == '':
return b''
def encode(self, label: str) -> bytes:
if label == "":
return b""
if self.allow_pure_ascii and is_all_ascii(label):
encoded = label.encode('ascii')
encoded = label.encode("ascii")
if len(encoded) > 63:
raise LabelTooLong
return encoded
@ -203,16 +277,16 @@ class IDNA2008Codec(IDNACodec):
label = idna.uts46_remap(label, False, self.transitional)
return idna.alabel(label)
except idna.IDNAError as e:
if e.args[0] == 'Label too long':
if e.args[0] == "Label too long":
raise LabelTooLong
else:
raise IDNAException(idna_exception=e)
def decode(self, label):
def decode(self, label: bytes) -> str:
if not self.strict_decode:
return super().decode(label)
if label == b'':
return ''
if label == b"":
return ""
if not have_idna_2008:
raise NoIDNA2008
try:
@ -223,8 +297,6 @@ class IDNA2008Codec(IDNACodec):
except (idna.IDNAError, UnicodeError) as e:
raise IDNAException(idna_exception=e)
_escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
IDNA_2003_Practical = IDNA2003Codec(False)
IDNA_2003_Strict = IDNA2003Codec(True)
@ -235,35 +307,8 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True)
IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
IDNA_2008 = IDNA_2008_Practical
def _escapify(label):
"""Escape the characters in label which need it.
@returns: the escaped string
@rtype: string"""
if isinstance(label, bytes):
# Ordinary DNS label mode. Escape special characters and values
# < 0x20 or > 0x7f.
text = ''
for c in label:
if c in _escaped:
text += '\\' + chr(c)
elif c > 0x20 and c < 0x7F:
text += chr(c)
else:
text += '\\%03d' % c
return text
# Unicode label mode. Escape only special characters and values < 0x20
text = ''
for c in label:
if c in _escaped_text:
text += '\\' + c
elif c <= '\x20':
text += '\\%03d' % ord(c)
else:
text += c
return text
def _validate_labels(labels):
def _validate_labels(labels: Tuple[bytes, ...]) -> None:
"""Check for empty labels in the middle of a label sequence,
labels that are too long, and for too many labels.
@ -284,7 +329,7 @@ def _validate_labels(labels):
total += ll + 1
if ll > 63:
raise LabelTooLong
if i < 0 and label == b'':
if i < 0 and label == b"":
i = j
j += 1
if total > 255:
@ -293,7 +338,7 @@ def _validate_labels(labels):
raise EmptyLabel
def _maybe_convert_to_binary(label):
def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
"""If label is ``str``, convert it to ``bytes``. If it is already
``bytes`` just return it.
@ -316,14 +361,13 @@ class Name:
of the class are immutable.
"""
__slots__ = ['labels']
__slots__ = ["labels"]
def __init__(self, labels):
"""*labels* is any iterable whose values are ``str`` or ``bytes``.
"""
def __init__(self, labels: Iterable[Union[bytes, str]]):
"""*labels* is any iterable whose values are ``str`` or ``bytes``."""
labels = [_maybe_convert_to_binary(x) for x in labels]
self.labels = tuple(labels)
blabels = [_maybe_convert_to_binary(x) for x in labels]
self.labels = tuple(blabels)
_validate_labels(self.labels)
def __copy__(self):
@ -334,29 +378,29 @@ class Name:
def __getstate__(self):
# Names can be pickled
return {'labels': self.labels}
return {"labels": self.labels}
def __setstate__(self, state):
super().__setattr__('labels', state['labels'])
super().__setattr__("labels", state["labels"])
_validate_labels(self.labels)
def is_absolute(self):
def is_absolute(self) -> bool:
"""Is the most significant label of this name the root label?
Returns a ``bool``.
"""
return len(self.labels) > 0 and self.labels[-1] == b''
return len(self.labels) > 0 and self.labels[-1] == b""
def is_wild(self):
def is_wild(self) -> bool:
"""Is this name wild? (I.e. Is the least significant label '*'?)
Returns a ``bool``.
"""
return len(self.labels) > 0 and self.labels[0] == b'*'
return len(self.labels) > 0 and self.labels[0] == b"*"
def __hash__(self):
def __hash__(self) -> int:
"""Return a case-insensitive hash of the name.
Returns an ``int``.
@ -368,14 +412,14 @@ class Name:
h += (h << 3) + c
return h
def fullcompare(self, other):
def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]:
"""Compare two names, returning a 3-tuple
``(relation, order, nlabels)``.
*relation* describes the relation ship between the names,
and is one of: ``dns.name.NAMERELN_NONE``,
``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``,
``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``.
and is one of: ``dns.name.NameRelation.NONE``,
``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``,
``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``.
*order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and ==
0 if *self* == *other*. A relative name is always less than an
@ -404,9 +448,9 @@ class Name:
oabs = other.is_absolute()
if sabs != oabs:
if sabs:
return (NAMERELN_NONE, 1, 0)
return (NameRelation.NONE, 1, 0)
else:
return (NAMERELN_NONE, -1, 0)
return (NameRelation.NONE, -1, 0)
l1 = len(self.labels)
l2 = len(other.labels)
ldiff = l1 - l2
@ -417,7 +461,7 @@ class Name:
order = 0
nlabels = 0
namereln = NAMERELN_NONE
namereln = NameRelation.NONE
while l > 0:
l -= 1
l1 -= 1
@ -427,52 +471,52 @@ class Name:
if label1 < label2:
order = -1
if nlabels > 0:
namereln = NAMERELN_COMMONANCESTOR
namereln = NameRelation.COMMONANCESTOR
return (namereln, order, nlabels)
elif label1 > label2:
order = 1
if nlabels > 0:
namereln = NAMERELN_COMMONANCESTOR
namereln = NameRelation.COMMONANCESTOR
return (namereln, order, nlabels)
nlabels += 1
order = ldiff
if ldiff < 0:
namereln = NAMERELN_SUPERDOMAIN
namereln = NameRelation.SUPERDOMAIN
elif ldiff > 0:
namereln = NAMERELN_SUBDOMAIN
namereln = NameRelation.SUBDOMAIN
else:
namereln = NAMERELN_EQUAL
namereln = NameRelation.EQUAL
return (namereln, order, nlabels)
def is_subdomain(self, other):
def is_subdomain(self, other: "Name") -> bool:
"""Is self a subdomain of other?
Note that the notion of subdomain includes equality, e.g.
"dnpython.org" is a subdomain of itself.
"dnspython.org" is a subdomain of itself.
Returns a ``bool``.
"""
(nr, _, _) = self.fullcompare(other)
if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL:
return True
return False
def is_superdomain(self, other):
def is_superdomain(self, other: "Name") -> bool:
"""Is self a superdomain of other?
Note that the notion of superdomain includes equality, e.g.
"dnpython.org" is a superdomain of itself.
"dnspython.org" is a superdomain of itself.
Returns a ``bool``.
"""
(nr, _, _) = self.fullcompare(other)
if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL:
return True
return False
def canonicalize(self):
def canonicalize(self) -> "Name":
"""Return a name which is equal to the current name, but is in
DNSSEC canonical form.
"""
@ -516,12 +560,12 @@ class Name:
return NotImplemented
def __repr__(self):
return '<DNS name ' + self.__str__() + '>'
return "<DNS name " + self.__str__() + ">"
def __str__(self):
return self.to_text(False)
def to_text(self, omit_final_dot=False):
def to_text(self, omit_final_dot: bool = False) -> str:
"""Convert name to DNS text format.
*omit_final_dot* is a ``bool``. If True, don't emit the final
@ -532,17 +576,19 @@ class Name:
"""
if len(self.labels) == 0:
return '@'
if len(self.labels) == 1 and self.labels[0] == b'':
return '.'
return "@"
if len(self.labels) == 1 and self.labels[0] == b"":
return "."
if omit_final_dot and self.is_absolute():
l = self.labels[:-1]
else:
l = self.labels
s = '.'.join(map(_escapify, l))
s = ".".join(map(_escapify, l))
return s
def to_unicode(self, omit_final_dot=False, idna_codec=None):
def to_unicode(
self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None
) -> str:
"""Convert name to Unicode text format.
IDN ACE labels are converted to Unicode.
@ -561,18 +607,18 @@ class Name:
"""
if len(self.labels) == 0:
return '@'
if len(self.labels) == 1 and self.labels[0] == b'':
return '.'
return "@"
if len(self.labels) == 1 and self.labels[0] == b"":
return "."
if omit_final_dot and self.is_absolute():
l = self.labels[:-1]
else:
l = self.labels
if idna_codec is None:
idna_codec = IDNA_2003_Practical
return '.'.join([idna_codec.decode(x) for x in l])
return ".".join([idna_codec.decode(x) for x in l])
def to_digestable(self, origin=None):
def to_digestable(self, origin: Optional["Name"] = None) -> bytes:
"""Convert name to a format suitable for digesting in hashes.
The name is canonicalized and converted to uncompressed wire
@ -589,10 +635,17 @@ class Name:
Returns a ``bytes``.
"""
return self.to_wire(origin=origin, canonicalize=True)
digest = self.to_wire(origin=origin, canonicalize=True)
assert digest is not None
return digest
def to_wire(self, file=None, compress=None, origin=None,
canonicalize=False):
def to_wire(
self,
file: Optional[Any] = None,
compress: Optional[CompressType] = None,
origin: Optional["Name"] = None,
canonicalize: bool = False,
) -> Optional[bytes]:
"""Convert name to wire format, possibly compressing it.
*file* is the file where the name is emitted (typically an
@ -638,6 +691,7 @@ class Name:
out += label
return bytes(out)
labels: Iterable[bytes]
if not self.is_absolute():
if origin is None or not origin.is_absolute():
raise NeedAbsoluteNameOrOrigin
@ -654,24 +708,25 @@ class Name:
else:
pos = None
if pos is not None:
value = 0xc000 + pos
s = struct.pack('!H', value)
value = 0xC000 + pos
s = struct.pack("!H", value)
file.write(s)
break
else:
if compress is not None and len(n) > 1:
pos = file.tell()
if pos <= 0x3fff:
if pos <= 0x3FFF:
compress[n] = pos
l = len(label)
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
if l > 0:
if canonicalize:
file.write(label.lower())
else:
file.write(label)
return None
def __len__(self):
def __len__(self) -> int:
"""The length of the name (in labels).
Returns an ``int``.
@ -688,7 +743,7 @@ class Name:
def __sub__(self, other):
return self.relativize(other)
def split(self, depth):
def split(self, depth: int) -> Tuple["Name", "Name"]:
"""Split a name into a prefix and suffix names at the specified depth.
*depth* is an ``int`` specifying the number of labels in the suffix
@ -705,11 +760,10 @@ class Name:
elif depth == l:
return (dns.name.empty, self)
elif depth < 0 or depth > l:
raise ValueError(
'depth must be >= 0 and <= the length of the name')
return (Name(self[: -depth]), Name(self[-depth:]))
raise ValueError("depth must be >= 0 and <= the length of the name")
return (Name(self[:-depth]), Name(self[-depth:]))
def concatenate(self, other):
def concatenate(self, other: "Name") -> "Name":
"""Return a new name which is the concatenation of self and other.
Raises ``dns.name.AbsoluteConcatenation`` if the name is
@ -724,7 +778,7 @@ class Name:
labels.extend(list(other.labels))
return Name(labels)
def relativize(self, origin):
def relativize(self, origin: "Name") -> "Name":
"""If the name is a subdomain of *origin*, return a new name which is
the name relative to origin. Otherwise return the name.
@ -740,7 +794,7 @@ class Name:
else:
return self
def derelativize(self, origin):
def derelativize(self, origin: "Name") -> "Name":
"""If the name is a relative name, return a new name which is the
concatenation of the name and origin. Otherwise return the name.
@ -756,7 +810,9 @@ class Name:
else:
return self
def choose_relativity(self, origin=None, relativize=True):
def choose_relativity(
self, origin: Optional["Name"] = None, relativize: bool = True
) -> "Name":
"""Return a name with the relativity desired by the caller.
If *origin* is ``None``, then the name is returned.
@ -775,7 +831,7 @@ class Name:
else:
return self
def parent(self):
def parent(self) -> "Name":
"""Return the parent of the name.
For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``.
@ -790,13 +846,17 @@ class Name:
raise NoParent
return Name(self.labels[1:])
#: The root name, '.'
root = Name([b''])
root = Name([b""])
#: The empty name.
empty = Name([])
def from_unicode(text, origin=root, idna_codec=None):
def from_unicode(
text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None
) -> Name:
"""Convert unicode text into a Name object.
Labels are encoded in IDN ACE form according to rules specified by
@ -819,17 +879,17 @@ def from_unicode(text, origin=root, idna_codec=None):
if not (origin is None or isinstance(origin, Name)):
raise ValueError("origin must be a Name or None")
labels = []
label = ''
label = ""
escaping = False
edigits = 0
total = 0
if idna_codec is None:
idna_codec = IDNA_2003
if text == '@':
text = ''
if text == "@":
text = ""
if text:
if text in ['.', '\u3002', '\uff0e', '\uff61']:
return Name([b'']) # no Unicode "u" on this constant!
if text in [".", "\u3002", "\uff0e", "\uff61"]:
return Name([b""]) # no Unicode "u" on this constant!
for c in text:
if escaping:
if edigits == 0:
@ -848,12 +908,12 @@ def from_unicode(text, origin=root, idna_codec=None):
if edigits == 3:
escaping = False
label += chr(total)
elif c in ['.', '\u3002', '\uff0e', '\uff61']:
elif c in [".", "\u3002", "\uff0e", "\uff61"]:
if len(label) == 0:
raise EmptyLabel
labels.append(idna_codec.encode(label))
label = ''
elif c == '\\':
label = ""
elif c == "\\":
escaping = True
edigits = 0
total = 0
@ -864,22 +924,28 @@ def from_unicode(text, origin=root, idna_codec=None):
if len(label) > 0:
labels.append(idna_codec.encode(label))
else:
labels.append(b'')
labels.append(b"")
if (len(labels) == 0 or labels[-1] != b'') and origin is not None:
if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
labels.extend(list(origin.labels))
return Name(labels)
def is_all_ascii(text):
def is_all_ascii(text: str) -> bool:
for c in text:
if ord(c) > 0x7f:
if ord(c) > 0x7F:
return False
return True
def from_text(text, origin=root, idna_codec=None):
def from_text(
text: Union[bytes, str],
origin: Optional[Name] = root,
idna_codec: Optional[IDNACodec] = None,
) -> Name:
"""Convert text into a Name object.
*text*, a ``str``, is the text to convert into a name.
*text*, a ``bytes`` or ``str``, is the text to convert into a name.
*origin*, a ``dns.name.Name``, specifies the origin to
append to non-absolute names. The default is the root name.
@ -903,23 +969,23 @@ def from_text(text, origin=root, idna_codec=None):
#
# then it's still "all ASCII" even though the domain name has
# codepoints > 127.
text = text.encode('ascii')
text = text.encode("ascii")
if not isinstance(text, bytes):
raise ValueError("input to from_text() must be a string")
if not (origin is None or isinstance(origin, Name)):
raise ValueError("origin must be a Name or None")
labels = []
label = b''
label = b""
escaping = False
edigits = 0
total = 0
if text == b'@':
text = b''
if text == b"@":
text = b""
if text:
if text == b'.':
return Name([b''])
if text == b".":
return Name([b""])
for c in text:
byte_ = struct.pack('!B', c)
byte_ = struct.pack("!B", c)
if escaping:
if edigits == 0:
if byte_.isdigit():
@ -936,13 +1002,13 @@ def from_text(text, origin=root, idna_codec=None):
edigits += 1
if edigits == 3:
escaping = False
label += struct.pack('!B', total)
elif byte_ == b'.':
label += struct.pack("!B", total)
elif byte_ == b".":
if len(label) == 0:
raise EmptyLabel
labels.append(label)
label = b''
elif byte_ == b'\\':
label = b""
elif byte_ == b"\\":
escaping = True
edigits = 0
total = 0
@ -953,13 +1019,16 @@ def from_text(text, origin=root, idna_codec=None):
if len(label) > 0:
labels.append(label)
else:
labels.append(b'')
if (len(labels) == 0 or labels[-1] != b'') and origin is not None:
labels.append(b"")
if (len(labels) == 0 or labels[-1] != b"") and origin is not None:
labels.extend(list(origin.labels))
return Name(labels)
def from_wire_parser(parser):
# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
def from_wire_parser(parser: "dns.wire.Parser") -> Name:
"""Convert possibly compressed wire format into a Name.
*parser* is a dns.wire.Parser.
@ -980,7 +1049,7 @@ def from_wire_parser(parser):
if count < 64:
labels.append(parser.get_bytes(count))
elif count >= 192:
current = (count & 0x3f) * 256 + parser.get_uint8()
current = (count & 0x3F) * 256 + parser.get_uint8()
if current >= biggest_pointer:
raise BadPointer
biggest_pointer = current
@ -988,11 +1057,11 @@ def from_wire_parser(parser):
else:
raise BadLabelType
count = parser.get_uint8()
labels.append(b'')
labels.append(b"")
return Name(labels)
def from_wire(message, current):
def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
"""Convert possibly compressed wire format into a Name.
*message* is a ``bytes`` containing an entire DNS message in DNS

View file

@ -1,40 +0,0 @@
from typing import Optional, Union, Tuple, Iterable, List
have_idna_2008: bool
class Name:
def is_subdomain(self, o : Name) -> bool: ...
def is_superdomain(self, o : Name) -> bool: ...
def __init__(self, labels : Iterable[Union[bytes,str]]) -> None:
self.labels : List[bytes]
def is_absolute(self) -> bool: ...
def is_wild(self) -> bool: ...
def fullcompare(self, other) -> Tuple[int,int,int]: ...
def canonicalize(self) -> Name: ...
def __eq__(self, other) -> bool: ...
def __ne__(self, other) -> bool: ...
def __lt__(self, other : Name) -> bool: ...
def __le__(self, other : Name) -> bool: ...
def __ge__(self, other : Name) -> bool: ...
def __gt__(self, other : Name) -> bool: ...
def to_text(self, omit_final_dot=False) -> str: ...
def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ...
def to_digestable(self, origin=None) -> bytes: ...
def to_wire(self, file=None, compress=None, origin=None,
canonicalize=False) -> Optional[bytes]: ...
def __add__(self, other : Name) -> Name: ...
def __sub__(self, other : Name) -> Name: ...
def split(self, depth) -> List[Tuple[str,str]]: ...
def concatenate(self, other : Name) -> Name: ...
def relativize(self, origin) -> Name: ...
def derelativize(self, origin) -> Name: ...
def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ...
def parent(self) -> Name: ...
class IDNACodec:
pass
def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name:
...
empty : Name

View file

@ -27,7 +27,8 @@
"""DNS name dictionary"""
from collections.abc import MutableMapping
# pylint seems to be confused about this one!
from collections.abc import MutableMapping # pylint: disable=no-name-in-module
import dns.name
@ -62,7 +63,7 @@ class NameDict(MutableMapping):
def __setitem__(self, key, value):
if not isinstance(key, dns.name.Name):
raise ValueError('NameDict key must be a name')
raise ValueError("NameDict key must be a name")
self.__store[key] = value
self.__update_max_depth(key)

View file

@ -17,12 +17,17 @@
"""DNS nodes. A node is a set of rdatasets."""
from typing import Any, Dict, Optional
import enum
import io
import dns.immutable
import dns.name
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rrset
import dns.renderer
@ -32,26 +37,28 @@ _cname_types = {
# "neutral" types can coexist with a CNAME and thus are not "other data"
_neutral_types = {
dns.rdatatype.NSEC, # RFC 4035 section 2.5
dns.rdatatype.NSEC, # RFC 4035 section 2.5
dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible!
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
}
def _matches_type_or_its_signature(rdtypes, rdtype, covers):
return rdtype in rdtypes or \
(rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
@enum.unique
class NodeKind(enum.Enum):
"""Rdatasets in nodes
"""
REGULAR = 0 # a.k.a "other data"
"""Rdatasets in nodes"""
REGULAR = 0 # a.k.a "other data"
NEUTRAL = 1
CNAME = 2
@classmethod
def classify(cls, rdtype, covers):
def classify(
cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType
) -> "NodeKind":
if _matches_type_or_its_signature(_cname_types, rdtype, covers):
return NodeKind.CNAME
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
@ -60,7 +67,7 @@ class NodeKind(enum.Enum):
return NodeKind.REGULAR
@classmethod
def classify_rdataset(cls, rdataset):
def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind":
return cls.classify(rdataset.rdtype, rdataset.covers)
@ -81,19 +88,19 @@ class Node:
deleted.
"""
__slots__ = ['rdatasets']
__slots__ = ["rdatasets"]
def __init__(self):
# the set of rdatasets, represented as a list.
self.rdatasets = []
def to_text(self, name, **kw):
def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str:
"""Convert a node to text format.
Each rdataset at the node is printed. Any keyword arguments
to this method are passed on to the rdataset's to_text() method.
*name*, a ``dns.name.Name`` or ``str``, the owner name of the
*name*, a ``dns.name.Name``, the owner name of the
rdatasets.
Returns a ``str``.
@ -103,12 +110,12 @@ class Node:
s = io.StringIO()
for rds in self.rdatasets:
if len(rds) > 0:
s.write(rds.to_text(name, **kw))
s.write('\n')
s.write(rds.to_text(name, **kw)) # type: ignore[arg-type]
s.write("\n")
return s.getvalue()[:-1]
def __repr__(self):
return '<DNS node ' + str(id(self)) + '>'
return "<DNS node " + str(id(self)) + ">"
def __eq__(self, other):
#
@ -144,27 +151,36 @@ class Node:
if len(self.rdatasets) > 0:
kind = NodeKind.classify_rdataset(rdataset)
if kind == NodeKind.CNAME:
self.rdatasets = [rds for rds in self.rdatasets if
NodeKind.classify_rdataset(rds) !=
NodeKind.REGULAR]
self.rdatasets = [
rds
for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR
]
elif kind == NodeKind.REGULAR:
self.rdatasets = [rds for rds in self.rdatasets if
NodeKind.classify_rdataset(rds) !=
NodeKind.CNAME]
self.rdatasets = [
rds
for rds in self.rdatasets
if NodeKind.classify_rdataset(rds) != NodeKind.CNAME
]
# Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
# edit self.rdatasets.
self.rdatasets.append(rdataset)
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
"""Find an rdataset matching the specified properties in the
current node.
*rdclass*, an ``int``, the class of the rdataset.
*rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset.
*rdtype*, an ``int``, the type of the rdataset.
*rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset.
*covers*, an ``int`` or ``None``, the covered type.
*covers*, a ``dns.rdatatype.RdataType``, the covered type.
Usually this value is ``dns.rdatatype.NONE``, but if the
rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``,
then the covers value will be the rdata type the SIG/RRSIG
@ -191,8 +207,13 @@ class Node:
self._append_rdataset(rds)
return rds
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> Optional[dns.rdataset.Rdataset]:
"""Get an rdataset matching the specified properties in the
current node.
@ -223,7 +244,12 @@ class Node:
rds = None
return rds
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
"""Delete the rdataset matching the specified properties in the
current node.
@ -240,7 +266,7 @@ class Node:
if rds is not None:
self.rdatasets.remove(rds)
def replace_rdataset(self, replacement):
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
"""Replace an rdataset.
It is not an error if there is no rdataset matching *replacement*.
@ -256,16 +282,17 @@ class Node:
"""
if not isinstance(replacement, dns.rdataset.Rdataset):
raise ValueError('replacement is not an rdataset')
raise ValueError("replacement is not an rdataset")
if isinstance(replacement, dns.rrset.RRset):
# RRsets are not good replacements as the match() method
# is not compatible.
replacement = replacement.to_rdataset()
self.delete_rdataset(replacement.rdclass, replacement.rdtype,
replacement.covers)
self.delete_rdataset(
replacement.rdclass, replacement.rdtype, replacement.covers
)
self._append_rdataset(replacement)
def classify(self):
def classify(self) -> NodeKind:
"""Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a
@ -286,7 +313,7 @@ class Node:
return kind
return NodeKind.NEUTRAL
def is_immutable(self):
def is_immutable(self) -> bool:
return False
@ -298,23 +325,38 @@ class ImmutableNode(Node):
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
def find_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> dns.rdataset.Rdataset:
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
def get_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
create: bool = False,
) -> Optional[dns.rdataset.Rdataset]:
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
def delete_rdataset(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
) -> None:
raise TypeError("immutable")
def replace_rdataset(self, replacement):
def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
raise TypeError("immutable")
def is_immutable(self):
def is_immutable(self) -> bool:
return True

View file

@ -1,17 +0,0 @@
from typing import List, Optional, Union
from . import rdataset, rdatatype, name
class Node:
def __init__(self):
self.rdatasets : List[rdataset.Rdataset]
def to_text(self, name : Union[str,name.Name], **kw) -> str:
...
def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
create=False) -> rdataset.Rdataset:
...
def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
create=False) -> Optional[rdataset.Rdataset]:
...
def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE):
...
def replace_rdataset(self, replacement : rdataset.Rdataset) -> None:
...

View file

@ -20,6 +20,7 @@
import dns.enum
import dns.exception
class Opcode(dns.enum.IntEnum):
#: Query
QUERY = 0
@ -45,7 +46,7 @@ class UnknownOpcode(dns.exception.DNSException):
"""An DNS opcode is unknown."""
def from_text(text):
def from_text(text: str) -> Opcode:
"""Convert text into an opcode.
*text*, a ``str``, the textual opcode
@ -58,7 +59,7 @@ def from_text(text):
return Opcode.from_text(text)
def from_flags(flags):
def from_flags(flags: int) -> Opcode:
"""Extract an opcode from DNS message flags.
*flags*, an ``int``, the DNS flags.
@ -66,10 +67,10 @@ def from_flags(flags):
Returns an ``int``.
"""
return (flags & 0x7800) >> 11
return Opcode((flags & 0x7800) >> 11)
def to_flags(value):
def to_flags(value: Opcode) -> int:
"""Convert an opcode to a value suitable for ORing into DNS message
flags.
@ -81,7 +82,7 @@ def to_flags(value):
return (value << 11) & 0x7800
def to_text(value):
def to_text(value: Opcode) -> str:
"""Convert an opcode to text.
*value*, an ``int`` the opcode value,
@ -94,7 +95,7 @@ def to_text(value):
return Opcode.to_text(value)
def is_update(flags):
def is_update(flags: int) -> bool:
"""Is the opcode in flags UPDATE?
*flags*, an ``int``, the DNS message flags.
@ -104,6 +105,7 @@ def is_update(flags):
return from_flags(flags) == Opcode.UPDATE
### BEGIN generated Opcode constants
QUERY = Opcode.QUERY

File diff suppressed because it is too large Load diff

View file

@ -1,64 +0,0 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message
from requests.sessions import Session
import socket
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
have_doh: bool
def https(q : message.Message, where: str, timeout : Optional[float] = None,
port : Optional[int] = 443, source : Optional[str] = None,
source_port : Optional[int] = 0,
session: Optional[Session] = None,
path : Optional[str] = '/dns-query', post : Optional[bool] = True,
bootstrap_address : Optional[str] = None,
verify : Optional[bool] = True) -> message.Message:
pass
def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None) -> message.Message:
pass
def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR,
rdclass=rdataclass.IN,
timeout : Optional[float] = None, port=53,
keyring : Optional[Dict[name.Name, bytes]] = None,
keyname : Union[str,name.Name]= None, relativize=True,
lifetime : Optional[float] = None,
source : Optional[str] = None, source_port=0, serial=0,
use_udp : Optional[bool] = False,
keyalgorithm=tsig.default_algorithm) \
-> Generator[Any,Any,message.Message]:
pass
def udp(q : message.Message, where : str, timeout : Optional[float] = None,
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None) -> message.Message:
pass
def tls(q : message.Message, where : str, timeout : Optional[float] = None,
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

74
lib/dns/quic/__init__.py Normal file
View file

@ -0,0 +1,74 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
try:
import aioquic.quic.configuration # type: ignore
import dns.asyncbackend
from dns._asyncbackend import NullContext
from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
from dns.quic._asyncio import (
AsyncioQuicManager,
AsyncioQuicConnection,
AsyncioQuicStream,
)
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
have_quic = True
def null_factory(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
return NullContext(None)
def _asyncio_manager_factory(
context, *args, **kwargs # pylint: disable=unused-argument
):
return AsyncioQuicManager(*args, **kwargs)
# We have a context factory and a manager factory as for trio we need to have
# a nursery.
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
try:
import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
TrioQuicManager,
TrioQuicConnection,
TrioQuicStream,
)
def _trio_context_factory():
return trio.open_nursery()
def _trio_manager_factory(context, *args, **kwargs):
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
except ImportError:
pass
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
except ImportError:
have_quic = False
from typing import Any
class AsyncQuicStream: # type: ignore
pass
class AsyncQuicConnection: # type: ignore
async def make_stream(self) -> Any:
raise NotImplementedError
class SyncQuicStream: # type: ignore
pass
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError

206
lib/dns/quic/_asyncio.py Normal file
View file

@ -0,0 +1,206 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import asyncio
import socket
import ssl
import struct
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.inet
import dns.asyncbackend
from dns.quic._common import (
BaseQuicStream,
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
)
class AsyncioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = asyncio.Condition()
async def _wait_for_wake_up(self):
async with self._wake_up:
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True:
if self._buffer.have(amount):
return
self._expecting = amount
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except Exception:
pass
self._expecting = 0
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class AsyncioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = None
self._handshake_complete = asyncio.Event()
self._socket_created = asyncio.Event()
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
async def _receiver(self):
try:
af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio")
self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, self._peer
)
self._socket_created.set()
async with self._socket:
while not self._done:
(datagram, address) = await self._socket.recvfrom(
QUIC_MAX_DATAGRAM, None
)
if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue
self._connection.receive_datagram(
datagram, self._peer[0], time.time()
)
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
async with self._wake_timer:
self._wake_timer.notify_all()
except Exception:
pass
async def _wait_for_wake_timer(self):
async with self._wake_timer:
await self._wake_timer.wait()
async def _sender(self):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, address) in datagrams:
assert address == self._peer[0]
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
self._done = True
self._receiver_task.cancel()
count += 1
if count > 10:
# yield
count = 0
await asyncio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
async with self._wake_timer:
self._wake_timer.notify_all()
def run(self):
if self._closed:
return
self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender())
async def make_stream(self):
await self._handshake_complete.wait()
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
async with self._wake_timer:
self._wake_timer.notify_all()
try:
await self._receiver_task
except asyncio.CancelledError:
pass
try:
await self._sender_task
except asyncio.CancelledError:
pass
class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
super().__init__(conf, verify_mode, AsyncioQuicConnection)
def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
if start:
connection.run()
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

180
lib/dns/quic/_common.py Normal file
View file

@ -0,0 +1,180 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import struct
import time
from typing import Any
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import dns.inet
QUIC_MAX_DATAGRAM = 2048
class UnexpectedEOF(Exception):
pass
class Buffer:
def __init__(self):
self._buffer = b""
self._seen_end = False
def put(self, data, is_end):
if self._seen_end:
return
self._buffer += data
if is_end:
self._seen_end = True
def have(self, amount):
if len(self._buffer) >= amount:
return True
if self._seen_end:
raise UnexpectedEOF
return False
def seen_end(self):
return self._seen_end
def get(self, amount):
assert self.have(amount)
data = self._buffer[:amount]
self._buffer = self._buffer[amount:]
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
self._connection = connection
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
def id(self):
return self._stream_id
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
else:
expiration = None
return expiration
def _timeout_from_expiration(self, expiration):
if expiration is not None:
timeout = max(expiration - time.time(), 0.0)
else:
timeout = None
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises UnexpectedEOF.
def _encapsulate(self, datagram):
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
return self._expecting > 0 and self._buffer.have(self._expecting)
def _close(self):
self._connection.close_stream(self._stream_id)
self._buffer.put(b"", True) # send EOF in case we haven't seen it.
class BaseQuicConnection:
def __init__(
self, connection, address, port, source=None, source_port=0, manager=None
):
self._done = False
self._connection = connection
self._address = address
self._port = port
self._closed = False
self._manager = manager
self._streams = {}
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
if self._af == socket.AF_INET:
source = "0.0.0.0"
elif self._af == socket.AF_INET6:
source = "::"
else:
raise NotImplementedError
if source:
self._source = (source, source_port)
else:
self._source = None
def close_stream(self, stream_id):
del self._streams[stream_id]
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
if expiration is None:
expiration = now + 3600 # arbitrary "big" value
interval = max(expiration - now, 0)
if self._closed and closed_is_special:
# lower sleep interval to avoid a race in the closing process
# which can lead to higher latency closing due to sleeping when
# we have events.
interval = min(interval, 0.05)
return (expiration, interval)
def _handle_timer(self, expiration):
now = time.time()
if expiration <= now:
self._connection.handle_timer(now)
class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self) -> Any:
pass
class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory):
self._connections = {}
self._connection_factory = connection_factory
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=["doq", "doq-i03"],
verify_mode=verify_mode,
)
if verify_path is not None:
conf.load_verify_locations(verify_path)
self._conf = conf
def _connect(self, address, port=853, source=None, source_port=0):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
qconn.connect(address, time.time())
connection = self._connection_factory(
qconn, address, port, source, source_port, self
)
self._connections[(address, port)] = connection
return (connection, True)
def closed(self, address, port):
try:
del self._connections[(address, port)]
except KeyError:
pass
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
raise NotImplementedError

214
lib/dns/quic/_sync.py Normal file
View file

@ -0,0 +1,214 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import selectors
import struct
import threading
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.inet
from dns.quic._common import (
BaseQuicStream,
BaseQuicConnection,
BaseQuicManager,
QUIC_MAX_DATAGRAM,
)
# Avoid circularity with dns.query
if hasattr(selectors, "PollSelector"):
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
class SyncQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = threading.Condition()
self._lock = threading.Lock()
def wait_for(self, amount, expiration):
timeout = self._timeout_from_expiration(expiration)
while True:
with self._lock:
if self._buffer.have(amount):
return
self._expecting = amount
with self._wake_up:
self._wake_up.wait(timeout)
self._expecting = 0
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
self._connection.write(self._stream_id, data, is_end)
def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
with self._wake_up:
self._wake_up.notify()
def close(self):
with self._lock:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
with self._wake_up:
self._wake_up.notify()
return False
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
if self._source is not None:
try:
self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
except Exception:
self._socket.close()
raise
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
def _read(self):
count = 0
while count < 10:
count += 1
try:
datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
except BlockingIOError:
return
with self._lock:
self._connection.receive_datagram(datagram, self._peer[0], time.time())
def _drain_wakeup(self):
while True:
try:
self._receive_wakeup.recv(32)
except BlockingIOError:
return
def _worker(self):
sel = _selector_class()
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
items = sel.select(interval)
for (key, _) in items:
key.data()
with self._lock:
self._handle_timer(expiration)
datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, _) in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
self._handle_events()
def _handle_events(self):
while True:
with self._lock:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
self._done = True
def write(self, stream, data, is_end=False):
with self._lock:
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start()
def make_stream(self):
self._handshake_complete.wait()
with self._lock:
stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
def close_stream(self, stream_id):
with self._lock:
super().close_stream(stream_id)
def close(self):
with self._lock:
if self._closed:
return
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_wakeup.send(b"\x01")
self._worker_thread.join()
class SyncQuicManager(BaseQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
super().__init__(conf, verify_mode, SyncQuicConnection)
self._lock = threading.Lock()
def connect(self, address, port=853, source=None, source_port=0):
with self._lock:
(connection, start) = self._connect(address, port, source, source_port)
if start:
connection.run()
return connection
def closed(self, address, port):
with self._lock:
super().closed(address, port)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
connection.close()
return False

170
lib/dns/quic/_trio.py Normal file
View file

@ -0,0 +1,170 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import struct
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import trio
import dns.inet
from dns._asyncbackend import NullContext
from dns.quic._common import (
BaseQuicStream,
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
)
class TrioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = trio.Condition()
async def wait_for(self, amount):
while True:
if self._buffer.have(amount):
return
self._expecting = amount
async with self._wake_up:
await self._wake_up.wait()
self._expecting = 0
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
if self._source:
trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
async def _worker(self):
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
with trio.CancelScope(
deadline=trio.current_time() + interval
) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram(datagram, self._peer[0], time.time())
self._worker_scope = None
self._handle_timer(expiration)
datagrams = self._connection.datagrams_to_send(time.time())
for (datagram, _) in datagrams:
await self._socket.send(datagram)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(
event, aioquic.quic.events.ConnectionTerminated
) or isinstance(event, aioquic.quic.events.StreamReset):
self._done = True
self._socket.close()
count += 1
if count > 10:
# yield
count = 0
await trio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
if self._worker_scope is not None:
self._worker_scope.cancel()
async def run(self):
if self._closed:
return
async with trio.open_nursery() as nursery:
nursery.start_soon(self._worker)
self._run_done.set()
async def make_stream(self):
await self._handshake_complete.wait()
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
if self._worker_scope is not None:
self._worker_scope.cancel()
await self._run_done.wait()
class TrioQuicManager(AsyncQuicManager):
def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
super().__init__(conf, verify_mode, TrioQuicConnection)
self._nursery = nursery
def connect(self, address, port=853, source=None, source_port=0):
(connection, start) = self._connect(address, port, source, source_port)
if start:
self._nursery.start_soon(connection.run)
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the itertor into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View file

@ -17,9 +17,12 @@
"""DNS Result Codes."""
from typing import Tuple
import dns.enum
import dns.exception
class Rcode(dns.enum.IntEnum):
#: No error
NOERROR = 0
@ -77,20 +80,20 @@ class UnknownRcode(dns.exception.DNSException):
"""A DNS rcode is unknown."""
def from_text(text):
def from_text(text: str) -> Rcode:
"""Convert text into an rcode.
*text*, a ``str``, the textual rcode or an integer in textual form.
Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown.
Returns an ``int``.
Returns a ``dns.rcode.Rcode``.
"""
return Rcode.from_text(text)
def from_flags(flags, ednsflags):
def from_flags(flags: int, ednsflags: int) -> Rcode:
"""Return the rcode value encoded by flags and ednsflags.
*flags*, an ``int``, the DNS flags field.
@ -99,17 +102,17 @@ def from_flags(flags, ednsflags):
Raises ``ValueError`` if rcode is < 0 or > 4095
Returns an ``int``.
Returns a ``dns.rcode.Rcode``.
"""
value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0)
return value
value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0)
return Rcode.make(value)
def to_flags(value):
def to_flags(value: Rcode) -> Tuple[int, int]:
"""Return a (flags, ednsflags) tuple which encodes the rcode.
*value*, an ``int``, the rcode.
*value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095.
@ -117,16 +120,16 @@ def to_flags(value):
"""
if value < 0 or value > 4095:
raise ValueError('rcode must be >= 0 and <= 4095')
v = value & 0xf
ev = (value & 0xff0) << 20
raise ValueError("rcode must be >= 0 and <= 4095")
v = value & 0xF
ev = (value & 0xFF0) << 20
return (v, ev)
def to_text(value, tsig=False):
def to_text(value: Rcode, tsig: bool = False) -> str:
"""Convert rcode into text.
*value*, an ``int``, the rcode.
*value*, a ``dns.rcode.Rcode``, the rcode.
Raises ``ValueError`` if rcode is < 0 or > 4095.
@ -134,9 +137,10 @@ def to_text(value, tsig=False):
"""
if tsig and value == Rcode.BADVERS:
return 'BADSIG'
return "BADSIG"
return Rcode.to_text(value)
### BEGIN generated Rcode constants
NOERROR = Rcode.NOERROR

View file

@ -17,6 +17,8 @@
"""DNS rdata."""
from typing import Any, Dict, Optional, Tuple, Union
from importlib import import_module
import base64
import binascii
@ -55,21 +57,22 @@ class NoRelativeRdataOrdering(dns.exception.DNSException):
"""
def _wordbreak(data, chunksize=_chunksize, separator=b' '):
def _wordbreak(data, chunksize=_chunksize, separator=b" "):
"""Break a binary string into chunks of chunksize characters separated by
a space.
"""
if not chunksize:
return data.decode()
return separator.join([data[i:i + chunksize]
for i
in range(0, len(data), chunksize)]).decode()
return separator.join(
[data[i : i + chunksize] for i in range(0, len(data), chunksize)]
).decode()
# pylint: disable=unused-argument
def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
def _hexify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
@ -77,17 +80,19 @@ def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
return _wordbreak(binascii.hexlify(data), chunksize, separator)
def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw):
def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a separator.
"""
return _wordbreak(base64.b64encode(data), chunksize, separator)
# pylint: enable=unused-argument
__escaped = b'"\\'
def _escapify(qstring):
"""Escape the characters in a quoted string which need it."""
@ -96,14 +101,14 @@ def _escapify(qstring):
if not isinstance(qstring, bytearray):
qstring = bytearray(qstring)
text = ''
text = ""
for c in qstring:
if c in __escaped:
text += '\\' + chr(c)
text += "\\" + chr(c)
elif c >= 0x20 and c < 0x7F:
text += chr(c)
else:
text += '\\%03d' % c
text += "\\%03d" % c
return text
@ -114,9 +119,10 @@ def _truncate_bitmap(what):
for i in range(len(what) - 1, -1, -1):
if what[i] != 0:
return what[0: i + 1]
return what[0 : i + 1]
return what[0:1]
# So we don't have to edit all the rdata classes...
_constify = dns.immutable.constify
@ -125,7 +131,7 @@ _constify = dns.immutable.constify
class Rdata:
"""Base class for all DNS rdata types."""
__slots__ = ['rdclass', 'rdtype', 'rdcomment']
__slots__ = ["rdclass", "rdtype", "rdcomment"]
def __init__(self, rdclass, rdtype):
"""Initialize an rdata.
@ -140,8 +146,9 @@ class Rdata:
self.rdcomment = None
def _get_all_slots(self):
return itertools.chain.from_iterable(getattr(cls, '__slots__', [])
for cls in self.__class__.__mro__)
return itertools.chain.from_iterable(
getattr(cls, "__slots__", []) for cls in self.__class__.__mro__
)
def __getstate__(self):
# We used to try to do a tuple of all slots here, but it
@ -160,12 +167,12 @@ class Rdata:
def __setstate__(self, state):
for slot, val in state.items():
object.__setattr__(self, slot, val)
if not hasattr(self, 'rdcomment'):
if not hasattr(self, "rdcomment"):
# Pickled rdata from 2.0.x might not have a rdcomment, so add
# it if needed.
object.__setattr__(self, 'rdcomment', None)
object.__setattr__(self, "rdcomment", None)
def covers(self):
def covers(self) -> dns.rdatatype.RdataType:
"""Return the type a Rdata covers.
DNS SIG/RRSIG rdatas apply to a specific type; this type is
@ -174,12 +181,12 @@ class Rdata:
creating rdatasets, allowing the rdataset to contain only RRSIGs
of a particular type, e.g. RRSIG(NS).
Returns an ``int``.
Returns a ``dns.rdatatype.RdataType``.
"""
return dns.rdatatype.NONE
def extended_rdatatype(self):
def extended_rdatatype(self) -> int:
"""Return a 32-bit type value, the least significant 16 bits of
which are the ordinary DNS type, and the upper 16 bits of which are
the "covered" type, if any.
@ -189,7 +196,12 @@ class Rdata:
return self.covers() << 16 | self.rdtype
def to_text(self, origin=None, relativize=True, **kw):
def to_text(
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
) -> str:
"""Convert an rdata to text format.
Returns a ``str``.
@ -197,11 +209,22 @@ class Rdata:
raise NotImplementedError # pragma: no cover
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
def _to_wire(
self,
file: Optional[Any],
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> bytes:
raise NotImplementedError # pragma: no cover
def to_wire(self, file=None, compress=None, origin=None,
canonicalize=False):
def to_wire(
self,
file: Optional[Any] = None,
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> bytes:
"""Convert an rdata to wire format.
Returns a ``bytes`` or ``None``.
@ -214,15 +237,18 @@ class Rdata:
self._to_wire(f, compress, origin, canonicalize)
return f.getvalue()
def to_generic(self, origin=None):
def to_generic(
self, origin: Optional[dns.name.Name] = None
) -> "dns.rdata.GenericRdata":
"""Creates a dns.rdata.GenericRdata equivalent of this rdata.
Returns a ``dns.rdata.GenericRdata``.
"""
return dns.rdata.GenericRdata(self.rdclass, self.rdtype,
self.to_wire(origin=origin))
return dns.rdata.GenericRdata(
self.rdclass, self.rdtype, self.to_wire(origin=origin)
)
def to_digestable(self, origin=None):
def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes:
"""Convert rdata to a format suitable for digesting in hashes. This
is also the DNSSEC canonical form.
@ -234,12 +260,19 @@ class Rdata:
def __repr__(self):
covers = self.covers()
if covers == dns.rdatatype.NONE:
ctext = ''
ctext = ""
else:
ctext = '(' + dns.rdatatype.to_text(covers) + ')'
return '<DNS ' + dns.rdataclass.to_text(self.rdclass) + ' ' + \
dns.rdatatype.to_text(self.rdtype) + ctext + ' rdata: ' + \
str(self) + '>'
ctext = "(" + dns.rdatatype.to_text(covers) + ")"
return (
"<DNS "
+ dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdata: "
+ str(self)
+ ">"
)
def __str__(self):
return self.to_text()
@ -320,27 +353,39 @@ class Rdata:
return not self.__eq__(other)
def __lt__(self, other):
if not isinstance(other, Rdata) or \
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) < 0
def __le__(self, other):
if not isinstance(other, Rdata) or \
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) <= 0
def __ge__(self, other):
if not isinstance(other, Rdata) or \
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) >= 0
def __gt__(self, other):
if not isinstance(other, Rdata) or \
self.rdclass != other.rdclass or self.rdtype != other.rdtype:
if (
not isinstance(other, Rdata)
or self.rdclass != other.rdclass
or self.rdtype != other.rdtype
):
return NotImplemented
return self._cmp(other) > 0
@ -348,15 +393,28 @@ class Rdata:
return hash(self.to_digestable(dns.name.root))
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
tok: dns.tokenizer.Tokenizer,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
relativize_to: Optional[dns.name.Name] = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
def from_wire_parser(
cls,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
parser: dns.wire.Parser,
origin: Optional[dns.name.Name] = None,
) -> "Rdata":
raise NotImplementedError # pragma: no cover
def replace(self, **kwargs):
def replace(self, **kwargs: Any) -> "Rdata":
"""
Create a new Rdata instance based on the instance replace was
invoked on. It is possible to pass different parameters to
@ -369,19 +427,25 @@ class Rdata:
"""
# Get the constructor parameters.
parameters = inspect.signature(self.__init__).parameters
parameters = inspect.signature(self.__init__).parameters # type: ignore
# Ensure that all of the arguments correspond to valid fields.
# Don't allow rdclass or rdtype to be changed, though.
for key in kwargs:
if key == 'rdcomment':
if key == "rdcomment":
continue
if key not in parameters:
raise AttributeError("'{}' object has no attribute '{}'"
.format(self.__class__.__name__, key))
if key in ('rdclass', 'rdtype'):
raise AttributeError("Cannot overwrite '{}' attribute '{}'"
.format(self.__class__.__name__, key))
raise AttributeError(
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, key
)
)
if key in ("rdclass", "rdtype"):
raise AttributeError(
"Cannot overwrite '{}' attribute '{}'".format(
self.__class__.__name__, key
)
)
# Construct the parameter list. For each field, use the value in
# kwargs if present, and the current value otherwise.
@ -391,9 +455,9 @@ class Rdata:
rd = self.__class__(*args)
# The comment is not set in the constructor, so give it special
# handling.
rdcomment = kwargs.get('rdcomment', self.rdcomment)
rdcomment = kwargs.get("rdcomment", self.rdcomment)
if rdcomment is not None:
object.__setattr__(rd, 'rdcomment', rdcomment)
object.__setattr__(rd, "rdcomment", rdcomment)
return rd
# Type checking and conversion helpers. These are class methods as
@ -408,18 +472,26 @@ class Rdata:
return dns.rdatatype.RdataType.make(value)
@classmethod
def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True):
def _as_bytes(
cls,
value: Any,
encode: bool = False,
max_length: Optional[int] = None,
empty_ok: bool = True,
) -> bytes:
if encode and isinstance(value, str):
value = value.encode()
bvalue = value.encode()
elif isinstance(value, bytearray):
value = bytes(value)
elif not isinstance(value, bytes):
raise ValueError('not bytes')
if max_length is not None and len(value) > max_length:
raise ValueError('too long')
if not empty_ok and len(value) == 0:
raise ValueError('empty bytes not allowed')
return value
bvalue = bytes(value)
elif isinstance(value, bytes):
bvalue = value
else:
raise ValueError("not bytes")
if max_length is not None and len(bvalue) > max_length:
raise ValueError("too long")
if not empty_ok and len(bvalue) == 0:
raise ValueError("empty bytes not allowed")
return bvalue
@classmethod
def _as_name(cls, value):
@ -429,49 +501,49 @@ class Rdata:
if isinstance(value, str):
return dns.name.from_text(value)
elif not isinstance(value, dns.name.Name):
raise ValueError('not a name')
raise ValueError("not a name")
return value
@classmethod
def _as_uint8(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
raise ValueError("not an integer")
if value < 0 or value > 255:
raise ValueError('not a uint8')
raise ValueError("not a uint8")
return value
@classmethod
def _as_uint16(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
raise ValueError("not an integer")
if value < 0 or value > 65535:
raise ValueError('not a uint16')
raise ValueError("not a uint16")
return value
@classmethod
def _as_uint32(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
raise ValueError("not an integer")
if value < 0 or value > 4294967295:
raise ValueError('not a uint32')
raise ValueError("not a uint32")
return value
@classmethod
def _as_uint48(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
raise ValueError("not an integer")
if value < 0 or value > 281474976710655:
raise ValueError('not a uint48')
raise ValueError("not a uint48")
return value
@classmethod
def _as_int(cls, value, low=None, high=None):
if not isinstance(value, int):
raise ValueError('not an integer')
raise ValueError("not an integer")
if low is not None and value < low:
raise ValueError('value too small')
raise ValueError("value too small")
if high is not None and value > high:
raise ValueError('value too large')
raise ValueError("value too large")
return value
@classmethod
@ -483,7 +555,7 @@ class Rdata:
elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value)
else:
raise ValueError('not an IPv4 address')
raise ValueError("not an IPv4 address")
@classmethod
def _as_ipv6_address(cls, value):
@ -494,14 +566,14 @@ class Rdata:
elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value)
else:
raise ValueError('not an IPv6 address')
raise ValueError("not an IPv6 address")
@classmethod
def _as_bool(cls, value):
if isinstance(value, bool):
return value
else:
raise ValueError('not a boolean')
raise ValueError("not a boolean")
@classmethod
def _as_ttl(cls, value):
@ -510,7 +582,7 @@ class Rdata:
elif isinstance(value, str):
return dns.ttl.from_text(value)
else:
raise ValueError('not a TTL')
raise ValueError("not a TTL")
@classmethod
def _as_tuple(cls, value, as_value):
@ -532,6 +604,7 @@ class Rdata:
return items
@dns.immutable.immutable
class GenericRdata(Rdata):
"""Generic Rdata Class
@ -540,28 +613,32 @@ class GenericRdata(Rdata):
implementation. It implements the DNS "unknown RRs" scheme.
"""
__slots__ = ['data']
__slots__ = ["data"]
def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype)
object.__setattr__(self, 'data', data)
self.data = data
def to_text(self, origin=None, relativize=True, **kw):
return r'\# %d ' % len(self.data) + _hexify(self.data, **kw)
def to_text(
self,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
**kw: Dict[str, Any]
) -> str:
return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
token = tok.get()
if not token.is_identifier() or token.value != r'\#':
raise dns.exception.SyntaxError(
r'generic rdata does not start with \#')
if not token.is_identifier() or token.value != r"\#":
raise dns.exception.SyntaxError(r"generic rdata does not start with \#")
length = tok.get_int()
hex = tok.concatenate_remaining_identifiers(True).encode()
data = binascii.unhexlify(hex)
if len(data) != length:
raise dns.exception.SyntaxError(
'generic rdata hex data has wrong length')
raise dns.exception.SyntaxError("generic rdata hex data has wrong length")
return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -571,8 +648,12 @@ class GenericRdata(Rdata):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
return cls(rdclass, rdtype, parser.get_remaining())
_rdata_classes = {}
_module_prefix = 'dns.rdtypes'
_rdata_classes: Dict[
Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any
] = {}
_module_prefix = "dns.rdtypes"
def get_rdata_class(rdclass, rdtype):
cls = _rdata_classes.get((rdclass, rdtype))
@ -581,16 +662,16 @@ def get_rdata_class(rdclass, rdtype):
if not cls:
rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace('-', '_')
rdtype_text = rdtype_text.replace("-", "_")
try:
mod = import_module('.'.join([_module_prefix,
rdclass_text, rdtype_text]))
mod = import_module(
".".join([_module_prefix, rdclass_text, rdtype_text])
)
cls = getattr(mod, rdtype_text)
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
try:
mod = import_module('.'.join([_module_prefix,
'ANY', rdtype_text]))
mod = import_module(".".join([_module_prefix, "ANY", rdtype_text]))
cls = getattr(mod, rdtype_text)
_rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
_rdata_classes[(rdclass, rdtype)] = cls
@ -602,8 +683,15 @@ def get_rdata_class(rdclass, rdtype):
return cls
def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None, idna_codec=None):
def from_text(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
tok: Union[dns.tokenizer.Tokenizer, str],
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
relativize_to: Optional[dns.name.Name] = None,
idna_codec: Optional[dns.name.IDNACodec] = None,
) -> Rdata:
"""Build an rdata object from text format.
This function attempts to dynamically load a class which
@ -617,9 +705,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
If *tok* is a ``str``, then a tokenizer is created and the string
is used as its input.
*rdclass*, an ``int``, the rdataclass.
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, an ``int``, the rdatatype.
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``.
@ -651,17 +739,18 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
# peek at first token
token = tok.get()
tok.unget(token)
if token.is_identifier() and \
token.value == r'\#':
if token.is_identifier() and token.value == r"\#":
#
# Known type using the generic syntax. Extract the
# wire form from the generic syntax, and then run
# from_wire on it.
#
grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
relativize, relativize_to)
rdata = from_wire(rdclass, rdtype, grdata.data, 0,
len(grdata.data), origin)
grdata = GenericRdata.from_text(
rdclass, rdtype, tok, origin, relativize, relativize_to
)
rdata = from_wire(
rdclass, rdtype, grdata.data, 0, len(grdata.data), origin
)
#
# If this comparison isn't equal, then there must have been
# compressed names in the wire format, which is an error,
@ -669,19 +758,27 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
#
rwire = rdata.to_wire()
if rwire != grdata.data:
raise dns.exception.SyntaxError('compressed data in '
'generic syntax form '
'of known rdatatype')
raise dns.exception.SyntaxError(
"compressed data in "
"generic syntax form "
"of known rdatatype"
)
if rdata is None:
rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
relativize_to)
rdata = cls.from_text(
rdclass, rdtype, tok, origin, relativize, relativize_to
)
token = tok.get_eol_as_token()
if token.comment is not None:
object.__setattr__(rdata, 'rdcomment', token.comment)
object.__setattr__(rdata, "rdcomment", token.comment)
return rdata
def from_wire_parser(rdclass, rdtype, parser, origin=None):
def from_wire_parser(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
parser: dns.wire.Parser,
origin: Optional[dns.name.Name] = None,
) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
@ -692,9 +789,9 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
Once a class is chosen, its from_wire() class method is called
with the parameters to this function.
*rdclass*, an ``int``, the rdataclass.
*rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass.
*rdtype*, an ``int``, the rdatatype.
*rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype.
*parser*, a ``dns.wire.Parser``, the parser, which should be
restricted to the rdata length.
@ -712,7 +809,14 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
def from_wire(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
wire: bytes,
current: int,
rdlen: int,
origin: Optional[dns.name.Name] = None,
) -> Rdata:
"""Build an rdata object from wire format
This function attempts to dynamically load a class which
@ -746,13 +850,21 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists."""
supp_kwargs = {'rdclass', 'rdtype'}
fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \
"already exists."
supp_kwargs = {"rdclass", "rdtype"}
fmt = (
"The rdata type with class {rdclass:d} and rdtype {rdtype:d} "
+ "already exists."
)
def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
rdclass=dns.rdataclass.IN):
def register_type(
implementation: Any,
rdtype: int,
rdtype_text: str,
is_singleton: bool = False,
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
) -> None:
"""Dynamically register a module to handle an rdatatype.
*implementation*, a module implementing the type in the usual dnspython
@ -769,14 +881,16 @@ def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
it applies to all classes.
"""
existing_cls = get_rdata_class(rdclass, rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
existing_cls = get_rdata_class(rdclass, the_rdtype)
if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
try:
if dns.rdatatype.RdataType(rdtype).name != rdtype_text:
raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
except ValueError:
pass
_rdata_classes[(rdclass, rdtype)] = getattr(implementation,
rdtype_text.replace('-', '_'))
dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
_rdata_classes[(rdclass, the_rdtype)] = getattr(
implementation, rdtype_text.replace("-", "_")
)
dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)

View file

@ -1,19 +0,0 @@
from typing import Dict, Tuple, Any, Optional, BinaryIO
from .name import Name, IDNACodec
class Rdata:
def __init__(self):
self.address : str
def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]:
...
@classmethod
def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True):
...
_rdata_modules : Dict[Tuple[Any,Rdata],Any]
def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None,
relativize : bool = True, relativize_to : Optional[Name] = None,
idna_codec : Optional[IDNACodec] = None):
...
def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None):
...

View file

@ -20,8 +20,10 @@
import dns.enum
import dns.exception
class RdataClass(dns.enum.IntEnum):
"""DNS Rdata Class"""
RESERVED0 = 0
IN = 1
INTERNET = IN
@ -56,7 +58,7 @@ class UnknownRdataclass(dns.exception.DNSException):
"""A DNS class is unknown."""
def from_text(text):
def from_text(text: str) -> RdataClass:
"""Convert text into a DNS rdata class value.
The input text can be a defined DNS RR class mnemonic or
@ -68,13 +70,13 @@ def from_text(text):
Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535.
Returns an ``int``.
Returns a ``dns.rdataclass.RdataClass``.
"""
return RdataClass.from_text(text)
def to_text(value):
def to_text(value: RdataClass) -> str:
"""Convert a DNS rdata class value to text.
If the value has a known mnemonic, it will be used, otherwise the
@ -88,18 +90,19 @@ def to_text(value):
return RdataClass.to_text(value)
def is_metaclass(rdclass):
def is_metaclass(rdclass: RdataClass) -> bool:
"""True if the specified class is a metaclass.
The currently defined metaclasses are ANY and NONE.
*rdclass* is an ``int``.
*rdclass* is a ``dns.rdataclass.RdataClass``.
"""
if rdclass in _metaclasses:
return True
return False
### BEGIN generated RdataClass constants
RESERVED0 = RdataClass.RESERVED0

View file

@ -17,16 +17,20 @@
"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)"""
from typing import Any, cast, Collection, Dict, List, Optional, Union
import io
import random
import struct
import dns.exception
import dns.immutable
import dns.name
import dns.rdatatype
import dns.rdataclass
import dns.rdata
import dns.set
import dns.ttl
# define SimpleSet here for backwards compatibility
SimpleSet = dns.set.Set
@ -45,24 +49,30 @@ class Rdataset(dns.set.Set):
"""A DNS rdataset."""
__slots__ = ['rdclass', 'rdtype', 'covers', 'ttl']
__slots__ = ["rdclass", "rdtype", "covers", "ttl"]
def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0):
def __init__(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
ttl: int = 0,
):
"""Create a new rdataset of the specified class and type.
*rdclass*, an ``int``, the rdataclass.
*rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
*rdtype*, an ``int``, the rdatatype.
*rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype.
*covers*, an ``int``, the covered rdatatype.
*covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype.
*ttl*, an ``int``, the TTL.
"""
super().__init__()
self.rdclass = rdclass
self.rdtype = rdtype
self.covers = covers
self.rdtype: dns.rdatatype.RdataType = rdtype
self.covers: dns.rdatatype.RdataType = covers
self.ttl = ttl
def _clone(self):
@ -73,7 +83,7 @@ class Rdataset(dns.set.Set):
obj.ttl = self.ttl
return obj
def update_ttl(self, ttl):
def update_ttl(self, ttl: int) -> None:
"""Perform TTL minimization.
Set the TTL of the rdataset to be the lesser of the set's current
@ -88,7 +98,9 @@ class Rdataset(dns.set.Set):
elif ttl < self.ttl:
self.ttl = ttl
def add(self, rd, ttl=None): # pylint: disable=arguments-differ
def add( # pylint: disable=arguments-differ,arguments-renamed
self, rd: dns.rdata.Rdata, ttl: Optional[int] = None
) -> None:
"""Add the specified rdata to the rdataset.
If the optional *ttl* parameter is supplied, then
@ -115,8 +127,7 @@ class Rdataset(dns.set.Set):
raise IncompatibleTypes
if ttl is not None:
self.update_ttl(ttl)
if self.rdtype == dns.rdatatype.RRSIG or \
self.rdtype == dns.rdatatype.SIG:
if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG:
covers = rd.covers()
if len(self) == 0 and self.covers == dns.rdatatype.NONE:
self.covers = covers
@ -147,19 +158,26 @@ class Rdataset(dns.set.Set):
def _rdata_repr(self):
def maybe_truncate(s):
if len(s) > 100:
return s[:100] + '...'
return s[:100] + "..."
return s
return '[%s]' % ', '.join('<%s>' % maybe_truncate(str(rr))
for rr in self)
return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
def __repr__(self):
if self.covers == 0:
ctext = ''
ctext = ""
else:
ctext = '(' + dns.rdatatype.to_text(self.covers) + ')'
return '<DNS ' + dns.rdataclass.to_text(self.rdclass) + ' ' + \
dns.rdatatype.to_text(self.rdtype) + ctext + \
' rdataset: ' + self._rdata_repr() + '>'
ctext = "(" + dns.rdatatype.to_text(self.covers) + ")"
return (
"<DNS "
+ dns.rdataclass.to_text(self.rdclass)
+ " "
+ dns.rdatatype.to_text(self.rdtype)
+ ctext
+ " rdataset: "
+ self._rdata_repr()
+ ">"
)
def __str__(self):
return self.to_text()
@ -167,17 +185,26 @@ class Rdataset(dns.set.Set):
def __eq__(self, other):
if not isinstance(other, Rdataset):
return False
if self.rdclass != other.rdclass or \
self.rdtype != other.rdtype or \
self.covers != other.covers:
if (
self.rdclass != other.rdclass
or self.rdtype != other.rdtype
or self.covers != other.covers
):
return False
return super().__eq__(other)
def __ne__(self, other):
return not self.__eq__(other)
def to_text(self, name=None, origin=None, relativize=True,
override_rdclass=None, want_comments=False, **kw):
def to_text(
self,
name: Optional[dns.name.Name] = None,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
want_comments: bool = False,
**kw: Dict[str, Any],
) -> str:
"""Convert the rdataset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information
@ -206,10 +233,10 @@ class Rdataset(dns.set.Set):
if name is not None:
name = name.choose_relativity(origin, relativize)
ntext = str(name)
pad = ' '
pad = " "
else:
ntext = ''
pad = ''
ntext = ""
pad = ""
s = io.StringIO()
if override_rdclass is not None:
rdclass = override_rdclass
@ -221,28 +248,46 @@ class Rdataset(dns.set.Set):
# some dynamic updates, so we don't need to print out the TTL
# (which is meaningless anyway).
#
s.write('{}{}{} {}\n'.format(ntext, pad,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype)))
s.write(
"{}{}{} {}\n".format(
ntext,
pad,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype),
)
)
else:
for rd in self:
extra = ''
extra = ""
if want_comments:
if rd.rdcomment:
extra = f' ;{rd.rdcomment}'
s.write('%s%s%d %s %s %s%s\n' %
(ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype),
rd.to_text(origin=origin, relativize=relativize,
**kw),
extra))
extra = f" ;{rd.rdcomment}"
s.write(
"%s%s%d %s %s %s%s\n"
% (
ntext,
pad,
self.ttl,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype),
rd.to_text(origin=origin, relativize=relativize, **kw),
extra,
)
)
#
# We strip off the final \n for the caller's convenience in printing
#
return s.getvalue()[:-1]
def to_wire(self, name, file, compress=None, origin=None,
override_rdclass=None, want_shuffle=True):
def to_wire(
self,
name: dns.name.Name,
file: Any,
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
override_rdclass: Optional[dns.rdataclass.RdataClass] = None,
want_shuffle: bool = True,
) -> int:
"""Convert the rdataset to wire format.
*name*, a ``dns.name.Name`` is the owner name to use.
@ -279,6 +324,7 @@ class Rdataset(dns.set.Set):
file.write(stuff)
return 1
else:
l: Union[Rdataset, List[dns.rdata.Rdata]]
if want_shuffle:
l = list(self)
random.shuffle(l)
@ -286,8 +332,7 @@ class Rdataset(dns.set.Set):
l = self
for rd in l:
name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass,
self.ttl, 0)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
file.write(stuff)
start = file.tell()
rd.to_wire(file, compress, origin)
@ -299,17 +344,20 @@ class Rdataset(dns.set.Set):
file.seek(0, io.SEEK_END)
return len(self)
def match(self, rdclass, rdtype, covers):
def match(
self,
rdclass: dns.rdataclass.RdataClass,
rdtype: dns.rdatatype.RdataType,
covers: dns.rdatatype.RdataType,
) -> bool:
"""Returns ``True`` if this rdataset matches the specified class,
type, and covers.
"""
if self.rdclass == rdclass and \
self.rdtype == rdtype and \
self.covers == covers:
if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers:
return True
return False
def processing_order(self):
def processing_order(self) -> List[dns.rdata.Rdata]:
"""Return rdatas in a valid processing order according to the type's
specification. For example, MX records are in preference order from
lowest to highest preferences, with items of the same preference
@ -325,51 +373,56 @@ class Rdataset(dns.set.Set):
@dns.immutable.immutable
class ImmutableRdataset(Rdataset):
class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals]
"""An immutable DNS rdataset."""
_clone_class = Rdataset
def __init__(self, rdataset):
def __init__(self, rdataset: Rdataset):
"""Create an immutable rdataset from the specified rdataset."""
super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
rdataset.ttl)
super().__init__(
rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl
)
self.items = dns.immutable.Dict(rdataset.items)
def update_ttl(self, ttl):
raise TypeError('immutable')
raise TypeError("immutable")
def add(self, rd, ttl=None):
raise TypeError('immutable')
raise TypeError("immutable")
def union_update(self, other):
raise TypeError('immutable')
raise TypeError("immutable")
def intersection_update(self, other):
raise TypeError('immutable')
raise TypeError("immutable")
def update(self, other):
raise TypeError('immutable')
raise TypeError("immutable")
def __delitem__(self, i):
raise TypeError('immutable')
raise TypeError("immutable")
def __ior__(self, other):
raise TypeError('immutable')
# lgtm complains about these not raising ArithmeticError, but there is
# precedent for overrides of these methods in other classes to raise
# TypeError, and it seems like the better exception.
def __iand__(self, other):
raise TypeError('immutable')
def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def __iadd__(self, other):
raise TypeError('immutable')
def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def __isub__(self, other):
raise TypeError('immutable')
def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method]
raise TypeError("immutable")
def clear(self):
raise TypeError('immutable')
raise TypeError("immutable")
def __copy__(self):
return ImmutableRdataset(super().copy())
@ -386,9 +439,20 @@ class ImmutableRdataset(Rdataset):
def difference(self, other):
return ImmutableRdataset(super().difference(other))
def symmetric_difference(self, other):
return ImmutableRdataset(super().symmetric_difference(other))
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
origin=None, relativize=True, relativize_to=None):
def from_text_list(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
ttl: int,
text_rdatas: Collection[str],
idna_codec: Optional[dns.name.IDNACodec] = None,
origin: Optional[dns.name.Name] = None,
relativize: bool = True,
relativize_to: Optional[dns.name.Name] = None,
) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified list of rdatas in text format.
@ -407,28 +471,34 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
Returns a ``dns.rdataset.Rdataset`` object.
"""
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
r = Rdataset(rdclass, rdtype)
the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
the_rdtype = dns.rdatatype.RdataType.make(rdtype)
r = Rdataset(the_rdclass, the_rdtype)
r.update_ttl(ttl)
for t in text_rdatas:
rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
relativize_to, idna_codec)
rd = dns.rdata.from_text(
r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec
)
r.add(rd)
return r
def from_text(rdclass, rdtype, ttl, *text_rdatas):
def from_text(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],
ttl: int,
*text_rdatas: Any,
) -> Rdataset:
"""Create an rdataset with the specified class, type, and TTL, and with
the specified rdatas in text format.
Returns a ``dns.rdataset.Rdataset`` object.
"""
return from_text_list(rdclass, rdtype, ttl, text_rdatas)
return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas))
def from_rdata_list(ttl, rdatas):
def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified list of rdata objects.
@ -443,14 +513,15 @@ def from_rdata_list(ttl, rdatas):
r = Rdataset(rd.rdclass, rd.rdtype)
r.update_ttl(ttl)
r.add(rd)
assert r is not None
return r
def from_rdata(ttl, *rdatas):
def from_rdata(ttl: int, *rdatas: Any) -> Rdataset:
"""Create an rdataset with the specified TTL, and with
the specified rdata objects.
Returns a ``dns.rdataset.Rdataset`` object.
"""
return from_rdata_list(ttl, rdatas)
return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas))

View file

@ -1,58 +0,0 @@
from typing import Optional, Dict, List, Union
from io import BytesIO
from . import exception, name, set, rdatatype, rdata, rdataset
class DifferingCovers(exception.DNSException):
"""An attempt was made to add a DNS SIG/RRSIG whose covered type
is not the same as that of the other rdatas in the rdataset."""
class IncompatibleTypes(exception.DNSException):
"""An attempt was made to add DNS RR data of an incompatible type."""
class Rdataset(set.Set):
def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0):
self.rdclass : int = rdclass
self.rdtype : int = rdtype
self.covers : int = covers
self.ttl : int = ttl
def update_ttl(self, ttl : int) -> None:
...
def add(self, rd : rdata.Rdata, ttl : Optional[int] =None):
...
def union_update(self, other : Rdataset):
...
def intersection_update(self, other : Rdataset):
...
def update(self, other : Rdataset):
...
def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True,
override_rdclass : Optional[int] =None, **kw) -> bytes:
...
def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None,
override_rdclass : Optional[int] = None, want_shuffle=True) -> int:
...
def match(self, rdclass : int, rdtype : int, covers : int) -> bool:
...
def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset:
...
def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset:
...
def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
...
def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
...

View file

@ -17,11 +17,15 @@
"""DNS Rdata Types."""
from typing import Dict
import dns.enum
import dns.exception
class RdataType(dns.enum.IntEnum):
"""DNS Rdata Type"""
TYPE0 = 0
NONE = 0
A = 1
@ -116,24 +120,47 @@ class RdataType(dns.enum.IntEnum):
def _prefix(cls):
return "TYPE"
@classmethod
def _extra_from_text(cls, text):
if text.find("-") >= 0:
try:
return cls[text.replace("-", "_")]
except KeyError:
pass
return _registered_by_text.get(text)
@classmethod
def _extra_to_text(cls, value, current_text):
if current_text is None:
return _registered_by_value.get(value)
if current_text.find("_") >= 0:
return current_text.replace("_", "-")
return current_text
@classmethod
def _unknown_exception_class(cls):
return UnknownRdatatype
_registered_by_text = {}
_registered_by_value = {}
_registered_by_text: Dict[str, RdataType] = {}
_registered_by_value: Dict[RdataType, str] = {}
_metatypes = {RdataType.OPT}
_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME,
RdataType.NSEC, RdataType.CNAME}
_singletons = {
RdataType.SOA,
RdataType.NXT,
RdataType.DNAME,
RdataType.NSEC,
RdataType.CNAME,
}
class UnknownRdatatype(dns.exception.DNSException):
"""DNS resource record type is unknown."""
def from_text(text):
def from_text(text: str) -> RdataType:
"""Convert text into a DNS rdata type value.
The input text can be a defined DNS RR type mnemonic or
@ -145,20 +172,13 @@ def from_text(text):
Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535.
Returns an ``int``.
Returns a ``dns.rdatatype.RdataType``.
"""
text = text.upper().replace('-', '_')
try:
return RdataType.from_text(text)
except UnknownRdatatype:
registered_type = _registered_by_text.get(text)
if registered_type:
return registered_type
raise
return RdataType.from_text(text)
def to_text(value):
def to_text(value: RdataType) -> str:
"""Convert a DNS rdata type value to text.
If the value has a known mnemonic, it will be used, otherwise the
@ -169,18 +189,13 @@ def to_text(value):
Returns a ``str``.
"""
text = RdataType.to_text(value)
if text.startswith("TYPE"):
registered_text = _registered_by_value.get(value)
if registered_text:
text = registered_text
return text.replace('_', '-')
return RdataType.to_text(value)
def is_metatype(rdtype):
def is_metatype(rdtype: RdataType) -> bool:
"""True if the specified type is a metatype.
*rdtype* is an ``int``.
*rdtype* is a ``dns.rdatatype.RdataType``.
The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA,
MAILB, ANY, and OPT.
@ -191,7 +206,7 @@ def is_metatype(rdtype):
return (256 > rdtype >= 128) or rdtype in _metatypes
def is_singleton(rdtype):
def is_singleton(rdtype: RdataType) -> bool:
"""Is the specified type a singleton type?
Singleton types can only have a single rdata in an rdataset, or a single
@ -209,11 +224,14 @@ def is_singleton(rdtype):
return True
return False
# pylint: disable=redefined-outer-name
def register_type(rdtype, rdtype_text, is_singleton=False):
def register_type(
rdtype: RdataType, rdtype_text: str, is_singleton: bool = False
) -> None:
"""Dynamically register an rdatatype.
*rdtype*, an ``int``, the rdatatype to register.
*rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register.
*rdtype_text*, a ``str``, the textual form of the rdatatype.
@ -226,6 +244,7 @@ def register_type(rdtype, rdtype_text, is_singleton=False):
if is_singleton:
_singletons.add(rdtype)
### BEGIN generated RdataType constants
TYPE0 = RdataType.TYPE0

View file

@ -23,7 +23,7 @@ import dns.rdtypes.util
class Relay(dns.rdtypes.util.Gateway):
name = 'AMTRELAY relay'
name = "AMTRELAY relay"
@property
def relay(self):
@ -37,10 +37,11 @@ class AMTRELAY(dns.rdata.Rdata):
# see: RFC 8777
__slots__ = ['precedence', 'discovery_optional', 'relay_type', 'relay']
__slots__ = ["precedence", "discovery_optional", "relay_type", "relay"]
def __init__(self, rdclass, rdtype, precedence, discovery_optional,
relay_type, relay):
def __init__(
self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay
):
super().__init__(rdclass, rdtype)
relay = Relay(relay_type, relay)
self.precedence = self._as_uint8(precedence)
@ -50,37 +51,42 @@ class AMTRELAY(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
return '%d %d %d %s' % (self.precedence, self.discovery_optional,
self.relay_type, relay)
return "%d %d %d %s" % (
self.precedence,
self.discovery_optional,
self.relay_type,
relay,
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
precedence = tok.get_uint8()
discovery_optional = tok.get_uint8()
if discovery_optional > 1:
raise dns.exception.SyntaxError('expecting 0 or 1')
raise dns.exception.SyntaxError("expecting 0 or 1")
discovery_optional = bool(discovery_optional)
relay_type = tok.get_uint8()
if relay_type > 0x7f:
raise dns.exception.SyntaxError('expecting an integer <= 127')
relay = Relay.from_text(relay_type, tok, origin, relativize,
relativize_to)
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
relay.relay)
if relay_type > 0x7F:
raise dns.exception.SyntaxError("expecting an integer <= 127")
relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to)
return cls(
rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
relay_type = self.relay_type | (self.discovery_optional << 7)
header = struct.pack("!BB", self.precedence, relay_type)
file.write(header)
Relay(self.relay_type, self.relay).to_wire(file, compress, origin,
canonicalize)
Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(precedence, relay_type) = parser.get_struct('!BB')
(precedence, relay_type) = parser.get_struct("!BB")
discovery_optional = bool(relay_type >> 7)
relay_type &= 0x7f
relay_type &= 0x7F
relay = Relay.from_wire_parser(relay_type, parser, origin)
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
relay.relay)
return cls(
rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay
)

View file

@ -30,7 +30,7 @@ class CAA(dns.rdata.Rdata):
# see: RFC 6844
__slots__ = ['flags', 'tag', 'value']
__slots__ = ["flags", "tag", "value"]
def __init__(self, rdclass, rdtype, flags, tag, value):
super().__init__(rdclass, rdtype)
@ -41,23 +41,26 @@ class CAA(dns.rdata.Rdata):
self.value = self._as_bytes(value)
def to_text(self, origin=None, relativize=True, **kw):
return '%u %s "%s"' % (self.flags,
dns.rdata._escapify(self.tag),
dns.rdata._escapify(self.value))
return '%u %s "%s"' % (
self.flags,
dns.rdata._escapify(self.tag),
dns.rdata._escapify(self.value),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
flags = tok.get_uint8()
tag = tok.get_string().encode()
value = tok.get_string().encode()
return cls(rdclass, rdtype, flags, tag, value)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!B', self.flags))
file.write(struct.pack("!B", self.flags))
l = len(self.tag)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.tag)
file.write(self.value)

View file

@ -15,13 +15,19 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable
# pylint: disable=unused-import
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401
from dns.rdtypes.dnskeybase import (
SEP,
REVOKE,
ZONE,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import
@dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):

View file

@ -20,34 +20,34 @@ import base64
import dns.exception
import dns.immutable
import dns.dnssec
import dns.dnssectypes
import dns.rdata
import dns.tokenizer
_ctype_by_value = {
1: 'PKIX',
2: 'SPKI',
3: 'PGP',
4: 'IPKIX',
5: 'ISPKI',
6: 'IPGP',
7: 'ACPKIX',
8: 'IACPKIX',
253: 'URI',
254: 'OID',
1: "PKIX",
2: "SPKI",
3: "PGP",
4: "IPKIX",
5: "ISPKI",
6: "IPGP",
7: "ACPKIX",
8: "IACPKIX",
253: "URI",
254: "OID",
}
_ctype_by_name = {
'PKIX': 1,
'SPKI': 2,
'PGP': 3,
'IPKIX': 4,
'ISPKI': 5,
'IPGP': 6,
'ACPKIX': 7,
'IACPKIX': 8,
'URI': 253,
'OID': 254,
"PKIX": 1,
"SPKI": 2,
"PGP": 3,
"IPKIX": 4,
"ISPKI": 5,
"IPGP": 6,
"ACPKIX": 7,
"IACPKIX": 8,
"URI": 253,
"OID": 254,
}
@ -72,10 +72,11 @@ class CERT(dns.rdata.Rdata):
# see RFC 4398
__slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate']
__slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"]
def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm,
certificate):
def __init__(
self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate
):
super().__init__(rdclass, rdtype)
self.certificate_type = self._as_uint16(certificate_type)
self.key_tag = self._as_uint16(key_tag)
@ -84,24 +85,28 @@ class CERT(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
certificate_type = _ctype_to_text(self.certificate_type)
return "%s %d %s %s" % (certificate_type, self.key_tag,
dns.dnssec.algorithm_to_text(self.algorithm),
dns.rdata._base64ify(self.certificate, **kw))
return "%s %d %s %s" % (
certificate_type,
self.key_tag,
dns.dnssectypes.Algorithm.to_text(self.algorithm),
dns.rdata._base64ify(self.certificate, **kw),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
certificate_type = _ctype_from_text(tok.get_string())
key_tag = tok.get_uint16()
algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
b64 = tok.concatenate_remaining_identifiers().encode()
certificate = base64.b64decode(b64)
return cls(rdclass, rdtype, certificate_type, key_tag,
algorithm, certificate)
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
prefix = struct.pack("!HHB", self.certificate_type, self.key_tag,
self.algorithm)
prefix = struct.pack(
"!HHB", self.certificate_type, self.key_tag, self.algorithm
)
file.write(prefix)
file.write(self.certificate)
@ -109,5 +114,4 @@ class CERT(dns.rdata.Rdata):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(certificate_type, key_tag, algorithm) = parser.get_struct("!HHB")
certificate = parser.get_remaining()
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm,
certificate)
return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate)

View file

@ -27,7 +27,7 @@ import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'CSYNC'
type_name = "CSYNC"
@dns.immutable.immutable
@ -35,7 +35,7 @@ class CSYNC(dns.rdata.Rdata):
"""CSYNC record"""
__slots__ = ['serial', 'flags', 'windows']
__slots__ = ["serial", "flags", "windows"]
def __init__(self, rdclass, rdtype, serial, flags, windows):
super().__init__(rdclass, rdtype)
@ -47,18 +47,19 @@ class CSYNC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
text = Bitmap(self.windows).to_text()
return '%d %d%s' % (self.serial, self.flags, text)
return "%d %d%s" % (self.serial, self.flags, text)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
serial = tok.get_uint32()
flags = tok.get_uint16()
bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, serial, flags, bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!IH', self.serial, self.flags))
file.write(struct.pack("!IH", self.serial, self.flags))
Bitmap(self.windows).to_wire(file)
@classmethod

View file

@ -15,13 +15,19 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase
import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from]
import dns.immutable
# pylint: disable=unused-import
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401
from dns.rdtypes.dnskeybase import (
SEP,
REVOKE,
ZONE,
) # noqa: F401 lgtm[py/unused-import]
# pylint: enable=unused-import
@dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):

View file

@ -26,19 +26,19 @@ import dns.tokenizer
def _validate_float_string(what):
if len(what) == 0:
raise dns.exception.FormError
if what[0] == b'-'[0] or what[0] == b'+'[0]:
if what[0] == b"-"[0] or what[0] == b"+"[0]:
what = what[1:]
if what.isdigit():
return
try:
(left, right) = what.split(b'.')
(left, right) = what.split(b".")
except ValueError:
raise dns.exception.FormError
if left == b'' and right == b'':
if left == b"" and right == b"":
raise dns.exception.FormError
if not left == b'' and not left.decode().isdigit():
if not left == b"" and not left.decode().isdigit():
raise dns.exception.FormError
if not right == b'' and not right.decode().isdigit():
if not right == b"" and not right.decode().isdigit():
raise dns.exception.FormError
@ -49,18 +49,15 @@ class GPOS(dns.rdata.Rdata):
# see: RFC 1712
__slots__ = ['latitude', 'longitude', 'altitude']
__slots__ = ["latitude", "longitude", "altitude"]
def __init__(self, rdclass, rdtype, latitude, longitude, altitude):
super().__init__(rdclass, rdtype)
if isinstance(latitude, float) or \
isinstance(latitude, int):
if isinstance(latitude, float) or isinstance(latitude, int):
latitude = str(latitude)
if isinstance(longitude, float) or \
isinstance(longitude, int):
if isinstance(longitude, float) or isinstance(longitude, int):
longitude = str(longitude)
if isinstance(altitude, float) or \
isinstance(altitude, int):
if isinstance(altitude, float) or isinstance(altitude, int):
altitude = str(altitude)
latitude = self._as_bytes(latitude, True, 255)
longitude = self._as_bytes(longitude, True, 255)
@ -73,19 +70,20 @@ class GPOS(dns.rdata.Rdata):
self.altitude = altitude
flat = self.float_latitude
if flat < -90.0 or flat > 90.0:
raise dns.exception.FormError('bad latitude')
raise dns.exception.FormError("bad latitude")
flong = self.float_longitude
if flong < -180.0 or flong > 180.0:
raise dns.exception.FormError('bad longitude')
raise dns.exception.FormError("bad longitude")
def to_text(self, origin=None, relativize=True, **kw):
return '{} {} {}'.format(self.latitude.decode(),
self.longitude.decode(),
self.altitude.decode())
return "{} {} {}".format(
self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
latitude = tok.get_string()
longitude = tok.get_string()
altitude = tok.get_string()
@ -94,15 +92,15 @@ class GPOS(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.latitude)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.latitude)
l = len(self.longitude)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.longitude)
l = len(self.altitude)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.altitude)
@classmethod

View file

@ -30,7 +30,7 @@ class HINFO(dns.rdata.Rdata):
# see: RFC 1035
__slots__ = ['cpu', 'os']
__slots__ = ["cpu", "os"]
def __init__(self, rdclass, rdtype, cpu, os):
super().__init__(rdclass, rdtype)
@ -38,12 +38,14 @@ class HINFO(dns.rdata.Rdata):
self.os = self._as_bytes(os, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu),
dns.rdata._escapify(self.os))
return '"{}" "{}"'.format(
dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
cpu = tok.get_string(max_length=255)
os = tok.get_string(max_length=255)
return cls(rdclass, rdtype, cpu, os)
@ -51,11 +53,11 @@ class HINFO(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.cpu)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.cpu)
l = len(self.os)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.os)
@classmethod

View file

@ -32,7 +32,7 @@ class HIP(dns.rdata.Rdata):
# see: RFC 5205
__slots__ = ['hit', 'algorithm', 'key', 'servers']
__slots__ = ["hit", "algorithm", "key", "servers"]
def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
super().__init__(rdclass, rdtype)
@ -43,18 +43,19 @@ class HIP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
hit = binascii.hexlify(self.hit).decode()
key = base64.b64encode(self.key).replace(b'\n', b'').decode()
text = ''
key = base64.b64encode(self.key).replace(b"\n", b"").decode()
text = ""
servers = []
for server in self.servers:
servers.append(server.choose_relativity(origin, relativize))
if len(servers) > 0:
text += (' ' + ' '.join((x.to_unicode() for x in servers)))
return '%u %s %s%s' % (self.algorithm, hit, key, text)
text += " " + " ".join((x.to_unicode() for x in servers))
return "%u %s %s%s" % (self.algorithm, hit, key, text)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
hit = binascii.unhexlify(tok.get_string().encode())
key = base64.b64decode(tok.get_string().encode())
@ -75,7 +76,7 @@ class HIP(dns.rdata.Rdata):
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(lh, algorithm, lk) = parser.get_struct('!BBH')
(lh, algorithm, lk) = parser.get_struct("!BBH")
hit = parser.get_bytes(lh)
key = parser.get_bytes(lk)
servers = []

View file

@ -30,7 +30,7 @@ class ISDN(dns.rdata.Rdata):
# see: RFC 1183
__slots__ = ['address', 'subaddress']
__slots__ = ["address", "subaddress"]
def __init__(self, rdclass, rdtype, address, subaddress):
super().__init__(rdclass, rdtype)
@ -39,31 +39,33 @@ class ISDN(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress:
return '"{}" "{}"'.format(dns.rdata._escapify(self.address),
dns.rdata._escapify(self.subaddress))
return '"{}" "{}"'.format(
dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
)
else:
return '"%s"' % dns.rdata._escapify(self.address)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string()
tokens = tok.get_remaining(max_tokens=1)
if len(tokens) >= 1:
subaddress = tokens[0].unescape().value
else:
subaddress = ''
subaddress = ""
return cls(rdclass, rdtype, address, subaddress)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.address)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.address)
l = len(self.subaddress)
if l > 0:
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.subaddress)
@classmethod
@ -72,5 +74,5 @@ class ISDN(dns.rdata.Rdata):
if parser.remaining() > 0:
subaddress = parser.get_counted_bytes()
else:
subaddress = b''
subaddress = b""
return cls(rdclass, rdtype, address, subaddress)

View file

@ -3,6 +3,7 @@
import struct
import dns.immutable
import dns.rdata
@dns.immutable.immutable
@ -12,7 +13,7 @@ class L32(dns.rdata.Rdata):
# see: rfc6742.txt
__slots__ = ['preference', 'locator32']
__slots__ = ["preference", "locator32"]
def __init__(self, rdclass, rdtype, preference, locator32):
super().__init__(rdclass, rdtype)
@ -20,17 +21,18 @@ class L32(dns.rdata.Rdata):
self.locator32 = self._as_ipv4_address(locator32)
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.locator32}'
return f"{self.preference} {self.locator32}"
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(struct.pack("!H", self.preference))
file.write(dns.ipv4.inet_aton(self.locator32))
@classmethod

View file

@ -13,33 +13,33 @@ class L64(dns.rdata.Rdata):
# see: rfc6742.txt
__slots__ = ['preference', 'locator64']
__slots__ = ["preference", "locator64"]
def __init__(self, rdclass, rdtype, preference, locator64):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(locator64, bytes):
if len(locator64) != 8:
raise ValueError('invalid locator64')
self.locator64 = dns.rdata._hexify(locator64, 4, b':')
raise ValueError("invalid locator64")
self.locator64 = dns.rdata._hexify(locator64, 4, b":")
else:
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':')
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":")
self.locator64 = locator64
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.locator64}'
return f"{self.preference} {self.locator64}"
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
locator64 = tok.get_identifier()
return cls(rdclass, rdtype, preference, locator64)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64,
4, 4, ':'))
file.write(struct.pack("!H", self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":"))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -93,15 +93,15 @@ def _decode_size(what, desc):
def _check_coordinate_list(value, low, high):
if value[0] < low or value[0] > high:
raise ValueError(f'not in range [{low}, {high}]')
raise ValueError(f"not in range [{low}, {high}]")
if value[1] < 0 or value[1] > 59:
raise ValueError('bad minutes value')
raise ValueError("bad minutes value")
if value[2] < 0 or value[2] > 59:
raise ValueError('bad seconds value')
raise ValueError("bad seconds value")
if value[3] < 0 or value[3] > 999:
raise ValueError('bad milliseconds value')
raise ValueError("bad milliseconds value")
if value[4] != 1 and value[4] != -1:
raise ValueError('bad hemisphere value')
raise ValueError("bad hemisphere value")
@dns.immutable.immutable
@ -111,12 +111,26 @@ class LOC(dns.rdata.Rdata):
# see: RFC 1876
__slots__ = ['latitude', 'longitude', 'altitude', 'size',
'horizontal_precision', 'vertical_precision']
__slots__ = [
"latitude",
"longitude",
"altitude",
"size",
"horizontal_precision",
"vertical_precision",
]
def __init__(self, rdclass, rdtype, latitude, longitude, altitude,
size=_default_size, hprec=_default_hprec,
vprec=_default_vprec):
def __init__(
self,
rdclass,
rdtype,
latitude,
longitude,
altitude,
size=_default_size,
hprec=_default_hprec,
vprec=_default_vprec,
):
"""Initialize a LOC record instance.
The parameters I{latitude} and I{longitude} may be either a 4-tuple
@ -145,34 +159,44 @@ class LOC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
if self.latitude[4] > 0:
lat_hemisphere = 'N'
lat_hemisphere = "N"
else:
lat_hemisphere = 'S'
lat_hemisphere = "S"
if self.longitude[4] > 0:
long_hemisphere = 'E'
long_hemisphere = "E"
else:
long_hemisphere = 'W'
long_hemisphere = "W"
text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % (
self.latitude[0], self.latitude[1],
self.latitude[2], self.latitude[3], lat_hemisphere,
self.longitude[0], self.longitude[1], self.longitude[2],
self.longitude[3], long_hemisphere,
self.altitude / 100.0
self.latitude[0],
self.latitude[1],
self.latitude[2],
self.latitude[3],
lat_hemisphere,
self.longitude[0],
self.longitude[1],
self.longitude[2],
self.longitude[3],
long_hemisphere,
self.altitude / 100.0,
)
# do not print default values
if self.size != _default_size or \
self.horizontal_precision != _default_hprec or \
self.vertical_precision != _default_vprec:
if (
self.size != _default_size
or self.horizontal_precision != _default_hprec
or self.vertical_precision != _default_vprec
):
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
self.size / 100.0, self.horizontal_precision / 100.0,
self.vertical_precision / 100.0
self.size / 100.0,
self.horizontal_precision / 100.0,
self.vertical_precision / 100.0,
)
return text
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
latitude = [0, 0, 0, 0, 1]
longitude = [0, 0, 0, 0, 1]
size = _default_size
@ -184,16 +208,14 @@ class LOC(dns.rdata.Rdata):
if t.isdigit():
latitude[1] = int(t)
t = tok.get_string()
if '.' in t:
(seconds, milliseconds) = t.split('.')
if "." in t:
(seconds, milliseconds) = t.split(".")
if not seconds.isdigit():
raise dns.exception.SyntaxError(
'bad latitude seconds value')
raise dns.exception.SyntaxError("bad latitude seconds value")
latitude[2] = int(seconds)
l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError(
'bad latitude milliseconds value')
raise dns.exception.SyntaxError("bad latitude milliseconds value")
if l == 1:
m = 100
elif l == 2:
@ -205,26 +227,24 @@ class LOC(dns.rdata.Rdata):
elif t.isdigit():
latitude[2] = int(t)
t = tok.get_string()
if t == 'S':
if t == "S":
latitude[4] = -1
elif t != 'N':
raise dns.exception.SyntaxError('bad latitude hemisphere value')
elif t != "N":
raise dns.exception.SyntaxError("bad latitude hemisphere value")
longitude[0] = tok.get_int()
t = tok.get_string()
if t.isdigit():
longitude[1] = int(t)
t = tok.get_string()
if '.' in t:
(seconds, milliseconds) = t.split('.')
if "." in t:
(seconds, milliseconds) = t.split(".")
if not seconds.isdigit():
raise dns.exception.SyntaxError(
'bad longitude seconds value')
raise dns.exception.SyntaxError("bad longitude seconds value")
longitude[2] = int(seconds)
l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError(
'bad longitude milliseconds value')
raise dns.exception.SyntaxError("bad longitude milliseconds value")
if l == 1:
m = 100
elif l == 2:
@ -236,64 +256,75 @@ class LOC(dns.rdata.Rdata):
elif t.isdigit():
longitude[2] = int(t)
t = tok.get_string()
if t == 'W':
if t == "W":
longitude[4] = -1
elif t != 'E':
raise dns.exception.SyntaxError('bad longitude hemisphere value')
elif t != "E":
raise dns.exception.SyntaxError("bad longitude hemisphere value")
t = tok.get_string()
if t[-1] == 'm':
t = t[0: -1]
altitude = float(t) * 100.0 # m -> cm
if t[-1] == "m":
t = t[0:-1]
altitude = float(t) * 100.0 # m -> cm
tokens = tok.get_remaining(max_tokens=3)
if len(tokens) >= 1:
value = tokens[0].unescape().value
if value[-1] == 'm':
value = value[0: -1]
size = float(value) * 100.0 # m -> cm
if value[-1] == "m":
value = value[0:-1]
size = float(value) * 100.0 # m -> cm
if len(tokens) >= 2:
value = tokens[1].unescape().value
if value[-1] == 'm':
value = value[0: -1]
hprec = float(value) * 100.0 # m -> cm
if value[-1] == "m":
value = value[0:-1]
hprec = float(value) * 100.0 # m -> cm
if len(tokens) >= 3:
value = tokens[2].unescape().value
if value[-1] == 'm':
value = value[0: -1]
vprec = float(value) * 100.0 # m -> cm
if value[-1] == "m":
value = value[0:-1]
vprec = float(value) * 100.0 # m -> cm
# Try encoding these now so we raise if they are bad
_encode_size(size, "size")
_encode_size(hprec, "horizontal precision")
_encode_size(vprec, "vertical precision")
return cls(rdclass, rdtype, latitude, longitude, altitude,
size, hprec, vprec)
return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
milliseconds = (self.latitude[0] * 3600000 +
self.latitude[1] * 60000 +
self.latitude[2] * 1000 +
self.latitude[3]) * self.latitude[4]
milliseconds = (
self.latitude[0] * 3600000
+ self.latitude[1] * 60000
+ self.latitude[2] * 1000
+ self.latitude[3]
) * self.latitude[4]
latitude = 0x80000000 + milliseconds
milliseconds = (self.longitude[0] * 3600000 +
self.longitude[1] * 60000 +
self.longitude[2] * 1000 +
self.longitude[3]) * self.longitude[4]
milliseconds = (
self.longitude[0] * 3600000
+ self.longitude[1] * 60000
+ self.longitude[2] * 1000
+ self.longitude[3]
) * self.longitude[4]
longitude = 0x80000000 + milliseconds
altitude = int(self.altitude) + 10000000
size = _encode_size(self.size, "size")
hprec = _encode_size(self.horizontal_precision, "horizontal precision")
vprec = _encode_size(self.vertical_precision, "vertical precision")
wire = struct.pack("!BBBBIII", 0, size, hprec, vprec, latitude,
longitude, altitude)
wire = struct.pack(
"!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude
)
file.write(wire)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(version, size, hprec, vprec, latitude, longitude, altitude) = \
parser.get_struct("!BBBBIII")
(
version,
size,
hprec,
vprec,
latitude,
longitude,
altitude,
) = parser.get_struct("!BBBBIII")
if version != 0:
raise dns.exception.FormError("LOC version not zero")
if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
@ -312,8 +343,7 @@ class LOC(dns.rdata.Rdata):
size = _decode_size(size, "size")
hprec = _decode_size(hprec, "horizontal precision")
vprec = _decode_size(vprec, "vertical precision")
return cls(rdclass, rdtype, latitude, longitude, altitude,
size, hprec, vprec)
return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec)
@property
def float_latitude(self):

View file

@ -3,6 +3,7 @@
import struct
import dns.immutable
import dns.rdata
@dns.immutable.immutable
@ -12,7 +13,7 @@ class LP(dns.rdata.Rdata):
# see: rfc6742.txt
__slots__ = ['preference', 'fqdn']
__slots__ = ["preference", "fqdn"]
def __init__(self, rdclass, rdtype, preference, fqdn):
super().__init__(rdclass, rdtype)
@ -21,17 +22,18 @@ class LP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
fqdn = self.fqdn.choose_relativity(origin, relativize)
return '%d %s' % (self.preference, fqdn)
return "%d %s" % (self.preference, fqdn)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
fqdn = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, preference, fqdn)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(struct.pack("!H", self.preference))
self.fqdn.to_wire(file, compress, origin, canonicalize)
@classmethod

View file

@ -13,32 +13,33 @@ class NID(dns.rdata.Rdata):
# see: rfc6742.txt
__slots__ = ['preference', 'nodeid']
__slots__ = ["preference", "nodeid"]
def __init__(self, rdclass, rdtype, preference, nodeid):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(nodeid, bytes):
if len(nodeid) != 8:
raise ValueError('invalid nodeid')
self.nodeid = dns.rdata._hexify(nodeid, 4, b':')
raise ValueError("invalid nodeid")
self.nodeid = dns.rdata._hexify(nodeid, 4, b":")
else:
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':')
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":")
self.nodeid = nodeid
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.nodeid}'
return f"{self.preference} {self.nodeid}"
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':'))
file.write(struct.pack("!H", self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":"))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):

View file

@ -25,7 +25,7 @@ import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'NSEC'
type_name = "NSEC"
@dns.immutable.immutable
@ -33,7 +33,7 @@ class NSEC(dns.rdata.Rdata):
"""NSEC record"""
__slots__ = ['next', 'windows']
__slots__ = ["next", "windows"]
def __init__(self, rdclass, rdtype, next, windows):
super().__init__(rdclass, rdtype)
@ -45,11 +45,12 @@ class NSEC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize)
text = Bitmap(self.windows).to_text()
return '{}{}'.format(next, text)
return "{}{}".format(next, text)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
next = tok.get_name(origin, relativize, relativize_to)
windows = Bitmap.from_text(tok)
return cls(rdclass, rdtype, next, windows)

View file

@ -26,10 +26,12 @@ import dns.rdatatype
import dns.rdtypes.util
b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV',
b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567')
b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567',
b'0123456789ABCDEFGHIJKLMNOPQRSTUV')
b32_hex_to_normal = bytes.maketrans(
b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
)
b32_normal_to_hex = bytes.maketrans(
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV"
)
# hash algorithm constants
SHA1 = 1
@ -40,7 +42,7 @@ OPTOUT = 1
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'NSEC3'
type_name = "NSEC3"
@dns.immutable.immutable
@ -48,10 +50,11 @@ class NSEC3(dns.rdata.Rdata):
"""NSEC3 record"""
__slots__ = ['algorithm', 'flags', 'iterations', 'salt', 'next', 'windows']
__slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"]
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt,
next, windows):
def __init__(
self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows
):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.flags = self._as_uint8(flags)
@ -63,38 +66,41 @@ class NSEC3(dns.rdata.Rdata):
self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw):
next = base64.b32encode(self.next).translate(
b32_normal_to_hex).lower().decode()
if self.salt == b'':
salt = '-'
next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode()
if self.salt == b"":
salt = "-"
else:
salt = binascii.hexlify(self.salt).decode()
text = Bitmap(self.windows).to_text()
return '%u %u %u %s %s%s' % (self.algorithm, self.flags,
self.iterations, salt, next, text)
return "%u %u %u %s %s%s" % (
self.algorithm,
self.flags,
self.iterations,
salt,
next,
text,
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
flags = tok.get_uint8()
iterations = tok.get_uint16()
salt = tok.get_string()
if salt == '-':
salt = b''
if salt == "-":
salt = b""
else:
salt = binascii.unhexlify(salt.encode('ascii'))
next = tok.get_string().encode(
'ascii').upper().translate(b32_hex_to_normal)
salt = binascii.unhexlify(salt.encode("ascii"))
next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal)
next = base64.b32decode(next)
bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
bitmap)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt)
file.write(struct.pack("!BBHB", self.algorithm, self.flags,
self.iterations, l))
file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
file.write(self.salt)
l = len(self.next)
file.write(struct.pack("!B", l))
@ -103,9 +109,8 @@ class NSEC3(dns.rdata.Rdata):
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(algorithm, flags, iterations) = parser.get_struct('!BBH')
(algorithm, flags, iterations) = parser.get_struct("!BBH")
salt = parser.get_counted_bytes()
next = parser.get_counted_bytes()
bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
bitmap)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap)

View file

@ -28,7 +28,7 @@ class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record"""
__slots__ = ['algorithm', 'flags', 'iterations', 'salt']
__slots__ = ["algorithm", "flags", "iterations", "salt"]
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
super().__init__(rdclass, rdtype)
@ -38,34 +38,33 @@ class NSEC3PARAM(dns.rdata.Rdata):
self.salt = self._as_bytes(salt, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
if self.salt == b'':
salt = '-'
if self.salt == b"":
salt = "-"
else:
salt = binascii.hexlify(self.salt).decode()
return '%u %u %u %s' % (self.algorithm, self.flags, self.iterations,
salt)
return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
flags = tok.get_uint8()
iterations = tok.get_uint16()
salt = tok.get_string()
if salt == '-':
salt = ''
if salt == "-":
salt = ""
else:
salt = binascii.unhexlify(salt.encode())
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt)
file.write(struct.pack("!BBHB", self.algorithm, self.flags,
self.iterations, l))
file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l))
file.write(self.salt)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(algorithm, flags, iterations) = parser.get_struct('!BBH')
(algorithm, flags, iterations) = parser.get_struct("!BBH")
salt = parser.get_counted_bytes()
return cls(rdclass, rdtype, algorithm, flags, iterations, salt)

View file

@ -22,6 +22,7 @@ import dns.immutable
import dns.rdata
import dns.tokenizer
@dns.immutable.immutable
class OPENPGPKEY(dns.rdata.Rdata):
@ -37,8 +38,9 @@ class OPENPGPKEY(dns.rdata.Rdata):
return dns.rdata._base64ify(self.key, chunksize=None, **kw)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
b64 = tok.concatenate_remaining_identifiers().encode()
key = base64.b64decode(b64)
return cls(rdclass, rdtype, key)

View file

@ -26,12 +26,13 @@ import dns.rdata
# We don't implement from_text, and that's ok.
# pylint: disable=abstract-method
@dns.immutable.immutable
class OPT(dns.rdata.Rdata):
"""OPT record"""
__slots__ = ['options']
__slots__ = ["options"]
def __init__(self, rdclass, rdtype, options):
"""Initialize an OPT rdata.
@ -45,10 +46,12 @@ class OPT(dns.rdata.Rdata):
"""
super().__init__(rdclass, rdtype)
def as_option(option):
if not isinstance(option, dns.edns.Option):
raise ValueError('option is not a dns.edns.option')
raise ValueError("option is not a dns.edns.option")
return option
self.options = self._as_tuple(options, as_option)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -58,13 +61,13 @@ class OPT(dns.rdata.Rdata):
file.write(owire)
def to_text(self, origin=None, relativize=True, **kw):
return ' '.join(opt.to_text() for opt in self.options)
return " ".join(opt.to_text() for opt in self.options)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
options = []
while parser.remaining() > 0:
(otype, olen) = parser.get_struct('!HH')
(otype, olen) = parser.get_struct("!HH")
with parser.restrict_to(olen):
opt = dns.edns.option_from_wire_parser(otype, parser)
options.append(opt)

View file

@ -28,7 +28,7 @@ class RP(dns.rdata.Rdata):
# see: RFC 1183
__slots__ = ['mbox', 'txt']
__slots__ = ["mbox", "txt"]
def __init__(self, rdclass, rdtype, mbox, txt):
super().__init__(rdclass, rdtype)
@ -41,8 +41,9 @@ class RP(dns.rdata.Rdata):
return "{} {}".format(str(mbox), str(txt))
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
mbox = tok.get_name(origin, relativize, relativize_to)
txt = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, mbox, txt)

View file

@ -20,7 +20,7 @@ import calendar
import struct
import time
import dns.dnssec
import dns.dnssectypes
import dns.immutable
import dns.exception
import dns.rdata
@ -43,12 +43,11 @@ def sigtime_to_posixtime(what):
hour = int(what[8:10])
minute = int(what[10:12])
second = int(what[12:14])
return calendar.timegm((year, month, day, hour, minute, second,
0, 0, 0))
return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0))
def posixtime_to_sigtime(what):
return time.strftime('%Y%m%d%H%M%S', time.gmtime(what))
return time.strftime("%Y%m%d%H%M%S", time.gmtime(what))
@dns.immutable.immutable
@ -56,16 +55,35 @@ class RRSIG(dns.rdata.Rdata):
"""RRSIG record"""
__slots__ = ['type_covered', 'algorithm', 'labels', 'original_ttl',
'expiration', 'inception', 'key_tag', 'signer',
'signature']
__slots__ = [
"type_covered",
"algorithm",
"labels",
"original_ttl",
"expiration",
"inception",
"key_tag",
"signer",
"signature",
]
def __init__(self, rdclass, rdtype, type_covered, algorithm, labels,
original_ttl, expiration, inception, key_tag, signer,
signature):
def __init__(
self,
rdclass,
rdtype,
type_covered,
algorithm,
labels,
original_ttl,
expiration,
inception,
key_tag,
signer,
signature,
):
super().__init__(rdclass, rdtype)
self.type_covered = self._as_rdatatype(type_covered)
self.algorithm = dns.dnssec.Algorithm.make(algorithm)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.labels = self._as_uint8(labels)
self.original_ttl = self._as_ttl(original_ttl)
self.expiration = self._as_uint32(expiration)
@ -78,7 +96,7 @@ class RRSIG(dns.rdata.Rdata):
return self.type_covered
def to_text(self, origin=None, relativize=True, **kw):
return '%s %d %d %d %s %s %d %s %s' % (
return "%s %d %d %d %s %s %d %s %s" % (
dns.rdatatype.to_text(self.type_covered),
self.algorithm,
self.labels,
@ -87,14 +105,15 @@ class RRSIG(dns.rdata.Rdata):
posixtime_to_sigtime(self.inception),
self.key_tag,
self.signer.choose_relativity(origin, relativize),
dns.rdata._base64ify(self.signature, **kw)
dns.rdata._base64ify(self.signature, **kw),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
type_covered = dns.rdatatype.from_text(tok.get_string())
algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string())
labels = tok.get_int()
original_ttl = tok.get_ttl()
expiration = sigtime_to_posixtime(tok.get_string())
@ -103,22 +122,38 @@ class RRSIG(dns.rdata.Rdata):
signer = tok.get_name(origin, relativize, relativize_to)
b64 = tok.concatenate_remaining_identifiers().encode()
signature = base64.b64decode(b64)
return cls(rdclass, rdtype, type_covered, algorithm, labels,
original_ttl, expiration, inception, key_tag, signer,
signature)
return cls(
rdclass,
rdtype,
type_covered,
algorithm,
labels,
original_ttl,
expiration,
inception,
key_tag,
signer,
signature,
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack('!HBBIIIH', self.type_covered,
self.algorithm, self.labels,
self.original_ttl, self.expiration,
self.inception, self.key_tag)
header = struct.pack(
"!HBBIIIH",
self.type_covered,
self.algorithm,
self.labels,
self.original_ttl,
self.expiration,
self.inception,
self.key_tag,
)
file.write(header)
self.signer.to_wire(file, None, origin, canonicalize)
file.write(self.signature)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!HBBIIIH')
header = parser.get_struct("!HBBIIIH")
signer = parser.get_name(origin)
signature = parser.get_remaining()
return cls(rdclass, rdtype, *header, signer, signature)

View file

@ -30,11 +30,11 @@ class SOA(dns.rdata.Rdata):
# see: RFC 1035
__slots__ = ['mname', 'rname', 'serial', 'refresh', 'retry', 'expire',
'minimum']
__slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"]
def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry,
expire, minimum):
def __init__(
self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
):
super().__init__(rdclass, rdtype)
self.mname = self._as_name(mname)
self.rname = self._as_name(rname)
@ -47,13 +47,20 @@ class SOA(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
mname = self.mname.choose_relativity(origin, relativize)
rname = self.rname.choose_relativity(origin, relativize)
return '%s %s %d %d %d %d %d' % (
mname, rname, self.serial, self.refresh, self.retry,
self.expire, self.minimum)
return "%s %s %d %d %d %d %d" % (
mname,
rname,
self.serial,
self.refresh,
self.retry,
self.expire,
self.minimum,
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
mname = tok.get_name(origin, relativize, relativize_to)
rname = tok.get_name(origin, relativize, relativize_to)
serial = tok.get_uint32()
@ -61,18 +68,20 @@ class SOA(dns.rdata.Rdata):
retry = tok.get_ttl()
expire = tok.get_ttl()
minimum = tok.get_ttl()
return cls(rdclass, rdtype, mname, rname, serial, refresh, retry,
expire, minimum)
return cls(
rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.mname.to_wire(file, compress, origin, canonicalize)
self.rname.to_wire(file, compress, origin, canonicalize)
five_ints = struct.pack('!IIIII', self.serial, self.refresh,
self.retry, self.expire, self.minimum)
five_ints = struct.pack(
"!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum
)
file.write(five_ints)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
mname = parser.get_name(origin)
rname = parser.get_name(origin)
return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII'))
return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII"))

View file

@ -30,10 +30,9 @@ class SSHFP(dns.rdata.Rdata):
# See RFC 4255
__slots__ = ['algorithm', 'fp_type', 'fingerprint']
__slots__ = ["algorithm", "fp_type", "fingerprint"]
def __init__(self, rdclass, rdtype, algorithm, fp_type,
fingerprint):
def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_uint8(algorithm)
self.fp_type = self._as_uint8(fp_type)
@ -41,16 +40,17 @@ class SSHFP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy()
chunksize = kw.pop('chunksize', 128)
return '%d %d %s' % (self.algorithm,
self.fp_type,
dns.rdata._hexify(self.fingerprint,
chunksize=chunksize,
**kw))
chunksize = kw.pop("chunksize", 128)
return "%d %d %s" % (
self.algorithm,
self.fp_type,
dns.rdata._hexify(self.fingerprint, chunksize=chunksize, **kw),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_uint8()
fp_type = tok.get_uint8()
fingerprint = tok.concatenate_remaining_identifiers().encode()

View file

@ -18,7 +18,6 @@
import base64
import struct
import dns.dnssec
import dns.immutable
import dns.exception
import dns.rdata
@ -29,11 +28,28 @@ class TKEY(dns.rdata.Rdata):
"""TKEY Record"""
__slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error',
'key', 'other']
__slots__ = [
"algorithm",
"inception",
"expiration",
"mode",
"error",
"key",
"other",
]
def __init__(self, rdclass, rdtype, algorithm, inception, expiration,
mode, error, key, other=b''):
def __init__(
self,
rdclass,
rdtype,
algorithm,
inception,
expiration,
mode,
error,
key,
other=b"",
):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm)
self.inception = self._as_uint32(inception)
@ -45,17 +61,23 @@ class TKEY(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
_algorithm = self.algorithm.choose_relativity(origin, relativize)
text = '%s %u %u %u %u %s' % (str(_algorithm), self.inception,
self.expiration, self.mode, self.error,
dns.rdata._base64ify(self.key, 0))
text = "%s %u %u %u %u %s" % (
str(_algorithm),
self.inception,
self.expiration,
self.mode,
self.error,
dns.rdata._base64ify(self.key, 0),
)
if len(self.other) > 0:
text += ' %s' % (dns.rdata._base64ify(self.other, 0))
text += " %s" % (dns.rdata._base64ify(self.other, 0))
return text
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_name(relativize=False)
inception = tok.get_uint32()
expiration = tok.get_uint32()
@ -66,13 +88,15 @@ class TKEY(dns.rdata.Rdata):
other_b64 = tok.concatenate_remaining_identifiers(True).encode()
other = base64.b64decode(other_b64)
return cls(rdclass, rdtype, algorithm, inception, expiration, mode,
error, key, other)
return cls(
rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, compress, origin)
file.write(struct.pack("!IIHH", self.inception, self.expiration,
self.mode, self.error))
file.write(
struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error)
)
file.write(struct.pack("!H", len(self.key)))
file.write(self.key)
file.write(struct.pack("!H", len(self.other)))
@ -86,8 +110,9 @@ class TKEY(dns.rdata.Rdata):
key = parser.get_counted_bytes(2)
other = parser.get_counted_bytes(2)
return cls(rdclass, rdtype, algorithm, inception, expiration, mode,
error, key, other)
return cls(
rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other
)
# Constants for the mode field - from RFC 2930:
# 2.5 The Mode Field

View file

@ -29,11 +29,28 @@ class TSIG(dns.rdata.Rdata):
"""TSIG record"""
__slots__ = ['algorithm', 'time_signed', 'fudge', 'mac',
'original_id', 'error', 'other']
__slots__ = [
"algorithm",
"time_signed",
"fudge",
"mac",
"original_id",
"error",
"other",
]
def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac,
original_id, error, other):
def __init__(
self,
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
):
"""Initialize a TSIG rdata.
*rdclass*, an ``int`` is the rdataclass of the Rdata.
@ -67,45 +84,60 @@ class TSIG(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
algorithm = self.algorithm.choose_relativity(origin, relativize)
error = dns.rcode.to_text(self.error, True)
text = f"{algorithm} {self.time_signed} {self.fudge} " + \
f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \
f"{self.original_id} {error} {len(self.other)}"
text = (
f"{algorithm} {self.time_signed} {self.fudge} "
+ f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} "
+ f"{self.original_id} {error} {len(self.other)}"
)
if self.other:
text += f" {dns.rdata._base64ify(self.other, 0)}"
return text
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
algorithm = tok.get_name(relativize=False)
time_signed = tok.get_uint48()
fudge = tok.get_uint16()
mac_len = tok.get_uint16()
mac = base64.b64decode(tok.get_string())
if len(mac) != mac_len:
raise SyntaxError('invalid MAC')
raise SyntaxError("invalid MAC")
original_id = tok.get_uint16()
error = dns.rcode.from_text(tok.get_string())
other_len = tok.get_uint16()
if other_len > 0:
other = base64.b64decode(tok.get_string())
if len(other) != other_len:
raise SyntaxError('invalid other data')
raise SyntaxError("invalid other data")
else:
other = b''
return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac,
original_id, error, other)
other = b""
return cls(
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, None, origin, False)
file.write(struct.pack('!HIHH',
(self.time_signed >> 32) & 0xffff,
self.time_signed & 0xffffffff,
self.fudge,
len(self.mac)))
file.write(
struct.pack(
"!HIHH",
(self.time_signed >> 32) & 0xFFFF,
self.time_signed & 0xFFFFFFFF,
self.fudge,
len(self.mac),
)
)
file.write(self.mac)
file.write(struct.pack('!HHH', self.original_id, self.error,
len(self.other)))
file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other)))
file.write(self.other)
@classmethod
@ -114,7 +146,16 @@ class TSIG(dns.rdata.Rdata):
time_signed = parser.get_uint48()
fudge = parser.get_uint16()
mac = parser.get_counted_bytes(2)
(original_id, error) = parser.get_struct('!HH')
(original_id, error) = parser.get_struct("!HH")
other = parser.get_counted_bytes(2)
return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac,
original_id, error, other)
return cls(
rdclass,
rdtype,
algorithm,
time_signed,
fudge,
mac,
original_id,
error,
other,
)

View file

@ -32,7 +32,7 @@ class URI(dns.rdata.Rdata):
# see RFC 7553
__slots__ = ['priority', 'weight', 'target']
__slots__ = ["priority", "weight", "target"]
def __init__(self, rdclass, rdtype, priority, weight, target):
super().__init__(rdclass, rdtype)
@ -43,12 +43,12 @@ class URI(dns.rdata.Rdata):
raise dns.exception.SyntaxError("URI target cannot be empty")
def to_text(self, origin=None, relativize=True, **kw):
return '%d %d "%s"' % (self.priority, self.weight,
self.target.decode())
return '%d %d "%s"' % (self.priority, self.weight, self.target.decode())
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
priority = tok.get_uint16()
weight = tok.get_uint16()
target = tok.get().unescape()
@ -63,10 +63,10 @@ class URI(dns.rdata.Rdata):
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(priority, weight) = parser.get_struct('!HH')
(priority, weight) = parser.get_struct("!HH")
target = parser.get_remaining()
if len(target) == 0:
raise dns.exception.FormError('URI target may not be empty')
raise dns.exception.FormError("URI target may not be empty")
return cls(rdclass, rdtype, priority, weight, target)
def _processing_priority(self):

View file

@ -30,7 +30,7 @@ class X25(dns.rdata.Rdata):
# see RFC 1183
__slots__ = ['address']
__slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
@ -40,15 +40,16 @@ class X25(dns.rdata.Rdata):
return '"%s"' % dns.rdata._escapify(self.address)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string()
return cls(rdclass, rdtype, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.address)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(self.address)
@classmethod

View file

@ -6,7 +6,7 @@ import binascii
import dns.immutable
import dns.rdata
import dns.rdatatype
import dns.zone
import dns.zonetypes
@dns.immutable.immutable
@ -16,35 +16,38 @@ class ZONEMD(dns.rdata.Rdata):
# See RFC 8976
__slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest']
__slots__ = ["serial", "scheme", "hash_algorithm", "digest"]
def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest):
super().__init__(rdclass, rdtype)
self.serial = self._as_uint32(serial)
self.scheme = dns.zone.DigestScheme.make(scheme)
self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm)
self.scheme = dns.zonetypes.DigestScheme.make(scheme)
self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm)
self.digest = self._as_bytes(digest)
if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2
raise ValueError('scheme 0 is reserved')
raise ValueError("scheme 0 is reserved")
if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3
raise ValueError('hash_algorithm 0 is reserved')
raise ValueError("hash_algorithm 0 is reserved")
hasher = dns.zone._digest_hashers.get(self.hash_algorithm)
hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm)
if hasher and hasher().digest_size != len(self.digest):
raise ValueError('digest length inconsistent with hash algorithm')
raise ValueError("digest length inconsistent with hash algorithm")
def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy()
chunksize = kw.pop('chunksize', 128)
return '%d %d %d %s' % (self.serial, self.scheme, self.hash_algorithm,
dns.rdata._hexify(self.digest,
chunksize=chunksize,
**kw))
chunksize = kw.pop("chunksize", 128)
return "%d %d %d %s" % (
self.serial,
self.scheme,
self.hash_algorithm,
dns.rdata._hexify(self.digest, chunksize=chunksize, **kw),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
serial = tok.get_uint32()
scheme = tok.get_uint8()
hash_algorithm = tok.get_uint8()
@ -53,8 +56,7 @@ class ZONEMD(dns.rdata.Rdata):
return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!IBB", self.serial, self.scheme,
self.hash_algorithm)
header = struct.pack("!IBB", self.serial, self.scheme, self.hash_algorithm)
file.write(header)
file.write(self.digest)

View file

@ -18,51 +18,51 @@
"""Class ANY (generic) rdata type classes."""
__all__ = [
'AFSDB',
'AMTRELAY',
'AVC',
'CAA',
'CDNSKEY',
'CDS',
'CERT',
'CNAME',
'CSYNC',
'DLV',
'DNAME',
'DNSKEY',
'DS',
'EUI48',
'EUI64',
'GPOS',
'HINFO',
'HIP',
'ISDN',
'L32',
'L64',
'LOC',
'LP',
'MX',
'NID',
'NINFO',
'NS',
'NSEC',
'NSEC3',
'NSEC3PARAM',
'OPENPGPKEY',
'OPT',
'PTR',
'RP',
'RRSIG',
'RT',
'SMIMEA',
'SOA',
'SPF',
'SSHFP',
'TKEY',
'TLSA',
'TSIG',
'TXT',
'URI',
'X25',
'ZONEMD',
"AFSDB",
"AMTRELAY",
"AVC",
"CAA",
"CDNSKEY",
"CDS",
"CERT",
"CNAME",
"CSYNC",
"DLV",
"DNAME",
"DNSKEY",
"DS",
"EUI48",
"EUI64",
"GPOS",
"HINFO",
"HIP",
"ISDN",
"L32",
"L64",
"LOC",
"LP",
"MX",
"NID",
"NINFO",
"NS",
"NSEC",
"NSEC3",
"NSEC3PARAM",
"OPENPGPKEY",
"OPT",
"PTR",
"RP",
"RRSIG",
"RT",
"SMIMEA",
"SOA",
"SPF",
"SSHFP",
"TKEY",
"TLSA",
"TSIG",
"TXT",
"URI",
"X25",
"ZONEMD",
]

View file

@ -20,6 +20,7 @@ import struct
import dns.rdtypes.mxbase
import dns.immutable
@dns.immutable.immutable
class A(dns.rdata.Rdata):
@ -28,7 +29,7 @@ class A(dns.rdata.Rdata):
# domain: the domain of the address
# address: the 16-bit address
__slots__ = ['domain', 'address']
__slots__ = ["domain", "address"]
def __init__(self, rdclass, rdtype, domain, address):
super().__init__(rdclass, rdtype)
@ -37,11 +38,12 @@ class A(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
domain = self.domain.choose_relativity(origin, relativize)
return '%s %o' % (domain, self.address)
return "%s %o" % (domain, self.address)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
domain = tok.get_name(origin, relativize, relativize_to)
address = tok.get_uint16(base=8)
return cls(rdclass, rdtype, domain, address)

View file

@ -18,5 +18,5 @@
"""Class CH rdata type classes."""
__all__ = [
'A',
"A",
]

View file

@ -27,7 +27,7 @@ class A(dns.rdata.Rdata):
"""A record."""
__slots__ = ['address']
__slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
@ -37,8 +37,9 @@ class A(dns.rdata.Rdata):
return self.address
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_identifier()
return cls(rdclass, rdtype, address)

View file

@ -27,7 +27,7 @@ class AAAA(dns.rdata.Rdata):
"""AAAA record."""
__slots__ = ['address']
__slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
@ -37,8 +37,9 @@ class AAAA(dns.rdata.Rdata):
return self.address
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_identifier()
return cls(rdclass, rdtype, address)

View file

@ -26,12 +26,13 @@ import dns.ipv6
import dns.rdata
import dns.tokenizer
@dns.immutable.immutable
class APLItem:
"""An APL list item."""
__slots__ = ['family', 'negation', 'address', 'prefix']
__slots__ = ["family", "negation", "address", "prefix"]
def __init__(self, family, negation, address, prefix):
self.family = dns.rdata.Rdata._as_uint16(family)
@ -67,12 +68,12 @@ class APLItem:
if address[i] != 0:
last = i + 1
break
address = address[0: last]
address = address[0:last]
l = len(address)
assert l < 128
if self.negation:
l |= 0x80
header = struct.pack('!HBB', self.family, self.prefix, l)
header = struct.pack("!HBB", self.family, self.prefix, l)
file.write(header)
file.write(address)
@ -84,32 +85,33 @@ class APL(dns.rdata.Rdata):
# see: RFC 3123
__slots__ = ['items']
__slots__ = ["items"]
def __init__(self, rdclass, rdtype, items):
super().__init__(rdclass, rdtype)
for item in items:
if not isinstance(item, APLItem):
raise ValueError('item not an APLItem')
raise ValueError("item not an APLItem")
self.items = tuple(items)
def to_text(self, origin=None, relativize=True, **kw):
return ' '.join(map(str, self.items))
return " ".join(map(str, self.items))
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
items = []
for token in tok.get_remaining():
item = token.unescape().value
if item[0] == '!':
if item[0] == "!":
negation = True
item = item[1:]
else:
negation = False
(family, rest) = item.split(':', 1)
(family, rest) = item.split(":", 1)
family = int(family)
(address, prefix) = rest.split('/', 1)
(address, prefix) = rest.split("/", 1)
prefix = int(prefix)
item = APLItem(family, negation, address, prefix)
items.append(item)
@ -125,7 +127,7 @@ class APL(dns.rdata.Rdata):
items = []
while parser.remaining() > 0:
header = parser.get_struct('!HBB')
header = parser.get_struct("!HBB")
afdlen = header[2]
if afdlen > 127:
negation = True
@ -136,16 +138,16 @@ class APL(dns.rdata.Rdata):
l = len(address)
if header[0] == 1:
if l < 4:
address += b'\x00' * (4 - l)
address += b"\x00" * (4 - l)
elif header[0] == 2:
if l < 16:
address += b'\x00' * (16 - l)
address += b"\x00" * (16 - l)
else:
#
# This isn't really right according to the RFC, but it
# seems better than throwing an exception
#
address = codecs.encode(address, 'hex_codec')
address = codecs.encode(address, "hex_codec")
item = APLItem(header[0], negation, address, header[1])
items.append(item)
return cls(rdclass, rdtype, items)

View file

@ -19,6 +19,7 @@ import base64
import dns.exception
import dns.immutable
import dns.rdata
@dns.immutable.immutable
@ -28,7 +29,7 @@ class DHCID(dns.rdata.Rdata):
# see: RFC 4701
__slots__ = ['data']
__slots__ = ["data"]
def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype)
@ -38,8 +39,9 @@ class DHCID(dns.rdata.Rdata):
return dns.rdata._base64ify(self.data, **kw)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
b64 = tok.concatenate_remaining_identifiers().encode()
data = base64.b64decode(b64)
return cls(rdclass, rdtype, data)

View file

@ -3,6 +3,7 @@
import dns.rdtypes.svcbbase
import dns.immutable
@dns.immutable.immutable
class HTTPS(dns.rdtypes.svcbbase.SVCBBase):
"""HTTPS record"""

View file

@ -24,7 +24,8 @@ import dns.rdtypes.util
class Gateway(dns.rdtypes.util.Gateway):
name = 'IPSECKEY gateway'
name = "IPSECKEY gateway"
@dns.immutable.immutable
class IPSECKEY(dns.rdata.Rdata):
@ -33,10 +34,11 @@ class IPSECKEY(dns.rdata.Rdata):
# see: RFC 4025
__slots__ = ['precedence', 'gateway_type', 'algorithm', 'gateway', 'key']
__slots__ = ["precedence", "gateway_type", "algorithm", "gateway", "key"]
def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm,
gateway, key):
def __init__(
self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key
):
super().__init__(rdclass, rdtype)
gateway = Gateway(gateway_type, gateway)
self.precedence = self._as_uint8(precedence)
@ -46,38 +48,45 @@ class IPSECKEY(dns.rdata.Rdata):
self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw):
gateway = Gateway(self.gateway_type, self.gateway).to_text(origin,
relativize)
return '%d %d %d %s %s' % (self.precedence, self.gateway_type,
self.algorithm, gateway,
dns.rdata._base64ify(self.key, **kw))
gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize)
return "%d %d %d %s %s" % (
self.precedence,
self.gateway_type,
self.algorithm,
gateway,
dns.rdata._base64ify(self.key, **kw),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
precedence = tok.get_uint8()
gateway_type = tok.get_uint8()
algorithm = tok.get_uint8()
gateway = Gateway.from_text(gateway_type, tok, origin, relativize,
relativize_to)
gateway = Gateway.from_text(
gateway_type, tok, origin, relativize, relativize_to
)
b64 = tok.concatenate_remaining_identifiers().encode()
key = base64.b64decode(b64)
return cls(rdclass, rdtype, precedence, gateway_type, algorithm,
gateway.gateway, key)
return cls(
rdclass, rdtype, precedence, gateway_type, algorithm, gateway.gateway, key
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!BBB", self.precedence, self.gateway_type,
self.algorithm)
header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm)
file.write(header)
Gateway(self.gateway_type, self.gateway).to_wire(file, compress,
origin, canonicalize)
Gateway(self.gateway_type, self.gateway).to_wire(
file, compress, origin, canonicalize
)
file.write(self.key)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!BBB')
header = parser.get_struct("!BBB")
gateway_type = header[1]
gateway = Gateway.from_wire_parser(gateway_type, parser, origin)
key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], gateway_type, header[2],
gateway.gateway, key)
return cls(
rdclass, rdtype, header[0], gateway_type, header[2], gateway.gateway, key
)

View file

@ -27,7 +27,7 @@ import dns.rdtypes.util
def _write_string(file, s):
l = len(s)
assert l < 256
file.write(struct.pack('!B', l))
file.write(struct.pack("!B", l))
file.write(s)
@ -38,11 +38,11 @@ class NAPTR(dns.rdata.Rdata):
# see: RFC 3403
__slots__ = ['order', 'preference', 'flags', 'service', 'regexp',
'replacement']
__slots__ = ["order", "preference", "flags", "service", "regexp", "replacement"]
def __init__(self, rdclass, rdtype, order, preference, flags, service,
regexp, replacement):
def __init__(
self, rdclass, rdtype, order, preference, flags, service, regexp, replacement
):
super().__init__(rdclass, rdtype)
self.flags = self._as_bytes(flags, True, 255)
self.service = self._as_bytes(service, True, 255)
@ -53,24 +53,28 @@ class NAPTR(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
replacement = self.replacement.choose_relativity(origin, relativize)
return '%d %d "%s" "%s" "%s" %s' % \
(self.order, self.preference,
dns.rdata._escapify(self.flags),
dns.rdata._escapify(self.service),
dns.rdata._escapify(self.regexp),
replacement)
return '%d %d "%s" "%s" "%s" %s' % (
self.order,
self.preference,
dns.rdata._escapify(self.flags),
dns.rdata._escapify(self.service),
dns.rdata._escapify(self.regexp),
replacement,
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
order = tok.get_uint16()
preference = tok.get_uint16()
flags = tok.get_string()
service = tok.get_string()
regexp = tok.get_string()
replacement = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, order, preference, flags, service,
regexp, replacement)
return cls(
rdclass, rdtype, order, preference, flags, service, regexp, replacement
)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
two_ints = struct.pack("!HH", self.order, self.preference)
@ -82,14 +86,22 @@ class NAPTR(dns.rdata.Rdata):
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(order, preference) = parser.get_struct('!HH')
(order, preference) = parser.get_struct("!HH")
strings = []
for _ in range(3):
s = parser.get_counted_bytes()
strings.append(s)
replacement = parser.get_name(origin)
return cls(rdclass, rdtype, order, preference, strings[0], strings[1],
strings[2], replacement)
return cls(
rdclass,
rdtype,
order,
preference,
strings[0],
strings[1],
strings[2],
replacement,
)
def _processing_priority(self):
return (self.order, self.preference)

View file

@ -30,7 +30,7 @@ class NSAP(dns.rdata.Rdata):
# see: RFC 1706
__slots__ = ['address']
__slots__ = ["address"]
def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype)
@ -40,14 +40,15 @@ class NSAP(dns.rdata.Rdata):
return "0x%s" % binascii.hexlify(self.address).decode()
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string()
if address[0:2] != '0x':
raise dns.exception.SyntaxError('string does not start with 0x')
address = address[2:].replace('.', '')
if address[0:2] != "0x":
raise dns.exception.SyntaxError("string does not start with 0x")
address = address[2:].replace(".", "")
if len(address) % 2 != 0:
raise dns.exception.SyntaxError('hexstring has odd length')
raise dns.exception.SyntaxError("hexstring has odd length")
address = binascii.unhexlify(address.encode())
return cls(rdclass, rdtype, address)

View file

@ -31,7 +31,7 @@ class PX(dns.rdata.Rdata):
# see: RFC 2163
__slots__ = ['preference', 'map822', 'mapx400']
__slots__ = ["preference", "map822", "mapx400"]
def __init__(self, rdclass, rdtype, preference, map822, mapx400):
super().__init__(rdclass, rdtype)
@ -42,11 +42,12 @@ class PX(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
map822 = self.map822.choose_relativity(origin, relativize)
mapx400 = self.mapx400.choose_relativity(origin, relativize)
return '%d %s %s' % (self.preference, map822, mapx400)
return "%d %s %s" % (self.preference, map822, mapx400)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
preference = tok.get_uint16()
map822 = tok.get_name(origin, relativize, relativize_to)
mapx400 = tok.get_name(origin, relativize, relativize_to)

View file

@ -31,7 +31,7 @@ class SRV(dns.rdata.Rdata):
# see: RFC 2782
__slots__ = ['priority', 'weight', 'port', 'target']
__slots__ = ["priority", "weight", "port", "target"]
def __init__(self, rdclass, rdtype, priority, weight, port, target):
super().__init__(rdclass, rdtype)
@ -42,12 +42,12 @@ class SRV(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
target = self.target.choose_relativity(origin, relativize)
return '%d %d %d %s' % (self.priority, self.weight, self.port,
target)
return "%d %d %d %s" % (self.priority, self.weight, self.port, target)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
priority = tok.get_uint16()
weight = tok.get_uint16()
port = tok.get_uint16()
@ -61,7 +61,7 @@ class SRV(dns.rdata.Rdata):
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(priority, weight, port) = parser.get_struct('!HHH')
(priority, weight, port) = parser.get_struct("!HHH")
target = parser.get_name(origin)
return cls(rdclass, rdtype, priority, weight, port, target)

View file

@ -3,6 +3,7 @@
import dns.rdtypes.svcbbase
import dns.immutable
@dns.immutable.immutable
class SVCB(dns.rdtypes.svcbbase.SVCBBase):
"""SVCB record"""

View file

@ -23,13 +23,14 @@ import dns.immutable
import dns.rdata
try:
_proto_tcp = socket.getprotobyname('tcp')
_proto_udp = socket.getprotobyname('udp')
_proto_tcp = socket.getprotobyname("tcp")
_proto_udp = socket.getprotobyname("udp")
except OSError:
# Fall back to defaults in case /etc/protocols is unavailable.
_proto_tcp = 6
_proto_udp = 17
@dns.immutable.immutable
class WKS(dns.rdata.Rdata):
@ -37,7 +38,7 @@ class WKS(dns.rdata.Rdata):
# see: RFC 1035
__slots__ = ['address', 'protocol', 'bitmap']
__slots__ = ["address", "protocol", "bitmap"]
def __init__(self, rdclass, rdtype, address, protocol, bitmap):
super().__init__(rdclass, rdtype)
@ -51,12 +52,13 @@ class WKS(dns.rdata.Rdata):
for j in range(0, 8):
if byte & (0x80 >> j):
bits.append(str(i * 8 + j))
text = ' '.join(bits)
return '%s %d %s' % (self.address, self.protocol, text)
text = " ".join(bits)
return "%s %d %s" % (self.address, self.protocol, text)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
address = tok.get_string()
protocol = tok.get_string()
if protocol.isdigit():
@ -87,7 +89,7 @@ class WKS(dns.rdata.Rdata):
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(dns.ipv4.inet_aton(self.address))
protocol = struct.pack('!B', self.protocol)
protocol = struct.pack("!B", self.protocol)
file.write(protocol)
file.write(self.bitmap)

View file

@ -18,18 +18,18 @@
"""Class IN rdata type classes."""
__all__ = [
'A',
'AAAA',
'APL',
'DHCID',
'HTTPS',
'IPSECKEY',
'KX',
'NAPTR',
'NSAP',
'NSAP_PTR',
'PX',
'SRV',
'SVCB',
'WKS',
"A",
"AAAA",
"APL",
"DHCID",
"HTTPS",
"IPSECKEY",
"KX",
"NAPTR",
"NSAP",
"NSAP_PTR",
"PX",
"SRV",
"SVCB",
"WKS",
]

View file

@ -18,16 +18,16 @@
"""DNS rdata type classes"""
__all__ = [
'ANY',
'IN',
'CH',
'dnskeybase',
'dsbase',
'euibase',
'mxbase',
'nsbase',
'svcbbase',
'tlsabase',
'txtbase',
'util'
"ANY",
"IN",
"CH",
"dnskeybase",
"dsbase",
"euibase",
"mxbase",
"nsbase",
"svcbbase",
"tlsabase",
"txtbase",
"util",
]

View file

@ -21,11 +21,12 @@ import struct
import dns.exception
import dns.immutable
import dns.dnssec
import dns.dnssectypes
import dns.rdata
# wildcard import
__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822
__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822
class Flag(enum.IntFlag):
SEP = 0x0001
@ -38,22 +39,27 @@ class DNSKEYBase(dns.rdata.Rdata):
"""Base class for rdata that is like a DNSKEY record"""
__slots__ = ['flags', 'protocol', 'algorithm', 'key']
__slots__ = ["flags", "protocol", "algorithm", "key"]
def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key):
super().__init__(rdclass, rdtype)
self.flags = self._as_uint16(flags)
self.protocol = self._as_uint8(protocol)
self.algorithm = dns.dnssec.Algorithm.make(algorithm)
self.algorithm = dns.dnssectypes.Algorithm.make(algorithm)
self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw):
return '%d %d %d %s' % (self.flags, self.protocol, self.algorithm,
dns.rdata._base64ify(self.key, **kw))
return "%d %d %d %s" % (
self.flags,
self.protocol,
self.algorithm,
dns.rdata._base64ify(self.key, **kw),
)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
def from_text(
cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None
):
flags = tok.get_uint16()
protocol = tok.get_uint8()
algorithm = tok.get_string()
@ -68,10 +74,10 @@ class DNSKEYBase(dns.rdata.Rdata):
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!HBB')
header = parser.get_struct("!HBB")
key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2],
key)
return cls(rdclass, rdtype, header[0], header[1], header[2], key)
### BEGIN generated Flag constants

Some files were not shown because too many files have changed in this diff Show more