Bump dnspython from 2.0.0 to 2.2.0 (#1618)

* Bump dnspython from 2.0.0 to 2.2.0

Bumps [dnspython]() from 2.0.0 to 2.2.0.

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.2.0

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2022-01-25 11:08:24 -08:00 committed by GitHub
parent 515a5d42d3
commit 3c93b5600f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
143 changed files with 7498 additions and 2054 deletions

View file

@ -27,6 +27,7 @@ __all__ = [
'entropy', 'entropy',
'exception', 'exception',
'flags', 'flags',
'immutable',
'inet', 'inet',
'ipv4', 'ipv4',
'ipv6', 'ipv6',
@ -48,14 +49,18 @@ __all__ = [
'serial', 'serial',
'set', 'set',
'tokenizer', 'tokenizer',
'transaction',
'tsig', 'tsig',
'tsigkeyring', 'tsigkeyring',
'ttl', 'ttl',
'rdtypes', 'rdtypes',
'update', 'update',
'version', 'version',
'versioned',
'wire', 'wire',
'xfr',
'zone', 'zone',
'zonefile',
] ]
from dns.version import version as __version__ # noqa from dns.version import version as __version__ # noqa

View file

@ -27,6 +27,12 @@ class Socket: # pragma: no cover
async def close(self): async def close(self):
pass pass
async def getpeername(self):
raise NotImplementedError
async def getsockname(self):
raise NotImplementedError
async def __aenter__(self): async def __aenter__(self):
return self return self
@ -36,18 +42,18 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover class DatagramSocket(Socket): # pragma: no cover
async def sendto(self, what, destination, timeout): async def sendto(self, what, destination, timeout):
pass raise NotImplementedError
async def recvfrom(self, size, timeout): async def recvfrom(self, size, timeout):
pass raise NotImplementedError
class StreamSocket(Socket): # pragma: no cover class StreamSocket(Socket): # pragma: no cover
async def sendall(self, what, destination, timeout): async def sendall(self, what, timeout):
pass raise NotImplementedError
async def recv(self, size, timeout): async def recv(self, size, timeout):
pass raise NotImplementedError
class Backend: # pragma: no cover class Backend: # pragma: no cover
@ -58,3 +64,6 @@ class Backend: # pragma: no cover
source=None, destination=None, timeout=None, source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None): ssl_context=None, server_hostname=None):
raise NotImplementedError raise NotImplementedError
def datagram_connection_required(self):
return False

View file

@ -4,11 +4,14 @@
import socket import socket
import asyncio import asyncio
import sys
import dns._asyncbackend import dns._asyncbackend
import dns.exception import dns.exception
_is_win32 = sys.platform == 'win32'
def _get_running_loop(): def _get_running_loop():
try: try:
return asyncio.get_running_loop() return asyncio.get_running_loop()
@ -25,16 +28,16 @@ class _DatagramProtocol:
self.transport = transport self.transport = transport
def datagram_received(self, data, addr): def datagram_received(self, data, addr):
if self.recvfrom: if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr)) self.recvfrom.set_result((data, addr))
self.recvfrom = None self.recvfrom = None
def error_received(self, exc): # pragma: no cover def error_received(self, exc): # pragma: no cover
if self.recvfrom: if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc) self.recvfrom.set_exception(exc)
def connection_lost(self, exc): def connection_lost(self, exc):
if self.recvfrom: if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc) self.recvfrom.set_exception(exc)
def close(self): def close(self):
@ -79,21 +82,19 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
return self.transport.get_extra_info('sockname') return self.transport.get_extra_info('sockname')
class StreamSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer): def __init__(self, af, reader, writer):
self.family = af self.family = af
self.reader = reader self.reader = reader
self.writer = writer self.writer = writer
async def sendall(self, what, timeout): async def sendall(self, what, timeout):
self.writer.write(what), self.writer.write(what)
return await _maybe_wait_for(self.writer.drain(), timeout) return await _maybe_wait_for(self.writer.drain(), timeout)
raise dns.exception.Timeout(timeout=timeout)
async def recv(self, count, timeout): async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(count), return await _maybe_wait_for(self.reader.read(size),
timeout) timeout)
raise dns.exception.Timeout(timeout=timeout)
async def close(self): async def close(self):
self.writer.close() self.writer.close()
@ -116,11 +117,16 @@ class Backend(dns._asyncbackend.Backend):
async def make_socket(self, af, socktype, proto=0, async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None, source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None): ssl_context=None, server_hostname=None):
if destination is None and socktype == socket.SOCK_DGRAM and \
_is_win32:
raise NotImplementedError('destinationless datagram sockets '
'are not supported by asyncio '
'on Windows')
loop = _get_running_loop() loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM: if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint( transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af, _DatagramProtocol, source, family=af,
proto=proto) proto=proto, remote_addr=destination)
return DatagramSocket(af, transport, protocol) return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM: elif socktype == socket.SOCK_STREAM:
(r, w) = await _maybe_wait_for( (r, w) = await _maybe_wait_for(
@ -138,3 +144,6 @@ class Backend(dns._asyncbackend.Backend):
async def sleep(self, interval): async def sleep(self, interval):
await asyncio.sleep(interval) await asyncio.sleep(interval)
def datagram_connection_required(self):
return _is_win32

View file

@ -21,6 +21,8 @@ def _maybe_timeout(timeout):
# for brevity # for brevity
_lltuple = dns.inet.low_level_address_tuple _lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, socket):
@ -47,7 +49,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
return self.socket.getsockname() return self.socket.getsockname()
class StreamSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, socket): def __init__(self, socket):
self.socket = socket self.socket = socket
self.family = socket.family self.family = socket.family

View file

@ -0,0 +1,84 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator is for python 3.6,
# which doesn't have Context Variables. This implementation is somewhat
# costly for classes with slots, as it adds a __dict__ to them.
import inspect
class _Immutable:
"""Immutable mixin class"""
# Note we MUST NOT have __slots__ as that causes
#
# TypeError: multiple bases have instance lay-out conflict
#
# when we get mixed in with another class with slots. When we
# get mixed into something with slots, it effectively adds __dict__ to
# the slots of the other class, which allows attribute setting to work,
# albeit at the cost of the dictionary.
def __setattr__(self, name, value):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if not hasattr(self, '_immutable_init') or \
self._immutable_init is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
try:
# Are we already initializing an immutable class?
previous = args[0]._immutable_init
except AttributeError:
# We are the first!
previous = None
object.__setattr__(args[0], '_immutable_init', args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
if not previous:
# If we started the initialzation, establish immutability
# by removing the attribute that allows mutation
object.__delattr__(args[0], '_immutable_init')
nf.__signature__ = inspect.signature(f)
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

75
lib/dns/_immutable_ctx.py Normal file
View file

@ -0,0 +1,75 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# This implementation of the immutable decorator requires python >=
# 3.7, and is significantly more storage efficient when making classes
# with slots immutable. It's also faster.
import contextvars
import inspect
_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
class _Immutable:
"""Immutable mixin class"""
# We set slots to the empty list to say "we don't have any attributes".
# We do this so that if we're mixed in with a class with __slots__, we
# don't cause a __dict__ to be added which would waste space.
__slots__ = ()
def __setattr__(self, name, value):
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__setattr__(name, value)
def __delattr__(self, name):
if _in__init__.get() is not self:
raise TypeError("object doesn't support attribute assignment")
else:
super().__delattr__(name)
def _immutable_init(f):
def nf(*args, **kwargs):
previous = _in__init__.set(args[0])
try:
# call the actual __init__
f(*args, **kwargs)
finally:
_in__init__.reset(previous)
nf.__signature__ = inspect.signature(f)
return nf
def immutable(cls):
if _Immutable in cls.__mro__:
# Some ancestor already has the mixin, so just make sure we keep
# following the __init__ protocol.
cls.__init__ = _immutable_init(cls.__init__)
if hasattr(cls, '__setstate__'):
cls.__setstate__ = _immutable_init(cls.__setstate__)
ncls = cls
else:
# Mixin the Immutable class and follow the __init__ protocol.
class ncls(_Immutable, cls):
# We have to do the __slots__ declaration here too!
__slots__ = ()
@_immutable_init
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if hasattr(cls, '__setstate__'):
@_immutable_init
def __setstate__(self, *args, **kwargs):
super().__setstate__(*args, **kwargs)
# make ncls have the same name and module as cls
ncls.__name__ = cls.__name__
ncls.__qualname__ = cls.__qualname__
ncls.__module__ = cls.__module__
return ncls

View file

@ -21,6 +21,8 @@ def _maybe_timeout(timeout):
# for brevity # for brevity
_lltuple = dns.inet.low_level_address_tuple _lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket): class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket): def __init__(self, socket):
@ -47,7 +49,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
return self.socket.getsockname() return self.socket.getsockname()
class StreamSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False): def __init__(self, family, stream, tls=False):
self.family = family self.family = family
self.stream = stream self.stream = stream

View file

@ -2,9 +2,12 @@
import dns.exception import dns.exception
# pylint: disable=unused-import
from dns._asyncbackend import Socket, DatagramSocket, \ from dns._asyncbackend import Socket, DatagramSocket, \
StreamSocket, Backend # noqa: StreamSocket, Backend # noqa:
# pylint: enable=unused-import
_default_backend = None _default_backend = None
@ -18,13 +21,14 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException):
def get_backend(name): def get_backend(name):
"""Get the specified asychronous backend. """Get the specified asynchronous backend.
*name*, a ``str``, the name of the backend. Currently the "trio", *name*, a ``str``, the name of the backend. Currently the "trio",
"curio", and "asyncio" backends are available. "curio", and "asyncio" backends are available.
Raises NotImplementError if an unknown backend name is specified. Raises NotImplementError if an unknown backend name is specified.
""" """
# pylint: disable=import-outside-toplevel,redefined-outer-name
backend = _backends.get(name) backend = _backends.get(name)
if backend: if backend:
return backend return backend
@ -50,6 +54,7 @@ def sniff():
Returns the name of the library, or raises AsyncLibraryNotFoundError Returns the name of the library, or raises AsyncLibraryNotFoundError
if the library cannot be determined. if the library cannot be determined.
""" """
# pylint: disable=import-outside-toplevel
try: try:
if _no_sniffio: if _no_sniffio:
raise ImportError raise ImportError

13
lib/dns/asyncbackend.pyi Normal file
View file

@ -0,0 +1,13 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
class Backend:
...
def get_backend(name: str) -> Backend:
...
def sniff() -> str:
...
def get_default_backend() -> Backend:
...
def set_default_backend(name: str) -> Backend:
...

View file

@ -17,6 +17,7 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
import base64
import socket import socket
import struct import struct
import time import time
@ -30,8 +31,11 @@ import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
from dns.query import _compute_times, _matches_destination, BadResponse, ssl from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
UDPMode, _have_httpx, _have_http2, NoDOH
if _have_httpx:
import httpx
# for brevity # for brevity
_lltuple = dns.inet.low_level_address_tuple _lltuple = dns.inet.low_level_address_tuple
@ -94,36 +98,8 @@ async def receive_udp(sock, destination=None, expiration=None,
*sock*, a ``dns.asyncbackend.DatagramSocket``. *sock*, a ``dns.asyncbackend.DatagramSocket``.
*destination*, a destination tuple appropriate for the address family See :py:func:`dns.query.receive_udp()` for the documentation of the other
of the socket, specifying where the message is expected to arrive from. parameters, exceptions, and return type of this method.
When receiving a response, this would be where the associated query was
sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
unexpected sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*keyring*, a ``dict``, the keyring to use for TSIG.
*request_mac*, a ``bytes``, the MAC of the request (for TSIG).
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
Returns a ``(dns.message.Message, float, tuple)`` tuple of the received
message, the received time, and the address where the message arrived from.
""" """
wire = b'' wire = b''
@ -145,34 +121,6 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
backend=None): backend=None):
"""Return the response obtained after sending a query via UDP. """Return the response obtained after sending a query via UDP.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
unexpected sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
*sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
the socket to use for the query. If ``None``, the default, a the socket to use for the query. If ``None``, the default, a
socket is created. Note that if a socket is provided, the socket is created. Note that if a socket is provided, the
@ -181,7 +129,8 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend. the default, then dnspython will use the default backend.
Returns a ``dns.message.Message``. See :py:func:`dns.query.udp()` for the documentation of the other
parameters, exceptions, and return type of this method.
""" """
wire = q.to_wire() wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
@ -196,7 +145,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
if not backend: if not backend:
backend = dns.asyncbackend.get_default_backend() backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port) stuple = _source_tuple(af, source, source_port)
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) if backend.datagram_connection_required():
dtuple = (where, port)
else:
dtuple = None
s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
dtuple)
await send_udp(s, wire, destination, expiration) await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration, (r, received_time, _) = await receive_udp(s, destination, expiration,
ignore_unexpected, ignore_unexpected,
@ -219,31 +173,6 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
"""Return the response to the query, trying UDP first and falling back """Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response. to TCP if UDP results in a truncated response.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
unexpected sources.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
the socket to use for the UDP query. If ``None``, the default, a the socket to use for the UDP query. If ``None``, the default, a
socket is created. Note that if a socket is provided the *source*, socket is created. Note that if a socket is provided the *source*,
@ -257,8 +186,9 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend. the default, then dnspython will use the default backend.
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` See :py:func:`dns.query.udp_with_fallback()` for the documentation
if and only if TCP was used. of the other parameters, exceptions, and return type of this
method.
""" """
try: try:
response = await udp(q, where, timeout, port, source, source_port, response = await udp(q, where, timeout, port, source, source_port,
@ -275,15 +205,10 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
async def send_tcp(sock, what, expiration=None): async def send_tcp(sock, what, expiration=None):
"""Send a DNS message to the specified TCP socket. """Send a DNS message to the specified TCP socket.
*sock*, a ``socket``. *sock*, a ``dns.asyncbackend.StreamSocket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send. See :py:func:`dns.query.send_tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
occur.
Returns an ``(int, float)`` tuple of bytes sent and the sent time.
""" """
if isinstance(what, dns.message.Message): if isinstance(what, dns.message.Message):
@ -294,7 +219,7 @@ async def send_tcp(sock, what, expiration=None):
# onto the net # onto the net
tcpmsg = struct.pack("!H", l) + what tcpmsg = struct.pack("!H", l) + what
sent_time = time.time() sent_time = time.time()
await sock.sendall(tcpmsg, expiration) await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time) return (len(tcpmsg), sent_time)
@ -316,27 +241,10 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False): keyring=None, request_mac=b'', ignore_trailing=False):
"""Read a DNS message from a TCP socket. """Read a DNS message from a TCP socket.
*sock*, a ``socket``. *sock*, a ``dns.asyncbackend.StreamSocket``.
*expiration*, a ``float`` or ``None``, the absolute time at which See :py:func:`dns.query.receive_tcp()` for the documentation of the other
a timeout exception should be raised. If ``None``, no timeout will parameters, exceptions, and return type of this method.
occur.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*keyring*, a ``dict``, the keyring to use for TSIG.
*request_mac*, a ``bytes``, the MAC of the request (for TSIG).
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
Returns a ``(dns.message.Message, float)`` tuple of the received message
and the received time.
""" """
ldata = await _read_exactly(sock, 2, expiration) ldata = await _read_exactly(sock, 2, expiration)
@ -354,28 +262,6 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
backend=None): backend=None):
"""Return the response obtained after sending a query via TCP. """Return the response obtained after sending a query via TCP.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 53.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
socket to use for the query. If ``None``, the default, a socket socket to use for the query. If ``None``, the default, a socket
is created. Note that if a socket is provided is created. Note that if a socket is provided
@ -384,7 +270,8 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend. the default, then dnspython will use the default backend.
Returns a ``dns.message.Message``. See :py:func:`dns.query.tcp()` for the documentation of the other
parameters, exceptions, and return type of this method.
""" """
wire = q.to_wire() wire = q.to_wire()
@ -426,28 +313,6 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
backend=None, ssl_context=None, server_hostname=None): backend=None, ssl_context=None, server_hostname=None):
"""Return the response obtained after sending a query via TLS. """Return the response obtained after sending a query via TLS.
*q*, a ``dns.message.Message``, the query to send
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
query times out. If ``None``, the default, wait forever.
*port*, an ``int``, the port send the message to. The default is 853.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
*sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
to use for the query. If ``None``, the default, a socket is to use for the query. If ``None``, the default, a socket is
created. Note that if a socket is provided, it must be a created. Note that if a socket is provided, it must be a
@ -458,15 +323,8 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend. the default, then dnspython will use the default backend.
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing See :py:func:`dns.query.tls()` for the documentation of the other
a TLS connection. If ``None``, the default, creates one with the default parameters, exceptions, and return type of this method.
configuration.
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``.
""" """
# After 3.6 is no longer supported, this can use an AsyncExitStack. # After 3.6 is no longer supported, this can use an AsyncExitStack.
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
@ -498,3 +356,168 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
finally: finally:
if not sock and s: if not sock and s:
await s.close() await s.close()
async def https(q, where, timeout=None, port=443, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False, client=None,
path='/dns-query', post=True, verify=True):
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*client*, a ``httpx.AsyncClient``. If provided, the client to use for
the query.
Unlike the other dnspython async functions, a backend cannot be provided
in this function because httpx always auto-detects the async backend.
See :py:func:`dns.query.https()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not _have_httpx:
raise NoDOH('httpx is not available.') # pragma: no cover
wire = q.to_wire()
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
transport = None
headers = {
"accept": "application/dns-message"
}
if af is not None:
if af == socket.AF_INET:
url = 'https://{}:{}{}'.format(where, port, path)
elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path)
else:
url = where
if source is not None:
transport = httpx.AsyncHTTPTransport(local_address=source[0])
# After 3.6 is no longer supported, this can use an AsyncExitStack
client_to_close = None
try:
if not client:
client = httpx.AsyncClient(http1=True, http2=_have_http2,
verify=verify, transport=transport)
client_to_close = client
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
if post:
headers.update({
"content-type": "application/dns-message",
"content-length": str(len(wire))
})
response = await client.post(url, headers=headers, content=wire,
timeout=timeout)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
wire = wire.decode() # httpx does a repr() if we give it bytes
response = await client.get(url, headers=headers, timeout=timeout,
params={"dns": wire})
finally:
if client_to_close:
await client.aclose()
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError('{} responded with status code {}'
'\nResponse body: {}'.format(where,
response.status_code,
response.content))
r = dns.message.from_wire(response.content,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing)
r.time = response.elapsed
if not q.is_response(r):
raise BadResponse
return r
async def inbound_xfr(where, txn_manager, query=None,
port=53, timeout=None, lifetime=None, source=None,
source_port=0, udp_mode=UDPMode.NEVER, backend=None):
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.inbound_xfr()` for the documentation of
the other parameters, exceptions, and return type of this method.
"""
if query is None:
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
(_, expiration) = _compute_times(lifetime)
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
if not backend:
backend = dns.asyncbackend.get_default_backend()
s = await backend.make_socket(af, sock_type, 0, stuple, dtuple,
_timeout(expiration))
async with s:
if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration))
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial,
is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \
(expiration is not None and mexpiration > expiration):
mexpiration = expiration
if is_udp:
destination = _lltuple((where, port), af)
while True:
timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535,
timeout)
if _matches_destination(af, from_address,
destination, True):
break
else:
ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = (rdtype == dns.rdatatype.IXFR)
r = dns.message.from_wire(rwire, keyring=query.keyring,
request_mac=query.mac, xfr=True,
origin=origin, tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")

43
lib/dns/asyncquery.pyi Normal file
View file

@ -0,0 +1,43 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message, asyncbackend
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
async def udp(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.DatagramSocket] = None,
backend : Optional[asyncbackend.Backend]) -> message.Message:
pass
async def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend]) -> message.Message:
pass
async def tls(q : message.Message, where : str,
timeout : Optional[float] = None, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[asyncbackend.StreamSocket] = None,
backend : Optional[asyncbackend.Backend],
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

View file

@ -34,7 +34,8 @@ _udp = dns.asyncquery.udp
_tcp = dns.asyncquery.tcp _tcp = dns.asyncquery.tcp
class Resolver(dns.resolver.Resolver): class Resolver(dns.resolver.BaseResolver):
"""Asynchronous DNS stub resolver."""
async def resolve(self, qname, rdtype=dns.rdatatype.A, async def resolve(self, qname, rdtype=dns.rdatatype.A,
rdclass=dns.rdataclass.IN, rdclass=dns.rdataclass.IN,
@ -43,53 +44,12 @@ class Resolver(dns.resolver.Resolver):
backend=None): backend=None):
"""Query nameservers asynchronously to find the answer to the question. """Query nameservers asynchronously to find the answer to the question.
The *qname*, *rdtype*, and *rdclass* parameters may be objects
of the appropriate type, or strings that can be converted into objects
of the appropriate type.
*qname*, a ``dns.name.Name`` or ``str``, the query name.
*rdtype*, an ``int`` or ``str``, the query type.
*rdclass*, an ``int`` or ``str``, the query class.
*tcp*, a ``bool``. If ``True``, use TCP to make the query.
*source*, a ``str`` or ``None``. If not ``None``, bind to this IP
address when making queries.
*raise_on_no_answer*, a ``bool``. If ``True``, raise
``dns.resolver.NoAnswer`` if there's no answer to the question.
*source_port*, an ``int``, the port from which to send the message.
*lifetime*, a ``float``, how many seconds a query should run
before timing out.
*search*, a ``bool`` or ``None``, determines whether the
search list configured in the system's resolver configuration
are used for relative names, and whether the resolver's domain
may be added to relative names. The default is ``None``,
which causes the value of the resolver's
``use_search_by_default`` attribute to be used.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend. the default, then dnspython will use the default backend.
Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. See :py:func:`dns.resolver.Resolver.resolve()` for the
documentation of the other parameters, exceptions, and return
Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after type of this method.
DNAME substitution.
Raises ``dns.resolver.NoAnswer`` if *raise_on_no_answer* is
``True`` and the query name exists but has no RRset of the
desired type and class.
Raises ``dns.resolver.NoNameservers`` if no non-broken
nameservers are available to answer the question.
Returns a ``dns.resolver.Answer`` instance.
""" """
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
@ -111,7 +71,8 @@ class Resolver(dns.resolver.Resolver):
(nameserver, port, tcp, backoff) = resolution.next_nameserver() (nameserver, port, tcp, backoff) = resolution.next_nameserver()
if backoff: if backoff:
await backend.sleep(backoff) await backend.sleep(backoff)
timeout = self._compute_timeout(start, lifetime) timeout = self._compute_timeout(start, lifetime,
resolution.errors)
try: try:
if dns.inet.is_address(nameserver): if dns.inet.is_address(nameserver):
if tcp: if tcp:
@ -126,8 +87,9 @@ class Resolver(dns.resolver.Resolver):
raise_on_truncation=True, raise_on_truncation=True,
backend=backend) backend=backend)
else: else:
# We don't do DoH yet. response = await dns.asyncquery.https(request,
raise NotImplementedError nameserver,
timeout=timeout)
except Exception as ex: except Exception as ex:
(_, done) = resolution.query_result(None, ex) (_, done) = resolution.query_result(None, ex)
continue continue
@ -139,11 +101,6 @@ class Resolver(dns.resolver.Resolver):
if answer is not None: if answer is not None:
return answer return answer
async def query(self, *args, **kwargs):
# We have to define something here as we don't want to inherit the
# parent's query().
raise NotImplementedError
async def resolve_address(self, ipaddr, *args, **kwargs): async def resolve_address(self, ipaddr, *args, **kwargs):
"""Use an asynchronous resolver to run a reverse query for PTR """Use an asynchronous resolver to run a reverse query for PTR
records. records.
@ -165,6 +122,30 @@ class Resolver(dns.resolver.Resolver):
rdclass=dns.rdataclass.IN, rdclass=dns.rdataclass.IN,
*args, **kwargs) *args, **kwargs)
# pylint: disable=redefined-outer-name
async def canonical_name(self, name):
"""Determine the canonical name of *name*.
The canonical name is the name the resolver uses for queries
after all CNAME and DNAME renamings have been applied.
*name*, a ``dns.name.Name`` or ``str``, the query name.
This method can raise any exception that ``resolve()`` can
raise, other than ``dns.resolver.NoAnswer`` and
``dns.resolver.NXDOMAIN``.
Returns a ``dns.name.Name``.
"""
try:
answer = await self.resolve(name, raise_on_no_answer=False)
canonical_name = answer.canonical_name
except dns.resolver.NXDOMAIN as e:
canonical_name = e.canonical_name
return canonical_name
default_resolver = None default_resolver = None
@ -188,52 +169,46 @@ def reset_default_resolver():
async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True, tcp=False, source=None, raise_on_no_answer=True,
source_port=0, search=None, backend=None): source_port=0, lifetime=None, search=None, backend=None):
"""Query nameservers asynchronously to find the answer to the question. """Query nameservers asynchronously to find the answer to the question.
This is a convenience function that uses the default resolver This is a convenience function that uses the default resolver
object to make the query. object to make the query.
See ``dns.asyncresolver.Resolver.resolve`` for more information on the See :py:func:`dns.asyncresolver.Resolver.resolve` for more
parameters. information on the parameters.
""" """
return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp, return await get_default_resolver().resolve(qname, rdtype, rdclass, tcp,
source, raise_on_no_answer, source, raise_on_no_answer,
source_port, search, backend) source_port, lifetime, search,
backend)
async def resolve_address(ipaddr, *args, **kwargs): async def resolve_address(ipaddr, *args, **kwargs):
"""Use a resolver to run a reverse query for PTR records. """Use a resolver to run a reverse query for PTR records.
See ``dns.asyncresolver.Resolver.resolve_address`` for more See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more
information on the parameters. information on the parameters.
""" """
return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
async def canonical_name(name):
"""Determine the canonical name of *name*.
See :py:func:`dns.resolver.Resolver.canonical_name` for more
information on the parameters and possible exceptions.
"""
return await get_default_resolver().canonical_name(name)
async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
resolver=None, backend=None): resolver=None, backend=None):
"""Find the name of the zone which contains the specified name. """Find the name of the zone which contains the specified name.
*name*, an absolute ``dns.name.Name`` or ``str``, the query name. See :py:func:`dns.resolver.Resolver.zone_for_name` for more
information on the parameters and possible exceptions.
*rdclass*, an ``int``, the query class.
*tcp*, a ``bool``. If ``True``, use TCP to make the query.
*resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the
resolver to use. If ``None``, the default resolver is used.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS
root. (This is only likely to happen if you're using non-default
root servers in your network and they are misconfigured.)
Returns a ``dns.name.Name``.
""" """
if isinstance(name, str): if isinstance(name, str):

26
lib/dns/asyncresolver.pyi Normal file
View file

@ -0,0 +1,26 @@
from typing import Union, Optional, List, Any, Dict
from . import exception, rdataclass, name, rdatatype, asyncbackend
async def resolve(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...
async def resolve_address(self, ipaddr: str,
*args: Any, **kwargs: Optional[Dict]):
...
class Resolver:
def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
configure : Optional[bool] = True):
self.nameservers : List[str]
async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
rdclass : Union[int,str] = rdataclass.IN,
tcp : bool = False, source : Optional[str] = None,
raise_on_no_answer=True, source_port : int = 0,
lifetime : Optional[float]=None,
search : Optional[bool]=None,
backend : Optional[asyncbackend.Backend]=None):
...

View file

@ -64,9 +64,6 @@ class Algorithm(dns.enum.IntEnum):
return 255 return 255
globals().update(Algorithm.__members__)
def algorithm_from_text(text): def algorithm_from_text(text):
"""Convert text into a DNSSEC algorithm value. """Convert text into a DNSSEC algorithm value.
@ -169,23 +166,15 @@ def make_ds(name, key, algorithm, origin=None):
def _find_candidate_keys(keys, rrsig): def _find_candidate_keys(keys, rrsig):
candidate_keys = []
value = keys.get(rrsig.signer) value = keys.get(rrsig.signer)
if value is None:
return None
if isinstance(value, dns.node.Node): if isinstance(value, dns.node.Node):
try: rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY)
rdataset = value.find_rdataset(dns.rdataclass.IN,
dns.rdatatype.DNSKEY)
except KeyError:
return None
else: else:
rdataset = value rdataset = value
for rdata in rdataset: if rdataset is None:
if rdata.algorithm == rrsig.algorithm and \ return None
key_id(rdata) == rrsig.key_tag: return [rd for rd in rdataset if
candidate_keys.append(rdata) rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag]
return candidate_keys
def _is_rsa(algorithm): def _is_rsa(algorithm):
@ -254,6 +243,82 @@ def _bytes_to_long(b):
return int.from_bytes(b, 'big') return int.from_bytes(b, 'big')
def _validate_signature(sig, data, key, chosen_hash):
if _is_rsa(key.algorithm):
keyptr = key.key
(bytes_,) = struct.unpack('!B', keyptr[0:1])
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack('!H', keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
try:
public_key = rsa.RSAPublicNumbers(
_bytes_to_long(rsa_e),
_bytes_to_long(rsa_n)).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
elif _is_dsa(key.algorithm):
keyptr = key.key
(t,) = struct.unpack('!B', keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
try:
public_key = dsa.DSAPublicNumbers(
_bytes_to_long(dsa_y),
dsa.DSAParameterNumbers(
_bytes_to_long(dsa_p),
_bytes_to_long(dsa_q),
_bytes_to_long(dsa_g))).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
public_key.verify(sig, data, chosen_hash)
elif _is_ecdsa(key.algorithm):
keyptr = key.key
if key.algorithm == Algorithm.ECDSAP256SHA256:
curve = ec.SECP256R1()
octets = 32
else:
curve = ec.SECP384R1()
octets = 48
ecdsa_x = keyptr[0:octets]
ecdsa_y = keyptr[octets:octets * 2]
try:
public_key = ec.EllipticCurvePublicNumbers(
curve=curve,
x=_bytes_to_long(ecdsa_x),
y=_bytes_to_long(ecdsa_y)).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
public_key.verify(sig, data, ec.ECDSA(chosen_hash))
elif _is_eddsa(key.algorithm):
keyptr = key.key
if key.algorithm == Algorithm.ED25519:
loader = ed25519.Ed25519PublicKey
else:
loader = ed448.Ed448PublicKey
try:
public_key = loader.from_public_bytes(keyptr)
except ValueError:
raise ValidationFailure('invalid public key')
public_key.verify(sig, data)
elif _is_gost(key.algorithm):
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' %
algorithm_to_text(key.algorithm))
else:
raise ValidationFailure('unknown algorithm %u' % key.algorithm)
def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
"""Validate an RRset against a single signature rdata, throwing an """Validate an RRset against a single signature rdata, throwing an
exception if validation is not successful. exception if validation is not successful.
@ -291,143 +356,69 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
if candidate_keys is None: if candidate_keys is None:
raise ValidationFailure('unknown key') raise ValidationFailure('unknown key')
# For convenience, allow the rrset to be specified as a (name,
# rdataset) tuple as well as a proper rrset
if isinstance(rrset, tuple):
rrname = rrset[0]
rdataset = rrset[1]
else:
rrname = rrset.name
rdataset = rrset
if now is None:
now = time.time()
if rrsig.expiration < now:
raise ValidationFailure('expired')
if rrsig.inception > now:
raise ValidationFailure('not yet valid')
if _is_dsa(rrsig.algorithm):
sig_r = rrsig.signature[1:21]
sig_s = rrsig.signature[21:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r),
_bytes_to_long(sig_s))
elif _is_ecdsa(rrsig.algorithm):
if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
octets = 32
else:
octets = 48
sig_r = rrsig.signature[0:octets]
sig_s = rrsig.signature[octets:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r),
_bytes_to_long(sig_s))
else:
sig = rrsig.signature
data = b''
data += rrsig.to_wire(origin=origin)[:18]
data += rrsig.signer.to_digestable(origin)
# Derelativize the name before considering labels.
rrname = rrname.derelativize(origin)
if len(rrname) - 1 < rrsig.labels:
raise ValidationFailure('owner name longer than RRSIG labels')
elif rrsig.labels < len(rrname) - 1:
suffix = rrname.split(rrsig.labels + 1)[1]
rrname = dns.name.from_text('*', suffix)
rrnamebuf = rrname.to_digestable()
rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass,
rrsig.original_ttl)
rdatas = [rdata.to_digestable(origin) for rdata in rdataset]
for rdata in sorted(rdatas):
data += rrnamebuf
data += rrfixed
rrlen = struct.pack('!H', len(rdata))
data += rrlen
data += rdata
chosen_hash = _make_hash(rrsig.algorithm)
for candidate_key in candidate_keys: for candidate_key in candidate_keys:
# For convenience, allow the rrset to be specified as a (name,
# rdataset) tuple as well as a proper rrset
if isinstance(rrset, tuple):
rrname = rrset[0]
rdataset = rrset[1]
else:
rrname = rrset.name
rdataset = rrset
if now is None:
now = time.time()
if rrsig.expiration < now:
raise ValidationFailure('expired')
if rrsig.inception > now:
raise ValidationFailure('not yet valid')
if _is_rsa(rrsig.algorithm):
keyptr = candidate_key.key
(bytes_,) = struct.unpack('!B', keyptr[0:1])
keyptr = keyptr[1:]
if bytes_ == 0:
(bytes_,) = struct.unpack('!H', keyptr[0:2])
keyptr = keyptr[2:]
rsa_e = keyptr[0:bytes_]
rsa_n = keyptr[bytes_:]
try:
public_key = rsa.RSAPublicNumbers(
_bytes_to_long(rsa_e),
_bytes_to_long(rsa_n)).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
sig = rrsig.signature
elif _is_dsa(rrsig.algorithm):
keyptr = candidate_key.key
(t,) = struct.unpack('!B', keyptr[0:1])
keyptr = keyptr[1:]
octets = 64 + t * 8
dsa_q = keyptr[0:20]
keyptr = keyptr[20:]
dsa_p = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_g = keyptr[0:octets]
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
try:
public_key = dsa.DSAPublicNumbers(
_bytes_to_long(dsa_y),
dsa.DSAParameterNumbers(
_bytes_to_long(dsa_p),
_bytes_to_long(dsa_q),
_bytes_to_long(dsa_g))).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
sig_r = rrsig.signature[1:21]
sig_s = rrsig.signature[21:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r),
_bytes_to_long(sig_s))
elif _is_ecdsa(rrsig.algorithm):
keyptr = candidate_key.key
if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
curve = ec.SECP256R1()
octets = 32
else:
curve = ec.SECP384R1()
octets = 48
ecdsa_x = keyptr[0:octets]
ecdsa_y = keyptr[octets:octets * 2]
try:
public_key = ec.EllipticCurvePublicNumbers(
curve=curve,
x=_bytes_to_long(ecdsa_x),
y=_bytes_to_long(ecdsa_y)).public_key(default_backend())
except ValueError:
raise ValidationFailure('invalid public key')
sig_r = rrsig.signature[0:octets]
sig_s = rrsig.signature[octets:]
sig = utils.encode_dss_signature(_bytes_to_long(sig_r),
_bytes_to_long(sig_s))
elif _is_eddsa(rrsig.algorithm):
keyptr = candidate_key.key
if rrsig.algorithm == Algorithm.ED25519:
loader = ed25519.Ed25519PublicKey
else:
loader = ed448.Ed448PublicKey
try:
public_key = loader.from_public_bytes(keyptr)
except ValueError:
raise ValidationFailure('invalid public key')
sig = rrsig.signature
elif _is_gost(rrsig.algorithm):
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' %
algorithm_to_text(rrsig.algorithm))
else:
raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm)
data = b''
data += rrsig.to_wire(origin=origin)[:18]
data += rrsig.signer.to_digestable(origin)
if rrsig.labels < len(rrname) - 1:
suffix = rrname.split(rrsig.labels + 1)[1]
rrname = dns.name.from_text('*', suffix)
rrnamebuf = rrname.to_digestable(origin)
rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass,
rrsig.original_ttl)
rrlist = sorted(rdataset)
for rr in rrlist:
data += rrnamebuf
data += rrfixed
rrdata = rr.to_digestable(origin)
rrlen = struct.pack('!H', len(rrdata))
data += rrlen
data += rrdata
chosen_hash = _make_hash(rrsig.algorithm)
try: try:
if _is_rsa(rrsig.algorithm): _validate_signature(sig, data, candidate_key, chosen_hash)
public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
elif _is_dsa(rrsig.algorithm):
public_key.verify(sig, data, chosen_hash)
elif _is_ecdsa(rrsig.algorithm):
public_key.verify(sig, data, ec.ECDSA(chosen_hash))
elif _is_eddsa(rrsig.algorithm):
public_key.verify(sig, data)
else:
# Raise here for code clarity; this won't actually ever happen
# since if the algorithm is really unknown we'd already have
# raised an exception above
raise ValidationFailure('unknown algorithm %u' %
rrsig.algorithm) # pragma: no cover
# If we got here, we successfully verified so we can return
# without error
return return
except InvalidSignature: except (InvalidSignature, ValidationFailure):
# this happens on an individual validation failure # this happens on an individual validation failure
continue continue
# nothing verified -- raise failure: # nothing verified -- raise failure:
@ -546,7 +537,7 @@ def nsec3_hash(domain, salt, iterations, algorithm):
domain_encoded = domain.canonicalize().to_wire() domain_encoded = domain.canonicalize().to_wire()
digest = hashlib.sha1(domain_encoded + salt_encoded).digest() digest = hashlib.sha1(domain_encoded + salt_encoded).digest()
for i in range(iterations): for _ in range(iterations):
digest = hashlib.sha1(digest + salt_encoded).digest() digest = hashlib.sha1(digest + salt_encoded).digest()
output = base64.b32encode(digest).decode("utf-8") output = base64.b32encode(digest).decode("utf-8")
@ -579,3 +570,25 @@ else:
validate = _validate # type: ignore validate = _validate # type: ignore
validate_rrsig = _validate_rrsig # type: ignore validate_rrsig = _validate_rrsig # type: ignore
_have_pyca = True _have_pyca = True
### BEGIN generated Algorithm constants
RSAMD5 = Algorithm.RSAMD5
DH = Algorithm.DH
DSA = Algorithm.DSA
ECC = Algorithm.ECC
RSASHA1 = Algorithm.RSASHA1
DSANSEC3SHA1 = Algorithm.DSANSEC3SHA1
RSASHA1NSEC3SHA1 = Algorithm.RSASHA1NSEC3SHA1
RSASHA256 = Algorithm.RSASHA256
RSASHA512 = Algorithm.RSASHA512
ECCGOST = Algorithm.ECCGOST
ECDSAP256SHA256 = Algorithm.ECDSAP256SHA256
ECDSAP384SHA384 = Algorithm.ECDSAP384SHA384
ED25519 = Algorithm.ED25519
ED448 = Algorithm.ED448
INDIRECT = Algorithm.INDIRECT
PRIVATEDNS = Algorithm.PRIVATEDNS
PRIVATEOID = Algorithm.PRIVATEOID
### END generated Algorithm constants

21
lib/dns/dnssec.pyi Normal file
View file

@ -0,0 +1,21 @@
from typing import Union, Dict, Tuple, Optional
from . import rdataset, rrset, exception, name, rdtypes, rdata, node
import dns.rdtypes.ANY.DS as DS
import dns.rdtypes.ANY.DNSKEY as DNSKEY
_have_pyca : bool
def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None:
...
def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None:
...
class ValidationFailure(exception.DNSException):
...
def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS:
...
def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str:
...

10
lib/dns/e164.pyi Normal file
View file

@ -0,0 +1,10 @@
from typing import Optional, Iterable
from . import name, resolver
def from_e164(text : str, origin=name.Name(".")) -> name.Name:
...
def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str:
...
def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer:
...

View file

@ -23,6 +23,8 @@ import struct
import dns.enum import dns.enum
import dns.inet import dns.inet
import dns.rdata
class OptionType(dns.enum.IntEnum): class OptionType(dns.enum.IntEnum):
#: NSID #: NSID
@ -45,12 +47,13 @@ class OptionType(dns.enum.IntEnum):
PADDING = 12 PADDING = 12
#: CHAIN #: CHAIN
CHAIN = 13 CHAIN = 13
#: EDE (extended-dns-error)
EDE = 15
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):
return 65535 return 65535
globals().update(OptionType.__members__)
class Option: class Option:
@ -61,7 +64,7 @@ class Option:
*otype*, an ``int``, is the option type. *otype*, an ``int``, is the option type.
""" """
self.otype = otype self.otype = OptionType.make(otype)
def to_wire(self, file=None): def to_wire(self, file=None):
"""Convert an option to wire format. """Convert an option to wire format.
@ -149,7 +152,7 @@ class GenericOption(Option):
def __init__(self, otype, data): def __init__(self, otype, data):
super().__init__(otype) super().__init__(otype)
self.data = data self.data = dns.rdata.Rdata._as_bytes(data, True)
def to_wire(self, file=None): def to_wire(self, file=None):
if file: if file:
@ -186,12 +189,18 @@ class ECSOption(Option):
self.family = 2 self.family = 2
if srclen is None: if srclen is None:
srclen = 56 srclen = 56
address = dns.rdata.Rdata._as_ipv6_address(address)
srclen = dns.rdata.Rdata._as_int(srclen, 0, 128)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 128)
elif af == socket.AF_INET: elif af == socket.AF_INET:
self.family = 1 self.family = 1
if srclen is None: if srclen is None:
srclen = 24 srclen = 24
else: address = dns.rdata.Rdata._as_ipv4_address(address)
raise ValueError('Bad ip family') srclen = dns.rdata.Rdata._as_int(srclen, 0, 32)
scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32)
else: # pragma: no cover (this will never happen)
raise ValueError('Bad address family')
self.address = address self.address = address
self.srclen = srclen self.srclen = srclen
@ -293,10 +302,95 @@ class ECSOption(Option):
return cls(addr, src, scope) return cls(addr, src, scope)
class EDECode(dns.enum.IntEnum):
OTHER = 0
UNSUPPORTED_DNSKEY_ALGORITHM = 1
UNSUPPORTED_DS_DIGEST_TYPE = 2
STALE_ANSWER = 3
FORGED_ANSWER = 4
DNSSEC_INDETERMINATE = 5
DNSSEC_BOGUS = 6
SIGNATURE_EXPIRED = 7
SIGNATURE_NOT_YET_VALID = 8
DNSKEY_MISSING = 9
RRSIGS_MISSING = 10
NO_ZONE_KEY_BIT_SET = 11
NSEC_MISSING = 12
CACHED_ERROR = 13
NOT_READY = 14
BLOCKED = 15
CENSORED = 16
FILTERED = 17
PROHIBITED = 18
STALE_NXDOMAIN_ANSWER = 19
NOT_AUTHORITATIVE = 20
NOT_SUPPORTED = 21
NO_REACHABLE_AUTHORITY = 22
NETWORK_ERROR = 23
INVALID_DATA = 24
@classmethod
def _maximum(cls):
return 65535
class EDEOption(Option):
"""Extended DNS Error (EDE, RFC8914)"""
def __init__(self, code, text=None):
"""*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the
extended error.
*text*, a ``str`` or ``None``, specifying additional information about
the error.
"""
super().__init__(OptionType.EDE)
self.code = EDECode.make(code)
if text is not None and not isinstance(text, str):
raise ValueError('text must be string or None')
self.code = code
self.text = text
def to_text(self):
output = f'EDE {self.code}'
if self.text is not None:
output += f': {self.text}'
return output
def to_wire(self, file=None):
value = struct.pack('!H', self.code)
if self.text is not None:
value += self.text.encode('utf8')
if file:
file.write(value)
else:
return value
@classmethod
def from_wire_parser(cls, otype, parser):
code = parser.get_uint16()
text = parser.get_remaining()
if text:
if text[-1] == 0: # text MAY be null-terminated
text = text[:-1]
text = text.decode('utf8')
else:
text = None
return cls(code, text)
_type_to_class = { _type_to_class = {
OptionType.ECS: ECSOption OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
} }
def get_option_class(otype): def get_option_class(otype):
"""Return the class for the specified option type. """Return the class for the specified option type.
@ -342,3 +436,29 @@ def option_from_wire(otype, wire, current, olen):
parser = dns.wire.Parser(wire, current) parser = dns.wire.Parser(wire, current)
with parser.restrict_to(olen): with parser.restrict_to(olen):
return option_from_wire_parser(otype, parser) return option_from_wire_parser(otype, parser)
def register_type(implementation, otype):
"""Register the implementation of an option type.
*implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
*otype*, an ``int``, is the option type.
"""
_type_to_class[otype] = implementation
### BEGIN generated OptionType constants
NSID = OptionType.NSID
DAU = OptionType.DAU
DHU = OptionType.DHU
N3U = OptionType.N3U
ECS = OptionType.ECS
EXPIRE = OptionType.EXPIRE
COOKIE = OptionType.COOKIE
KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN
EDE = OptionType.EDE
### END generated OptionType constants

10
lib/dns/entropy.pyi Normal file
View file

@ -0,0 +1,10 @@
from typing import Optional
from random import SystemRandom
system_random : Optional[SystemRandom]
def random_16() -> int:
pass
def between(first: int, last: int) -> int:
pass

View file

@ -75,7 +75,7 @@ class IntEnum(enum.IntEnum):
@classmethod @classmethod
def _maximum(cls): def _maximum(cls):
raise NotImplementedError raise NotImplementedError # pragma: no cover
@classmethod @classmethod
def _short_name(cls): def _short_name(cls):

View file

@ -126,3 +126,17 @@ class Timeout(DNSException):
"""The DNS operation timed out.""" """The DNS operation timed out."""
supp_kwargs = {'timeout'} supp_kwargs = {'timeout'}
fmt = "The DNS operation timed out after {timeout} seconds" fmt = "The DNS operation timed out after {timeout} seconds"
class ExceptionWrapper:
def __init__(self, exception_class):
self.exception_class = exception_class
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val,
self.exception_class):
raise self.exception_class(str(exc_val)) from exc_val
return False

10
lib/dns/exception.pyi Normal file
View file

@ -0,0 +1,10 @@
from typing import Set, Optional, Dict
class DNSException(Exception):
supp_kwargs : Set[str]
kwargs : Optional[Dict]
fmt : Optional[str]
class SyntaxError(DNSException): ...
class FormError(DNSException): ...
class Timeout(DNSException): ...

View file

@ -37,8 +37,6 @@ class Flag(enum.IntFlag):
#: Checking Disabled #: Checking Disabled
CD = 0x0010 CD = 0x0010
globals().update(Flag.__members__)
# EDNS flags # EDNS flags
@ -47,9 +45,6 @@ class EDNSFlag(enum.IntFlag):
DO = 0x8000 DO = 0x8000
globals().update(EDNSFlag.__members__)
def _from_text(text, enum_class): def _from_text(text, enum_class):
flags = 0 flags = 0
tokens = text.split() tokens = text.split()
@ -104,3 +99,21 @@ def edns_to_text(flags):
""" """
return _to_text(flags, EDNSFlag) return _to_text(flags, EDNSFlag)
### BEGIN generated Flag constants
QR = Flag.QR
AA = Flag.AA
TC = Flag.TC
RD = Flag.RD
RA = Flag.RA
AD = Flag.AD
CD = Flag.CD
### END generated Flag constants
### BEGIN generated EDNSFlag constants
DO = EDNSFlag.DO
### END generated EDNSFlag constants

View file

@ -28,11 +28,12 @@ def from_text(text):
Returns a tuple of three ``int`` values ``(start, stop, step)``. Returns a tuple of three ``int`` values ``(start, stop, step)``.
""" """
# TODO, figure out the bounds on start, stop and step. start = -1
stop = -1
step = 1 step = 1
cur = '' cur = ''
state = 0 state = 0
# state 0 1 2 3 4 # state 0 1 2
# x - y / z # x - y / z
if text and text[0] == '-': if text and text[0] == '-':
@ -42,28 +43,27 @@ def from_text(text):
if c == '-' and state == 0: if c == '-' and state == 0:
start = int(cur) start = int(cur)
cur = '' cur = ''
state = 2 state = 1
elif c == '/': elif c == '/':
stop = int(cur) stop = int(cur)
cur = '' cur = ''
state = 4 state = 2
elif c.isdigit(): elif c.isdigit():
cur += c cur += c
else: else:
raise dns.exception.SyntaxError("Could not parse %s" % (c)) raise dns.exception.SyntaxError("Could not parse %s" % (c))
if state in (1, 3): if state == 0:
raise dns.exception.SyntaxError() raise dns.exception.SyntaxError("no stop value specified")
elif state == 1:
if state == 2:
stop = int(cur) stop = int(cur)
else:
if state == 4: assert state == 2
step = int(cur) step = int(cur)
assert step >= 1 assert step >= 1
assert start >= 0 assert start >= 0
assert start <= stop if start > stop:
# TODO, can start == stop? raise dns.exception.SyntaxError('start must be <= stop')
return (start, stop, step) return (start, stop, step)

70
lib/dns/immutable.py Normal file
View file

@ -0,0 +1,70 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc
import sys
# pylint: disable=unused-import
if sys.version_info >= (3, 7):
odict = dict
from dns._immutable_ctx import immutable
else:
# pragma: no cover
from collections import OrderedDict as odict
from dns._immutable_attr import immutable # noqa
# pylint: enable=unused-import
@immutable
class Dict(collections.abc.Mapping):
def __init__(self, dictionary, no_copy=False):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external
references to the dictionary.
"""
if no_copy and isinstance(dictionary, odict):
self._odict = dictionary
else:
self._odict = odict(dictionary)
self._hash = None
def __getitem__(self, key):
return self._odict.__getitem__(key)
def __hash__(self): # pylint: disable=invalid-hash-returned
if self._hash is None:
h = 0
for key in sorted(self._odict.keys()):
h ^= hash(key)
object.__setattr__(self, '_hash', h)
# this does return an int, but pylint doesn't figure that out
return self._hash
def __len__(self):
return len(self._odict)
def __iter__(self):
return iter(self._odict)
def constify(o):
"""
Convert mutable types to immutable types.
"""
if isinstance(o, bytearray):
return bytes(o)
if isinstance(o, tuple):
try:
hash(o)
return o
except Exception:
return tuple(constify(elt) for elt in o)
if isinstance(o, list):
return tuple(constify(elt) for elt in o)
if isinstance(o, dict):
cdict = odict()
for k, v in o.items():
cdict[k] = constify(v)
return Dict(cdict, True)
return o

View file

@ -162,7 +162,7 @@ def low_level_address_tuple(high_tuple, af=None):
return (addrpart, port, 0, int(scope)) return (addrpart, port, 0, int(scope))
try: try:
return (addrpart, port, 0, socket.if_nametoindex(scope)) return (addrpart, port, 0, socket.if_nametoindex(scope))
except AttributeError: except AttributeError: # pragma: no cover (we can't really test this)
ai_flags = socket.AI_NUMERICHOST ai_flags = socket.AI_NUMERICHOST
((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
return tup return tup

4
lib/dns/inet.pyi Normal file
View file

@ -0,0 +1,4 @@
from typing import Union
from socket import AddressFamily
AF_INET6 : Union[int, AddressFamily]

View file

@ -121,7 +121,13 @@ def inet_aton(text, ignore_scope=False):
elif l > 2: elif l > 2:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
if text == b'::': if text == b'':
raise dns.exception.SyntaxError
elif text.endswith(b':') and not text.endswith(b'::'):
raise dns.exception.SyntaxError
elif text.startswith(b':') and not text.startswith(b'::'):
raise dns.exception.SyntaxError
elif text == b'::':
text = b'0::' text = b'0::'
# #
# Get rid of the icky dot-quad syntax if we have it. # Get rid of the icky dot-quad syntax if we have it.
@ -129,9 +135,9 @@ def inet_aton(text, ignore_scope=False):
m = _v4_ending.match(text) m = _v4_ending.match(text)
if m is not None: if m is not None:
b = dns.ipv4.inet_aton(m.group(2)) b = dns.ipv4.inet_aton(m.group(2))
text = (u"{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(),
b[0], b[1], b[2], b[0], b[1], b[2],
b[3])).encode() b[3])).encode()
# #
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to # Try to turn '::<whatever>' into ':<whatever>'; if no match try to
# turn '<whatever>::' into '<whatever>:' # turn '<whatever>::' into '<whatever>:'
@ -157,7 +163,7 @@ def inet_aton(text, ignore_scope=False):
if seen_empty: if seen_empty:
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
seen_empty = True seen_empty = True
for i in range(0, 8 - l + 1): for _ in range(0, 8 - l + 1):
canonical.append(b'0000') canonical.append(b'0000')
else: else:
lc = len(c) lc = len(c)

View file

@ -35,6 +35,7 @@ import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.rrset import dns.rrset
import dns.renderer import dns.renderer
import dns.ttl
import dns.tsig import dns.tsig
import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.OPT
import dns.rdtypes.ANY.TSIG import dns.rdtypes.ANY.TSIG
@ -80,6 +81,21 @@ class Truncated(dns.exception.DNSException):
return self.kwargs['message'] return self.kwargs['message']
class NotQueryResponse(dns.exception.DNSException):
"""Message is not a response to a query."""
class ChainTooLong(dns.exception.DNSException):
"""The CNAME chain is too long."""
class AnswerForNXDOMAIN(dns.exception.DNSException):
"""The rcode is NXDOMAIN but an answer was found."""
class NoPreviousName(dns.exception.SyntaxError):
"""No previous name was known."""
class MessageSection(dns.enum.IntEnum): class MessageSection(dns.enum.IntEnum):
"""Message sections""" """Message sections"""
QUESTION = 0 QUESTION = 0
@ -91,8 +107,15 @@ class MessageSection(dns.enum.IntEnum):
def _maximum(cls): def _maximum(cls):
return 3 return 3
globals().update(MessageSection.__members__)
class MessageError:
def __init__(self, exception, offset):
self.exception = exception
self.offset = offset
DEFAULT_EDNS_PAYLOAD = 1232
MAX_CHAIN = 16
class Message: class Message:
"""A DNS message.""" """A DNS message."""
@ -115,6 +138,7 @@ class Message:
self.origin = None self.origin = None
self.tsig_ctx = None self.tsig_ctx = None
self.index = {} self.index = {}
self.errors = []
@property @property
def question(self): def question(self):
@ -169,10 +193,8 @@ class Message:
s = io.StringIO() s = io.StringIO()
s.write('id %d\n' % self.id) s.write('id %d\n' % self.id)
s.write('opcode %s\n' % s.write('opcode %s\n' % dns.opcode.to_text(self.opcode()))
dns.opcode.to_text(dns.opcode.from_flags(self.flags))) s.write('rcode %s\n' % dns.rcode.to_text(self.rcode()))
rc = dns.rcode.from_flags(self.flags, self.ednsflags)
s.write('rcode %s\n' % dns.rcode.to_text(rc))
s.write('flags %s\n' % dns.flags.to_text(self.flags)) s.write('flags %s\n' % dns.flags.to_text(self.flags))
if self.edns >= 0: if self.edns >= 0:
s.write('edns %s\n' % self.edns) s.write('edns %s\n' % self.edns)
@ -221,7 +243,8 @@ class Message:
return not self.__eq__(other) return not self.__eq__(other)
def is_response(self, other): def is_response(self, other):
"""Is *other* a response this message? """Is *other*, also a ``dns.message.Message``, a response to this
message?
Returns a ``bool``. Returns a ``bool``.
""" """
@ -231,9 +254,13 @@ class Message:
dns.opcode.from_flags(self.flags) != \ dns.opcode.from_flags(self.flags) != \
dns.opcode.from_flags(other.flags): dns.opcode.from_flags(other.flags):
return False return False
if dns.rcode.from_flags(other.flags, other.ednsflags) != \ if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL,
dns.rcode.NOERROR: dns.rcode.NOTIMP, dns.rcode.REFUSED}:
return True # We don't check the question section in these cases if
# the other question section is empty, even though they
# still really ought to have a question section.
if len(other.question) == 0:
return True
if dns.opcode.is_update(self.flags): if dns.opcode.is_update(self.flags):
# This is assuming the "sender doesn't include anything # This is assuming the "sender doesn't include anything
# from the update", but we don't care to check the other # from the update", but we don't care to check the other
@ -330,7 +357,8 @@ class Message:
return rrset return rrset
else: else:
for rrset in section: for rrset in section:
if rrset.match(name, rdclass, rdtype, covers, deleting): if rrset.full_match(name, rdclass, rdtype, covers,
deleting):
return rrset return rrset
if not create: if not create:
raise KeyError raise KeyError
@ -403,8 +431,8 @@ class Message:
*multi*, a ``bool``, should be set to ``True`` if this message is *multi*, a ``bool``, should be set to ``True`` if this message is
part of a multiple message sequence. part of a multiple message sequence.
*tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
when signing zone transfers. ongoing TSIG context, used when signing zone transfers.
Raises ``dns.exception.TooBig`` if *max_size* was exceeded. Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
@ -467,8 +495,8 @@ class Message:
*key*, a ``dns.tsig.Key`` is the key to use. If a key is specified, *key*, a ``dns.tsig.Key`` is the key to use. If a key is specified,
the *keyring* and *algorithm* fields are not used. the *keyring* and *algorithm* fields are not used.
*keyring*, a ``dict`` or ``dns.tsig.Key``, is either the TSIG *keyring*, a ``dict``, ``callable`` or ``dns.tsig.Key``, is either
keyring or key to use. the TSIG keyring or key to use.
The format of a keyring dict is a mapping from TSIG key name, as The format of a keyring dict is a mapping from TSIG key name, as
``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``. ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``.
@ -476,7 +504,9 @@ class Message:
used will be the first key in the *keyring*. Note that the order of used will be the first key in the *keyring*. Note that the order of
keys in a dictionary is not defined, so applications should supply a keys in a dictionary is not defined, so applications should supply a
keyname when a ``dict`` keyring is used, unless they know the keyring keyname when a ``dict`` keyring is used, unless they know the keyring
contains only one key. contains only one key. If a ``callable`` keyring is specified, the
callable will be called with the message and the keyname, and is
expected to return a key.
*keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of
thes TSIG key to use; defaults to ``None``. If *keyring* is a thes TSIG key to use; defaults to ``None``. If *keyring* is a
@ -497,7 +527,10 @@ class Message:
""" """
if isinstance(keyring, dns.tsig.Key): if isinstance(keyring, dns.tsig.Key):
self.keyring = keyring key = keyring
keyname = key.name
elif callable(keyring):
key = keyring(self, keyname)
else: else:
if isinstance(keyname, str): if isinstance(keyname, str):
keyname = dns.name.from_text(keyname) keyname = dns.name.from_text(keyname)
@ -506,7 +539,7 @@ class Message:
key = keyring[keyname] key = keyring[keyname]
if isinstance(key, bytes): if isinstance(key, bytes):
key = dns.tsig.Key(keyname, key, algorithm) key = dns.tsig.Key(keyname, key, algorithm)
self.keyring = key self.keyring = key
if original_id is None: if original_id is None:
original_id = self.id original_id = self.id
self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge, self.tsig = self._make_tsig(keyname, self.keyring.algorithm, 0, fudge,
@ -545,13 +578,13 @@ class Message:
return bool(self.tsig) return bool(self.tsig)
@staticmethod @staticmethod
def _make_opt(flags=0, payload=1280, options=None): def _make_opt(flags=0, payload=DEFAULT_EDNS_PAYLOAD, options=None):
opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT,
options or ()) options or ())
return dns.rrset.from_rdata(dns.name.root, int(flags), opt) return dns.rrset.from_rdata(dns.name.root, int(flags), opt)
def use_edns(self, edns=0, ednsflags=0, payload=1280, request_payload=None, def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD,
options=None): request_payload=None, options=None):
"""Configure EDNS behavior. """Configure EDNS behavior.
*edns*, an ``int``, is the EDNS level to use. Specifying *edns*, an ``int``, is the EDNS level to use. Specifying
@ -575,26 +608,21 @@ class Message:
if edns is None or edns is False: if edns is None or edns is False:
edns = -1 edns = -1
if edns is True: elif edns is True:
edns = 0 edns = 0
if request_payload is None:
request_payload = payload
if edns < 0: if edns < 0:
ednsflags = 0 self.opt = None
payload = 0 self.request_payload = 0
request_payload = 0
options = []
else: else:
# make sure the EDNS version in ednsflags agrees with edns # make sure the EDNS version in ednsflags agrees with edns
ednsflags &= 0xFF00FFFF ednsflags &= 0xFF00FFFF
ednsflags |= (edns << 16) ednsflags |= (edns << 16)
if options is None: if options is None:
options = [] options = []
if edns >= 0:
self.opt = self._make_opt(ednsflags, payload, options) self.opt = self._make_opt(ednsflags, payload, options)
else: if request_payload is None:
self.opt = None request_payload = payload
self.request_payload = request_payload self.request_payload = request_payload
@property @property
def edns(self): def edns(self):
@ -650,7 +678,7 @@ class Message:
Returns an ``int``. Returns an ``int``.
""" """
return dns.rcode.from_flags(self.flags, self.ednsflags) return dns.rcode.from_flags(int(self.flags), int(self.ednsflags))
def set_rcode(self, rcode): def set_rcode(self, rcode):
"""Set the rcode. """Set the rcode.
@ -668,7 +696,7 @@ class Message:
Returns an ``int``. Returns an ``int``.
""" """
return dns.opcode.from_flags(self.flags) return dns.opcode.from_flags(int(self.flags))
def set_opcode(self, opcode): def set_opcode(self, opcode):
"""Set the opcode. """Set the opcode.
@ -682,9 +710,13 @@ class Message:
# What the caller picked is fine. # What the caller picked is fine.
return value return value
# pylint: disable=unused-argument
def _parse_rr_header(self, section, name, rdclass, rdtype): def _parse_rr_header(self, section, name, rdclass, rdtype):
return (rdclass, rdtype, None, False) return (rdclass, rdtype, None, False)
# pylint: enable=unused-argument
def _parse_special_rr_header(self, section, count, position, def _parse_special_rr_header(self, section, count, position,
name, rdclass, rdtype): name, rdclass, rdtype):
if rdtype == dns.rdatatype.OPT: if rdtype == dns.rdatatype.OPT:
@ -699,14 +731,129 @@ class Message:
return (rdclass, rdtype, None, False) return (rdclass, rdtype, None, False)
class ChainingResult:
"""The result of a call to dns.message.QueryMessage.resolve_chaining().
The ``answer`` attribute is the answer RRSet, or ``None`` if it doesn't
exist.
The ``canonical_name`` attribute is the canonical name after all
chaining has been applied (this is the name as ``rrset.name`` in cases
where rrset is not ``None``).
The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to
use if caching the data. It is the smallest of all the CNAME TTLs
and either the answer TTL if it exists or the SOA TTL and SOA
minimum values for negative answers.
The ``cnames`` attribute is a list of all the CNAME RRSets followed to
get to the canonical name.
"""
def __init__(self, canonical_name, answer, minimum_ttl, cnames):
self.canonical_name = canonical_name
self.answer = answer
self.minimum_ttl = minimum_ttl
self.cnames = cnames
class QueryMessage(Message): class QueryMessage(Message):
pass def resolve_chaining(self):
"""Follow the CNAME chain in the response to determine the answer
RRset.
Raises ``dns.message.NotQueryResponse`` if the message is not
a response.
Raises ``dns.message.ChainTooLong`` if the CNAME chain is too long.
Raises ``dns.message.AnswerForNXDOMAIN`` if the rcode is NXDOMAIN
but an answer was found.
Raises ``dns.exception.FormError`` if the question count is not 1.
Returns a ChainingResult object.
"""
if self.flags & dns.flags.QR == 0:
raise NotQueryResponse
if len(self.question) != 1:
raise dns.exception.FormError
question = self.question[0]
qname = question.name
min_ttl = dns.ttl.MAX_TTL
answer = None
count = 0
cnames = []
while count < MAX_CHAIN:
try:
answer = self.find_rrset(self.answer, qname, question.rdclass,
question.rdtype)
min_ttl = min(min_ttl, answer.ttl)
break
except KeyError:
if question.rdtype != dns.rdatatype.CNAME:
try:
crrset = self.find_rrset(self.answer, qname,
question.rdclass,
dns.rdatatype.CNAME)
cnames.append(crrset)
min_ttl = min(min_ttl, crrset.ttl)
for rd in crrset:
qname = rd.target
break
count += 1
continue
except KeyError:
# Exit the chaining loop
break
else:
# Exit the chaining loop
break
if count >= MAX_CHAIN:
raise ChainTooLong
if self.rcode() == dns.rcode.NXDOMAIN and answer is not None:
raise AnswerForNXDOMAIN
if answer is None:
# Further minimize the TTL with NCACHE.
auname = qname
while True:
# Look for an SOA RR whose owner name is a superdomain
# of qname.
try:
srrset = self.find_rrset(self.authority, auname,
question.rdclass,
dns.rdatatype.SOA)
min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum)
break
except KeyError:
try:
auname = auname.parent()
except dns.name.NoParent:
break
return ChainingResult(qname, answer, min_ttl, cnames)
def canonical_name(self):
"""Return the canonical name of the first name in the question
section.
Raises ``dns.message.NotQueryResponse`` if the message is not
a response.
Raises ``dns.message.ChainTooLong`` if the CNAME chain is too long.
Raises ``dns.message.AnswerForNXDOMAIN`` if the rcode is NXDOMAIN
but an answer was found.
Raises ``dns.exception.FormError`` if the question count is not 1.
"""
return self.resolve_chaining().canonical_name
def _maybe_import_update(): def _maybe_import_update():
# We avoid circular imports by doing this here. We do it in another # We avoid circular imports by doing this here. We do it in another
# function as doing it in _message_factory_from_opcode() makes "dns" # function as doing it in _message_factory_from_opcode() makes "dns"
# a local symbol, and the first line fails :) # a local symbol, and the first line fails :)
# pylint: disable=redefined-outer-name,import-outside-toplevel,unused-import
import dns.update # noqa: F401 import dns.update # noqa: F401
@ -733,11 +880,14 @@ class _WireReader:
ignore_trailing: Ignore trailing junk at end of request? ignore_trailing: Ignore trailing junk at end of request?
multi: Is this message part of a multi-message sequence? multi: Is this message part of a multi-message sequence?
DNS dynamic updates. DNS dynamic updates.
continue_on_error: try to extract as much information as possible from
the message, accumulating MessageErrors in the *errors* attribute instead of
raising them.
""" """
def __init__(self, wire, initialize_message, question_only=False, def __init__(self, wire, initialize_message, question_only=False,
one_rr_per_rrset=False, ignore_trailing=False, one_rr_per_rrset=False, ignore_trailing=False,
keyring=None, multi=False): keyring=None, multi=False, continue_on_error=False):
self.parser = dns.wire.Parser(wire) self.parser = dns.wire.Parser(wire)
self.message = None self.message = None
self.initialize_message = initialize_message self.initialize_message = initialize_message
@ -746,6 +896,8 @@ class _WireReader:
self.ignore_trailing = ignore_trailing self.ignore_trailing = ignore_trailing
self.keyring = keyring self.keyring = keyring
self.multi = multi self.multi = multi
self.continue_on_error = continue_on_error
self.errors = []
def _get_question(self, section_number, qcount): def _get_question(self, section_number, qcount):
"""Read the next *qcount* records from the wire data and add them to """Read the next *qcount* records from the wire data and add them to
@ -753,7 +905,7 @@ class _WireReader:
""" """
section = self.message.sections[section_number] section = self.message.sections[section_number]
for i in range(qcount): for _ in range(qcount):
qname = self.parser.get_name(self.message.origin) qname = self.parser.get_name(self.message.origin)
(rdtype, rdclass) = self.parser.get_struct('!HH') (rdtype, rdclass) = self.parser.get_struct('!HH')
(rdclass, rdtype, _, _) = \ (rdclass, rdtype, _, _) = \
@ -762,11 +914,14 @@ class _WireReader:
self.message.find_rrset(section, qname, rdclass, rdtype, self.message.find_rrset(section, qname, rdclass, rdtype,
create=True, force_unique=True) create=True, force_unique=True)
def _add_error(self, e):
self.errors.append(MessageError(e, self.parser.current))
def _get_section(self, section_number, count): def _get_section(self, section_number, count):
"""Read the next I{count} records from the wire data and add them to """Read the next I{count} records from the wire data and add them to
the specified section. the specified section.
section: the section of the message to which to add records section_number: the section of the message to which to add records
count: the number of records to read count: the number of records to read
""" """
@ -789,53 +944,65 @@ class _WireReader:
(rdclass, rdtype, deleting, empty) = \ (rdclass, rdtype, deleting, empty) = \
self.message._parse_rr_header(section_number, self.message._parse_rr_header(section_number,
name, rdclass, rdtype) name, rdclass, rdtype)
if empty: try:
if rdlen > 0: rdata_start = self.parser.current
raise dns.exception.FormError if empty:
rd = None if rdlen > 0:
covers = dns.rdatatype.NONE raise dns.exception.FormError
else: rd = None
with self.parser.restrict_to(rdlen): covers = dns.rdatatype.NONE
rd = dns.rdata.from_wire_parser(rdclass, rdtype,
self.parser,
self.message.origin)
covers = rd.covers()
if self.message.xfr and rdtype == dns.rdatatype.SOA:
force_unique = True
if rdtype == dns.rdatatype.OPT:
self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
elif rdtype == dns.rdatatype.TSIG:
if self.keyring is None:
raise UnknownTSIGKey('got signed message without keyring')
if isinstance(self.keyring, dict):
key = self.keyring.get(absolute_name)
if isinstance(key, bytes):
key = dns.tsig.Key(absolute_name, key, rd.algorithm)
else: else:
key = self.keyring with self.parser.restrict_to(rdlen):
if key is None: rd = dns.rdata.from_wire_parser(rdclass, rdtype,
raise UnknownTSIGKey("key '%s' unknown" % name) self.parser,
self.message.keyring = key self.message.origin)
self.message.tsig_ctx = \ covers = rd.covers()
dns.tsig.validate(self.parser.wire, if self.message.xfr and rdtype == dns.rdatatype.SOA:
key, force_unique = True
absolute_name, if rdtype == dns.rdatatype.OPT:
rd, self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
int(time.time()), elif rdtype == dns.rdatatype.TSIG:
self.message.request_mac, if self.keyring is None:
rr_start, raise UnknownTSIGKey('got signed message without '
self.message.tsig_ctx, 'keyring')
self.multi) if isinstance(self.keyring, dict):
self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) key = self.keyring.get(absolute_name)
else: if isinstance(key, bytes):
rrset = self.message.find_rrset(section, name, key = dns.tsig.Key(absolute_name, key, rd.algorithm)
rdclass, rdtype, covers, elif callable(self.keyring):
deleting, True, key = self.keyring(self.message, absolute_name)
force_unique) else:
if rd is not None: key = self.keyring
if ttl > 0x7fffffff: if key is None:
ttl = 0 raise UnknownTSIGKey("key '%s' unknown" % name)
rrset.add(rd, ttl) self.message.keyring = key
self.message.tsig_ctx = \
dns.tsig.validate(self.parser.wire,
key,
absolute_name,
rd,
int(time.time()),
self.message.request_mac,
rr_start,
self.message.tsig_ctx,
self.multi)
self.message.tsig = dns.rrset.from_rdata(absolute_name, 0,
rd)
else:
rrset = self.message.find_rrset(section, name,
rdclass, rdtype, covers,
deleting, True,
force_unique)
if rd is not None:
if ttl > 0x7fffffff:
ttl = 0
rrset.add(rd, ttl)
except Exception as e:
if self.continue_on_error:
self._add_error(e)
self.parser.seek(rdata_start + rdlen)
else:
raise
def read(self): def read(self):
"""Read a wire format DNS message and build a dns.message.Message """Read a wire format DNS message and build a dns.message.Message
@ -847,73 +1014,86 @@ class _WireReader:
self.parser.get_struct('!HHHHHH') self.parser.get_struct('!HHHHHH')
factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
self.message = factory(id=id) self.message = factory(id=id)
self.message.flags = flags self.message.flags = dns.flags.Flag(flags)
self.initialize_message(self.message) self.initialize_message(self.message)
self.one_rr_per_rrset = \ self.one_rr_per_rrset = \
self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
self._get_question(MessageSection.QUESTION, qcount) try:
if self.question_only: self._get_question(MessageSection.QUESTION, qcount)
return if self.question_only:
self._get_section(MessageSection.ANSWER, ancount) return self.message
self._get_section(MessageSection.AUTHORITY, aucount) self._get_section(MessageSection.ANSWER, ancount)
self._get_section(MessageSection.ADDITIONAL, adcount) self._get_section(MessageSection.AUTHORITY, aucount)
if not self.ignore_trailing and self.parser.remaining() != 0: self._get_section(MessageSection.ADDITIONAL, adcount)
raise TrailingJunk if not self.ignore_trailing and self.parser.remaining() != 0:
if self.multi and self.message.tsig_ctx and not self.message.had_tsig: raise TrailingJunk
self.message.tsig_ctx.update(self.parser.wire) if self.multi and self.message.tsig_ctx and \
not self.message.had_tsig:
self.message.tsig_ctx.update(self.parser.wire)
except Exception as e:
if self.continue_on_error:
self._add_error(e)
else:
raise
return self.message return self.message
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
tsig_ctx=None, multi=False, tsig_ctx=None, multi=False,
question_only=False, one_rr_per_rrset=False, question_only=False, one_rr_per_rrset=False,
ignore_trailing=False, raise_on_truncation=False): ignore_trailing=False, raise_on_truncation=False,
"""Convert a DNS wire format message into a message continue_on_error=False):
object. """Convert a DNS wire format message into a message object.
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
if the message is signed. message is signed.
*request_mac*, a ``bytes``. If the message is a response to a *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed
TSIG-signed request, *request_mac* should be set to the MAC of request, *request_mac* should be set to the MAC of that request.
that request.
*xfr*, a ``bool``, should be set to ``True`` if this message is part of *xfr*, a ``bool``, should be set to ``True`` if this message is part of a
a zone transfer. zone transfer.
*origin*, a ``dns.name.Name`` or ``None``. If the message is part *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone
of a zone transfer, *origin* should be the origin name of the transfer, *origin* should be the origin name of the zone. If not ``None``,
zone. If not ``None``, names will be relativized to the origin. names will be relativized to the origin.
*tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
when validating zone transfers. ongoing TSIG context, used when validating zone transfers.
*multi*, a ``bool``, should be set to ``True`` if this message is *multi*, a ``bool``, should be set to ``True`` if this message is part of a
part of a multiple message sequence. multiple message sequence.
*question_only*, a ``bool``. If ``True``, read only up to *question_only*, a ``bool``. If ``True``, read only up to the end of the
the end of the question section. question section.
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
own RRset. RRset.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of
junk at end of the message. the message.
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the
the TC bit is set. TC bit is set.
*continue_on_error*, a ``bool``. If ``True``, try to continue parsing even
if errors occur. Erroneous rdata will be ignored. Errors will be
accumulated as a list of MessageError objects in the message's ``errors``
attribute. This option is recommended only for DNS analysis tools, or for
use in a server as part of an error handling path. The default is
``False``.
Raises ``dns.message.ShortHeader`` if the message is less than 12 octets Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
long. long.
Raises ``dns.message.TrailingJunk`` if there were octets in the message Raises ``dns.message.TrailingJunk`` if there were octets in the message past
past the end of the proper DNS message, and *ignore_trailing* is ``False``. the end of the proper DNS message, and *ignore_trailing* is ``False``.
Raises ``dns.message.BadEDNS`` if an OPT record was in the Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or
wrong section, or occurred more than once. occurred more than once.
Raises ``dns.message.BadTSIG`` if a TSIG record was not the last Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of
record of the additional data section. the additional data section.
Raises ``dns.message.Truncated`` if the TC flag is set and Raises ``dns.message.Truncated`` if the TC flag is set and
*raise_on_truncation* is ``True``. *raise_on_truncation* is ``True``.
@ -928,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
message.tsig_ctx = tsig_ctx message.tsig_ctx = tsig_ctx
reader = _WireReader(wire, initialize_message, question_only, reader = _WireReader(wire, initialize_message, question_only,
one_rr_per_rrset, ignore_trailing, keyring, multi) one_rr_per_rrset, ignore_trailing, keyring, multi,
continue_on_error)
try: try:
m = reader.read() m = reader.read()
except dns.exception.FormError: except dns.exception.FormError:
@ -941,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
# have to do this check here too. # have to do this check here too.
if m.flags & dns.flags.TC and raise_on_truncation: if m.flags & dns.flags.TC and raise_on_truncation:
raise Truncated(message=m) raise Truncated(message=m)
if continue_on_error:
m.errors = reader.errors
return m return m
@ -971,12 +1154,12 @@ class _TextReader:
self.id = None self.id = None
self.edns = -1 self.edns = -1
self.ednsflags = 0 self.ednsflags = 0
self.payload = None self.payload = DEFAULT_EDNS_PAYLOAD
self.rcode = None self.rcode = None
self.opcode = dns.opcode.QUERY self.opcode = dns.opcode.QUERY
self.flags = 0 self.flags = 0
def _header_line(self, section): def _header_line(self, _):
"""Process one line from the text format header section.""" """Process one line from the text format header section."""
token = self.tok.get() token = self.tok.get()
@ -1028,6 +1211,8 @@ class _TextReader:
self.relativize, self.relativize,
self.relativize_to) self.relativize_to)
name = self.last_name name = self.last_name
if name is None:
raise NoPreviousName
token = self.tok.get() token = self.tok.get()
if not token.is_identifier(): if not token.is_identifier():
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
@ -1062,6 +1247,8 @@ class _TextReader:
self.relativize, self.relativize,
self.relativize_to) self.relativize_to)
name = self.last_name name = self.last_name
if name is None:
raise NoPreviousName
token = self.tok.get() token = self.tok.get()
if not token.is_identifier(): if not token.is_identifier():
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
@ -1092,6 +1279,8 @@ class _TextReader:
token = self.tok.get() token = self.tok.get()
if empty and not token.is_eol_or_eof(): if empty and not token.is_eol_or_eof():
raise dns.exception.SyntaxError raise dns.exception.SyntaxError
if not empty and token.is_eol_or_eof():
raise dns.exception.UnexpectedEnd
if not token.is_eol_or_eof(): if not token.is_eol_or_eof():
self.tok.unget(token) self.tok.unget(token)
rd = dns.rdata.from_text(rdclass, rdtype, self.tok, rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
@ -1235,7 +1424,8 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False):
def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
want_dnssec=False, ednsflags=None, payload=None, want_dnssec=False, ednsflags=None, payload=None,
request_payload=None, options=None, idna_codec=None): request_payload=None, options=None, idna_codec=None,
id=None, flags=dns.flags.RD):
"""Make a query message. """Make a query message.
The query name, type, and class may all be specified either The query name, type, and class may all be specified either
@ -1252,7 +1442,9 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
is class IN. is class IN.
*use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the *use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the
default is None (no EDNS). default is ``None``. If ``None``, EDNS will be enabled only if other
parameters (*ednsflags*, *payload*, *request_payload*, or *options*) are
set.
See the description of dns.message.Message.use_edns() for the possible See the description of dns.message.Message.use_edns() for the possible
values for use_edns and their meanings. values for use_edns and their meanings.
@ -1275,6 +1467,12 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
is used. is used.
*id*, an ``int`` or ``None``, the desired query id. The default is
``None``, which generates a random query id.
*flags*, an ``int``, the desired query flags. The default is
``dns.flags.RD``.
Returns a ``dns.message.QueryMessage`` Returns a ``dns.message.QueryMessage``
""" """
@ -1282,8 +1480,8 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
qname = dns.name.from_text(qname, idna_codec=idna_codec) qname = dns.name.from_text(qname, idna_codec=idna_codec)
rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
m = QueryMessage() m = QueryMessage(id=id)
m.flags |= dns.flags.RD m.flags = dns.flags.Flag(flags)
m.find_rrset(m.question, qname, rdclass, rdtype, create=True, m.find_rrset(m.question, qname, rdclass, rdtype, create=True,
force_unique=True) force_unique=True)
# only pass keywords on to use_edns if they have been set to a # only pass keywords on to use_edns if they have been set to a
@ -1292,20 +1490,14 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
kwargs = {} kwargs = {}
if ednsflags is not None: if ednsflags is not None:
kwargs['ednsflags'] = ednsflags kwargs['ednsflags'] = ednsflags
if use_edns is None:
use_edns = 0
if payload is not None: if payload is not None:
kwargs['payload'] = payload kwargs['payload'] = payload
if use_edns is None:
use_edns = 0
if request_payload is not None: if request_payload is not None:
kwargs['request_payload'] = request_payload kwargs['request_payload'] = request_payload
if use_edns is None:
use_edns = 0
if options is not None: if options is not None:
kwargs['options'] = options kwargs['options'] = options
if use_edns is None: if kwargs and use_edns is None:
use_edns = 0 use_edns = 0
kwargs['edns'] = use_edns kwargs['edns'] = use_edns
m.use_edns(**kwargs) m.use_edns(**kwargs)
m.want_dnssec(want_dnssec) m.want_dnssec(want_dnssec)
@ -1355,3 +1547,12 @@ def make_response(query, recursion_available=False, our_payload=8192,
tsig_error, b'', query.keyalgorithm) tsig_error, b'', query.keyalgorithm)
response.request_mac = query.mac response.request_mac = query.mac
return response return response
### BEGIN generated MessageSection constants
QUESTION = MessageSection.QUESTION
ANSWER = MessageSection.ANSWER
AUTHORITY = MessageSection.AUTHORITY
ADDITIONAL = MessageSection.ADDITIONAL
### END generated MessageSection constants

47
lib/dns/message.pyi Normal file
View file

@ -0,0 +1,47 @@
from typing import Optional, Dict, List, Tuple, Union
from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode
import hmac
class Message:
def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes:
...
def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int,
covers=rdatatype.NONE, deleting : Optional[int]=None, create=False,
force_unique=False) -> rrset.RRset:
...
def __init__(self, id : Optional[int] =None) -> None:
self.id : int
self.flags = 0
self.sections : List[List[rrset.RRset]] = [[], [], [], []]
self.opt : rrset.RRset = None
self.request_payload = 0
self.keyring = None
self.tsig : rrset.RRset = None
self.request_mac = b''
self.xfr = False
self.origin = None
self.tsig_ctx = None
self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
def is_response(self, other : Message) -> bool:
...
def set_rcode(self, rcode : rcode.Rcode):
...
def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
...
def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
question_only=False, one_rr_per_rrset=False,
ignore_trailing=False) -> Message:
...
def make_response(query : Message, recursion_available=False, our_payload=8192,
fudge=300) -> Message:
...
def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None,
want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None,
request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message:
...

View file

@ -30,6 +30,7 @@ except ImportError: # pragma: no cover
import dns.wire import dns.wire
import dns.exception import dns.exception
import dns.immutable
# fullcompare() result values # fullcompare() result values
@ -215,9 +216,10 @@ class IDNA2008Codec(IDNACodec):
if not have_idna_2008: if not have_idna_2008:
raise NoIDNA2008 raise NoIDNA2008
try: try:
ulabel = idna.ulabel(label)
if self.uts_46: if self.uts_46:
label = idna.uts46_remap(label, False, False) ulabel = idna.uts46_remap(ulabel, False, self.transitional)
return _escapify(idna.ulabel(label)) return _escapify(ulabel)
except (idna.IDNAError, UnicodeError) as e: except (idna.IDNAError, UnicodeError) as e:
raise IDNAException(idna_exception=e) raise IDNAException(idna_exception=e)
@ -304,6 +306,7 @@ def _maybe_convert_to_binary(label):
raise ValueError # pragma: no cover raise ValueError # pragma: no cover
@dns.immutable.immutable
class Name: class Name:
"""A DNS name. """A DNS name.
@ -320,17 +323,9 @@ class Name:
""" """
labels = [_maybe_convert_to_binary(x) for x in labels] labels = [_maybe_convert_to_binary(x) for x in labels]
super().__setattr__('labels', tuple(labels)) self.labels = tuple(labels)
_validate_labels(self.labels) _validate_labels(self.labels)
def __setattr__(self, name, value):
# Names are immutable
raise TypeError("object doesn't support attribute assignment")
def __delattr__(self, name):
# Names are immutable
raise TypeError("object doesn't support attribute deletion")
def __copy__(self): def __copy__(self):
return Name(self.labels) return Name(self.labels)
@ -458,7 +453,7 @@ class Name:
Returns a ``bool``. Returns a ``bool``.
""" """
(nr, o, nl) = self.fullcompare(other) (nr, _, _) = self.fullcompare(other)
if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL: if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
return True return True
return False return False
@ -472,7 +467,7 @@ class Name:
Returns a ``bool``. Returns a ``bool``.
""" """
(nr, o, nl) = self.fullcompare(other) (nr, _, _) = self.fullcompare(other)
if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL: if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
return True return True
return False return False

40
lib/dns/name.pyi Normal file
View file

@ -0,0 +1,40 @@
from typing import Optional, Union, Tuple, Iterable, List
have_idna_2008: bool
class Name:
def is_subdomain(self, o : Name) -> bool: ...
def is_superdomain(self, o : Name) -> bool: ...
def __init__(self, labels : Iterable[Union[bytes,str]]) -> None:
self.labels : List[bytes]
def is_absolute(self) -> bool: ...
def is_wild(self) -> bool: ...
def fullcompare(self, other) -> Tuple[int,int,int]: ...
def canonicalize(self) -> Name: ...
def __eq__(self, other) -> bool: ...
def __ne__(self, other) -> bool: ...
def __lt__(self, other : Name) -> bool: ...
def __le__(self, other : Name) -> bool: ...
def __ge__(self, other : Name) -> bool: ...
def __gt__(self, other : Name) -> bool: ...
def to_text(self, omit_final_dot=False) -> str: ...
def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ...
def to_digestable(self, origin=None) -> bytes: ...
def to_wire(self, file=None, compress=None, origin=None,
canonicalize=False) -> Optional[bytes]: ...
def __add__(self, other : Name) -> Name: ...
def __sub__(self, other : Name) -> Name: ...
def split(self, depth) -> List[Tuple[str,str]]: ...
def concatenate(self, other : Name) -> Name: ...
def relativize(self, origin) -> Name: ...
def derelativize(self, origin) -> Name: ...
def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ...
def parent(self) -> Name: ...
class IDNACodec:
pass
def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name:
...
empty : Name

View file

@ -85,7 +85,7 @@ class NameDict(MutableMapping):
return key in self.__store return key in self.__store
def get_deepest_match(self, name): def get_deepest_match(self, name):
"""Find the deepest match to *fname* in the dictionary. """Find the deepest match to *name* in the dictionary.
The deepest match is the longest name in the dictionary which is The deepest match is the longest name in the dictionary which is
a superdomain of *name*. Note that *superdomain* includes matching a superdomain of *name*. Note that *superdomain* includes matching

View file

@ -17,16 +17,69 @@
"""DNS nodes. A node is a set of rdatasets.""" """DNS nodes. A node is a set of rdatasets."""
import enum
import io import io
import dns.immutable
import dns.rdataset import dns.rdataset
import dns.rdatatype import dns.rdatatype
import dns.renderer import dns.renderer
_cname_types = {
dns.rdatatype.CNAME,
}
# "neutral" types can coexist with a CNAME and thus are not "other data"
_neutral_types = {
dns.rdatatype.NSEC, # RFC 4035 section 2.5
dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible!
dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007
}
def _matches_type_or_its_signature(rdtypes, rdtype, covers):
return rdtype in rdtypes or \
(rdtype == dns.rdatatype.RRSIG and covers in rdtypes)
@enum.unique
class NodeKind(enum.Enum):
"""Rdatasets in nodes
"""
REGULAR = 0 # a.k.a "other data"
NEUTRAL = 1
CNAME = 2
@classmethod
def classify(cls, rdtype, covers):
if _matches_type_or_its_signature(_cname_types, rdtype, covers):
return NodeKind.CNAME
elif _matches_type_or_its_signature(_neutral_types, rdtype, covers):
return NodeKind.NEUTRAL
else:
return NodeKind.REGULAR
@classmethod
def classify_rdataset(cls, rdataset):
return cls.classify(rdataset.rdtype, rdataset.covers)
class Node: class Node:
"""A Node is a set of rdatasets.""" """A Node is a set of rdatasets.
A node is either a CNAME node or an "other data" node. A CNAME
node contains only CNAME, KEY, NSEC, and NSEC3 rdatasets along with their
covering RRSIG rdatasets. An "other data" node contains any
rdataset other than a CNAME or RRSIG(CNAME) rdataset. When
changes are made to a node, the CNAME or "other data" state is
always consistent with the update, i.e. the most recent change
wins. For example, if you have a node which contains a CNAME
rdataset, and then add an MX rdataset to it, then the CNAME
rdataset will be deleted. Likewise if you have a node containing
an MX rdataset and add a CNAME rdataset, the MX rdataset will be
deleted.
"""
__slots__ = ['rdatasets'] __slots__ = ['rdatasets']
@ -78,6 +131,30 @@ class Node:
def __iter__(self): def __iter__(self):
return iter(self.rdatasets) return iter(self.rdatasets)
def _append_rdataset(self, rdataset):
"""Append rdataset to the node with special handling for CNAME and
other data conditions.
Specifically, if the rdataset being appended has ``NodeKind.CNAME``,
then all rdatasets other than KEY, NSEC, NSEC3, and their covering
RRSIGs are deleted. If the rdataset being appended has
``NodeKind.REGULAR`` then CNAME and RRSIG(CNAME) are deleted.
"""
# Make having just one rdataset at the node fast.
if len(self.rdatasets) > 0:
kind = NodeKind.classify_rdataset(rdataset)
if kind == NodeKind.CNAME:
self.rdatasets = [rds for rds in self.rdatasets if
NodeKind.classify_rdataset(rds) !=
NodeKind.REGULAR]
elif kind == NodeKind.REGULAR:
self.rdatasets = [rds for rds in self.rdatasets if
NodeKind.classify_rdataset(rds) !=
NodeKind.CNAME]
# Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to
# edit self.rdatasets.
self.rdatasets.append(rdataset)
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False): create=False):
"""Find an rdataset matching the specified properties in the """Find an rdataset matching the specified properties in the
@ -110,8 +187,8 @@ class Node:
return rds return rds
if not create: if not create:
raise KeyError raise KeyError
rds = dns.rdataset.Rdataset(rdclass, rdtype) rds = dns.rdataset.Rdataset(rdclass, rdtype, covers)
self.rdatasets.append(rds) self._append_rdataset(rds)
return rds return rds
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
@ -180,6 +257,64 @@ class Node:
if not isinstance(replacement, dns.rdataset.Rdataset): if not isinstance(replacement, dns.rdataset.Rdataset):
raise ValueError('replacement is not an rdataset') raise ValueError('replacement is not an rdataset')
if isinstance(replacement, dns.rrset.RRset):
# RRsets are not good replacements as the match() method
# is not compatible.
replacement = replacement.to_rdataset()
self.delete_rdataset(replacement.rdclass, replacement.rdtype, self.delete_rdataset(replacement.rdclass, replacement.rdtype,
replacement.covers) replacement.covers)
self.rdatasets.append(replacement) self._append_rdataset(replacement)
def classify(self):
"""Classify a node.
A node which contains a CNAME or RRSIG(CNAME) is a
``NodeKind.CNAME`` node.
A node which contains only "neutral" types, i.e. types allowed to
co-exist with a CNAME, is a ``NodeKind.NEUTRAL`` node. The neutral
types are NSEC, NSEC3, KEY, and their associated RRSIGS. An empty node
is also considered neutral.
A node which contains some rdataset which is not a CNAME, RRSIG(CNAME),
or a neutral type is a a ``NodeKind.REGULAR`` node. Regular nodes are
also commonly referred to as "other data".
"""
for rdataset in self.rdatasets:
kind = NodeKind.classify(rdataset.rdtype, rdataset.covers)
if kind != NodeKind.NEUTRAL:
return kind
return NodeKind.NEUTRAL
def is_immutable(self):
return False
@dns.immutable.immutable
class ImmutableNode(Node):
def __init__(self, node):
super().__init__()
self.rdatasets = tuple(
[dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
)
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
if create:
raise TypeError("immutable")
return super().find_rdataset(rdclass, rdtype, covers, False)
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
if create:
raise TypeError("immutable")
return super().get_rdataset(rdclass, rdtype, covers, False)
def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
raise TypeError("immutable")
def replace_rdataset(self, replacement):
raise TypeError("immutable")
def is_immutable(self):
return True

17
lib/dns/node.pyi Normal file
View file

@ -0,0 +1,17 @@
from typing import List, Optional, Union
from . import rdataset, rdatatype, name
class Node:
def __init__(self):
self.rdatasets : List[rdataset.Rdataset]
def to_text(self, name : Union[str,name.Name], **kw) -> str:
...
def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
create=False) -> rdataset.Rdataset:
...
def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE,
create=False) -> Optional[rdataset.Rdataset]:
...
def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE):
...
def replace_rdataset(self, replacement : rdataset.Rdataset) -> None:
...

View file

@ -40,8 +40,6 @@ class Opcode(dns.enum.IntEnum):
def _unknown_exception_class(cls): def _unknown_exception_class(cls):
return UnknownOpcode return UnknownOpcode
globals().update(Opcode.__members__)
class UnknownOpcode(dns.exception.DNSException): class UnknownOpcode(dns.exception.DNSException):
"""An DNS opcode is unknown.""" """An DNS opcode is unknown."""
@ -105,3 +103,13 @@ def is_update(flags):
""" """
return from_flags(flags) == Opcode.UPDATE return from_flags(flags) == Opcode.UPDATE
### BEGIN generated Opcode constants
QUERY = Opcode.QUERY
IQUERY = Opcode.IQUERY
STATUS = Opcode.STATUS
NOTIFY = Opcode.NOTIFY
UPDATE = Opcode.UPDATE
### END generated Opcode constants

View file

@ -18,9 +18,10 @@
"""Talk to a DNS server.""" """Talk to a DNS server."""
import contextlib import contextlib
import enum
import errno import errno
import os import os
import select import selectors
import socket import socket
import struct import struct
import time import time
@ -35,14 +36,31 @@ import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.serial import dns.serial
import dns.xfr
try: try:
import requests import requests
from requests_toolbelt.adapters.source import SourceAddressAdapter from requests_toolbelt.adapters.source import SourceAddressAdapter
from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
have_doh = True _have_requests = True
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
have_doh = False _have_requests = False
_have_httpx = False
_have_http2 = False
try:
import httpx
_have_httpx = True
try:
# See if http2 support is available.
with httpx.Client(http2=True):
_have_http2 = True
except Exception:
pass
except ImportError: # pragma: no cover
pass
have_doh = _have_requests or _have_httpx
try: try:
import ssl import ssl
@ -73,20 +91,15 @@ class BadResponse(dns.exception.FormError):
"""A DNS query response does not respond to the question asked.""" """A DNS query response does not respond to the question asked."""
class TransferError(dns.exception.DNSException):
"""A zone transfer response got a non-zero rcode."""
def __init__(self, rcode):
message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode)
super().__init__(message)
self.rcode = rcode
class NoDOH(dns.exception.DNSException): class NoDOH(dns.exception.DNSException):
"""DNS over HTTPS (DOH) was requested but the requests module is not """DNS over HTTPS (DOH) was requested but the requests module is not
available.""" available."""
# for backwards compatibility
TransferError = dns.xfr.TransferError
def _compute_times(timeout): def _compute_times(timeout):
now = time.time() now = time.time()
if timeout is None: if timeout is None:
@ -94,91 +107,49 @@ def _compute_times(timeout):
else: else:
return (now, now + timeout) return (now, now + timeout)
# This module can use either poll() or select() as the "polling backend".
#
# A backend function takes an fd, bools for readability, writablity, and
# error detection, and a timeout.
def _poll_for(fd, readable, writable, error, timeout): def _wait_for(fd, readable, writable, _, expiration):
"""Poll polling backend.""" # Use the selected selector class to wait for any of the specified
event_mask = 0
if readable:
event_mask |= select.POLLIN
if writable:
event_mask |= select.POLLOUT
if error:
event_mask |= select.POLLERR
pollable = select.poll()
pollable.register(fd, event_mask)
if timeout:
event_list = pollable.poll(timeout * 1000)
else:
event_list = pollable.poll()
return bool(event_list)
def _select_for(fd, readable, writable, error, timeout):
"""Select polling backend."""
rset, wset, xset = [], [], []
if readable:
rset = [fd]
if writable:
wset = [fd]
if error:
xset = [fd]
if timeout is None:
(rcount, wcount, xcount) = select.select(rset, wset, xset)
else:
(rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
return bool((rcount or wcount or xcount))
def _wait_for(fd, readable, writable, error, expiration):
# Use the selected polling backend to wait for any of the specified
# events. An "expiration" absolute time is converted into a relative # events. An "expiration" absolute time is converted into a relative
# timeout. # timeout.
#
# The unused parameter is 'error', which is always set when
# selecting for read or write, and we have no error-only selects.
done = False if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
while not done: return True
if expiration is None: sel = _selector_class()
timeout = None events = 0
else: if readable:
timeout = expiration - time.time() events |= selectors.EVENT_READ
if timeout <= 0.0: if writable:
raise dns.exception.Timeout events |= selectors.EVENT_WRITE
try: if events:
if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: sel.register(fd, events)
return True if expiration is None:
if not _polling_backend(fd, readable, writable, error, timeout): timeout = None
raise dns.exception.Timeout else:
except OSError as e: # pragma: no cover timeout = expiration - time.time()
if e.args[0] != errno.EINTR: if timeout <= 0.0:
raise e raise dns.exception.Timeout
done = True if not sel.select(timeout):
raise dns.exception.Timeout
def _set_polling_backend(fn): def _set_selector_class(selector_class):
# Internal API. Do not use. # Internal API. Do not use.
global _polling_backend global _selector_class
_polling_backend = fn _selector_class = selector_class
if hasattr(select, 'poll'): if hasattr(selectors, 'PollSelector'):
# Prefer poll() on platforms that support it because it has no # Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will # limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values). # be more efficient for high values).
_polling_backend = _poll_for _selector_class = selectors.PollSelector
else: else:
_polling_backend = _select_for # pragma: no cover _selector_class = selectors.SelectSelector # pragma: no cover
def _wait_for_readable(s, expiration): def _wait_for_readable(s, expiration):
@ -303,8 +274,8 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message. junk at end of the received message.
*session*, a ``requests.session.Session``. If provided, the session to use *session*, an ``httpx.Client`` or ``requests.session.Session``. If
to send the queries. provided, the client/session to use to send the queries.
*path*, a ``str``. If *where* is an IP address, then *path* will be used to *path*, a ``str``. If *where* is an IP address, then *path* will be used to
construct the URL to send the DNS query to. construct the URL to send the DNS query to.
@ -320,37 +291,66 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
""" """
if not have_doh: if not have_doh:
raise NoDOH # pragma: no cover raise NoDOH('Neither httpx nor requests is available.') # pragma: no cover
_httpx_ok = _have_httpx
wire = q.to_wire() wire = q.to_wire()
(af, destination, source) = _destination_and_source(where, port, (af, _, source) = _destination_and_source(where, port, source, source_port,
source, source_port, False)
False)
transport_adapter = None transport_adapter = None
transport = None
headers = { headers = {
"accept": "application/dns-message" "accept": "application/dns-message"
} }
try: if af is not None:
where_af = dns.inet.af_for_address(where) if af == socket.AF_INET:
if where_af == socket.AF_INET:
url = 'https://{}:{}{}'.format(where, port, path) url = 'https://{}:{}{}'.format(where, port, path)
elif where_af == socket.AF_INET6: elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path) url = 'https://[{}]:{}{}'.format(where, port, path)
except ValueError: elif bootstrap_address is not None:
if bootstrap_address is not None: _httpx_ok = False
split_url = urllib.parse.urlsplit(where) split_url = urllib.parse.urlsplit(where)
headers['Host'] = split_url.hostname headers['Host'] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address) url = where.replace(split_url.hostname, bootstrap_address)
if _have_requests:
transport_adapter = HostHeaderSSLAdapter() transport_adapter = HostHeaderSSLAdapter()
else: else:
url = where url = where
if source is not None: if source is not None:
# set source port and source address # set source port and source address
transport_adapter = SourceAddressAdapter(source) if _have_httpx:
if source_port == 0:
transport = httpx.HTTPTransport(local_address=source[0])
else:
_httpx_ok = False
if _have_requests:
transport_adapter = SourceAddressAdapter(source)
if session:
if _have_httpx:
_is_httpx = isinstance(session, httpx.Client)
else:
_is_httpx = False
if _is_httpx and not _httpx_ok:
raise NoDOH('Session is httpx, but httpx cannot be used for '
'the requested operation.')
else:
_is_httpx = _httpx_ok
if not _httpx_ok and not _have_requests:
raise NoDOH('Cannot use httpx for this operation, and '
'requests is not available.')
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
if not session: if not session:
session = stack.enter_context(requests.sessions.Session()) if _is_httpx:
session = stack.enter_context(httpx.Client(http1=True,
http2=_have_http2,
verify=verify,
transport=transport))
else:
session = stack.enter_context(requests.sessions.Session())
if transport_adapter: if transport_adapter:
session.mount(url, transport_adapter) session.mount(url, transport_adapter)
@ -362,13 +362,23 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
"content-type": "application/dns-message", "content-type": "application/dns-message",
"content-length": str(len(wire)) "content-length": str(len(wire))
}) })
response = session.post(url, headers=headers, data=wire, if _is_httpx:
timeout=timeout, verify=verify) response = session.post(url, headers=headers, content=wire,
timeout=timeout)
else:
response = session.post(url, headers=headers, data=wire,
timeout=timeout, verify=verify)
else: else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
response = session.get(url, headers=headers, if _is_httpx:
timeout=timeout, verify=verify, wire = wire.decode() # httpx does a repr() if we give it bytes
params={"dns": wire}) response = session.get(url, headers=headers,
timeout=timeout,
params={"dns": wire})
else:
response = session.get(url, headers=headers,
timeout=timeout, verify=verify,
params={"dns": wire})
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes # status codes
@ -387,6 +397,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
raise BadResponse raise BadResponse
return r return r
def _udp_recv(sock, max_size, expiration):
"""Reads a datagram from the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
while True:
try:
return sock.recvfrom(max_size)
except BlockingIOError:
_wait_for_readable(sock, expiration)
def _udp_send(sock, data, destination, expiration):
"""Sends the specified datagram to destination over the socket.
A Timeout exception will be raised if the operation is not completed
by the expiration time.
"""
while True:
try:
if destination:
return sock.sendto(data, destination)
else:
return sock.send(data)
except BlockingIOError: # pragma: no cover
_wait_for_writable(sock, expiration)
def send_udp(sock, what, destination, expiration=None): def send_udp(sock, what, destination, expiration=None):
"""Send a DNS message to the specified UDP socket. """Send a DNS message to the specified UDP socket.
@ -406,9 +443,8 @@ def send_udp(sock, what, destination, expiration=None):
if isinstance(what, dns.message.Message): if isinstance(what, dns.message.Message):
what = what.to_wire() what = what.to_wire()
_wait_for_writable(sock, expiration)
sent_time = time.time() sent_time = time.time()
n = sock.sendto(what, destination) n = _udp_send(sock, what, destination, expiration)
return (n, sent_time) return (n, sent_time)
@ -458,9 +494,8 @@ def receive_udp(sock, destination=None, expiration=None,
""" """
wire = b'' wire = b''
while 1: while True:
_wait_for_readable(sock, expiration) (wire, from_address) = _udp_recv(sock, 65535, expiration)
(wire, from_address) = sock.recvfrom(65535)
if _matches_destination(sock.family, from_address, destination, if _matches_destination(sock.family, from_address, destination,
ignore_unexpected): ignore_unexpected):
break break
@ -571,7 +606,7 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None,
if a socket is provided, it must be a nonblocking datagram socket, if a socket is provided, it must be a nonblocking datagram socket,
and the *source* and *source_port* are ignored for the UDP query. and the *source* and *source_port* are ignored for the UDP query.
*tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
TCP query. If ``None``, the default, a socket is created. Note that TCP query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking connected stream if a socket is provided, it must be a nonblocking connected stream
socket, and *where*, *source* and *source_port* are ignored for the TCP socket, and *where*, *source* and *source_port* are ignored for the TCP
@ -598,18 +633,16 @@ def _net_read(sock, count, expiration):
""" """
s = b'' s = b''
while count > 0: while count > 0:
_wait_for_readable(sock, expiration)
try: try:
n = sock.recv(count) n = sock.recv(count)
except ssl.SSLWantReadError: # pragma: no cover if n == b'':
continue raise EOFError
count -= len(n)
s += n
except (BlockingIOError, ssl.SSLWantReadError):
_wait_for_readable(sock, expiration)
except ssl.SSLWantWriteError: # pragma: no cover except ssl.SSLWantWriteError: # pragma: no cover
_wait_for_writable(sock, expiration) _wait_for_writable(sock, expiration)
continue
if n == b'':
raise EOFError
count = count - len(n)
s = s + n
return s return s
@ -621,14 +654,12 @@ def _net_write(sock, data, expiration):
current = 0 current = 0
l = len(data) l = len(data)
while current < l: while current < l:
_wait_for_writable(sock, expiration)
try: try:
current += sock.send(data[current:]) current += sock.send(data[current:])
except (BlockingIOError, ssl.SSLWantWriteError):
_wait_for_writable(sock, expiration)
except ssl.SSLWantReadError: # pragma: no cover except ssl.SSLWantReadError: # pragma: no cover
_wait_for_readable(sock, expiration) _wait_for_readable(sock, expiration)
continue
except ssl.SSLWantWriteError: # pragma: no cover
continue
def send_tcp(sock, what, expiration=None): def send_tcp(sock, what, expiration=None):
@ -652,7 +683,6 @@ def send_tcp(sock, what, expiration=None):
# avoid writev() or doing a short write that would get pushed # avoid writev() or doing a short write that would get pushed
# onto the net # onto the net
tcpmsg = struct.pack("!H", l) + what tcpmsg = struct.pack("!H", l) + what
_wait_for_writable(sock, expiration)
sent_time = time.time() sent_time = time.time()
_net_write(sock, tcpmsg, expiration) _net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time) return (len(tcpmsg), sent_time)
@ -730,7 +760,7 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message. junk at end of the received message.
*sock*, a ``socket.socket``, or ``None``, the socket to use for the *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
query. If ``None``, the default, a socket is created. Note that query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking connected stream if a socket is provided, it must be a nonblocking connected stream
socket, and *where*, *port*, *source* and *source_port* are ignored. socket, and *where*, *port*, *source* and *source_port* are ignored.
@ -742,11 +772,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
(begin_time, expiration) = _compute_times(timeout) (begin_time, expiration) = _compute_times(timeout)
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
if sock: if sock:
#
# Verify that the socket is connected, as if it's not connected,
# it's not writable, and the polling in send_tcp() will time out or
# hang forever.
sock.getpeername()
s = sock s = sock
else: else:
(af, destination, source) = _destination_and_source(where, port, (af, destination, source) = _destination_and_source(where, port,
@ -926,8 +951,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
_connect(s, destination, expiration) _connect(s, destination, expiration)
l = len(wire) l = len(wire)
if use_udp: if use_udp:
_wait_for_writable(s, expiration) _udp_send(s, wire, None, expiration)
s.send(wire)
else: else:
tcpmsg = struct.pack("!H", l) + wire tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration) _net_write(s, tcpmsg, expiration)
@ -948,8 +972,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
(expiration is not None and mexpiration > expiration): (expiration is not None and mexpiration > expiration):
mexpiration = expiration mexpiration = expiration
if use_udp: if use_udp:
_wait_for_readable(s, expiration) (wire, _) = _udp_recv(s, 65535, mexpiration)
(wire, from_address) = s.recvfrom(65535)
else: else:
ldata = _net_read(s, 2, mexpiration) ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata) (l,) = struct.unpack("!H", ldata)
@ -1016,3 +1039,116 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
if done and q.keyring and not r.had_tsig: if done and q.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG") raise dns.exception.FormError("missing TSIG")
yield r yield r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def inbound_xfr(where, txn_manager, query=None,
port=53, timeout=None, lifetime=None, source=None,
source_port=0, udp_mode=UDPMode.NEVER):
"""Conduct an inbound transfer and apply it via a transaction from the
txn_manager.
*where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
for this transfer (typically a ``dns.zone.Zone``).
*query*, the query to send. If not supplied, a default query is
constructed using information from the *txn_manager*.
*port*, an ``int``, the port send the message to. The default is 53.
*timeout*, a ``float``, the number of seconds to wait for each
response message. If None, the default, wait forever.
*lifetime*, a ``float``, the total number of seconds to spend
doing the transfer. If ``None``, the default, then there is no
limit on the time the transfer may take.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*source_port*, an ``int``, the port from which to send the message.
The default is 0.
*udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use
TCP. Other possibilites are ``dns.UDPMode.TRY_FIRST``, which
means "try UDP but fallback to TCP if needed", and
``dns.UDPMode.ONLY``, which means "try UDP and raise
``dns.xfr.UseTCP`` if it does not succeeed.
Raises on errors.
"""
if query is None:
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
(af, destination, source) = _destination_and_source(where, port,
source, source_port)
(_, expiration) = _compute_times(lifetime)
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
with _make_socket(af, sock_type, source) as s:
_connect(s, destination, expiration)
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial,
is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \
(expiration is not None and mexpiration > expiration):
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(rwire, keyring=query.keyring,
request_mac=query.mac, xfr=True,
origin=origin, tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")

64
lib/dns/query.pyi Normal file
View file

@ -0,0 +1,64 @@
from typing import Optional, Union, Dict, Generator, Any
from . import tsig, rdatatype, rdataclass, name, message
from requests.sessions import Session
import socket
# If the ssl import works, then
#
# error: Name 'ssl' already defined (by an import)
#
# is expected and can be ignored.
try:
import ssl
except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
have_doh: bool
def https(q : message.Message, where: str, timeout : Optional[float] = None,
port : Optional[int] = 443, source : Optional[str] = None,
source_port : Optional[int] = 0,
session: Optional[Session] = None,
path : Optional[str] = '/dns-query', post : Optional[bool] = True,
bootstrap_address : Optional[str] = None,
verify : Optional[bool] = True) -> message.Message:
pass
def tcp(q : message.Message, where : str, timeout : float = None, port=53,
af : Optional[int] = None, source : Optional[str] = None,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None) -> message.Message:
pass
def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR,
rdclass=rdataclass.IN,
timeout : Optional[float] = None, port=53,
keyring : Optional[Dict[name.Name, bytes]] = None,
keyname : Union[str,name.Name]= None, relativize=True,
lifetime : Optional[float] = None,
source : Optional[str] = None, source_port=0, serial=0,
use_udp : Optional[bool] = False,
keyalgorithm=tsig.default_algorithm) \
-> Generator[Any,Any,message.Message]:
pass
def udp(q : message.Message, where : str, timeout : Optional[float] = None,
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None) -> message.Message:
pass
def tls(q : message.Message, where : str, timeout : Optional[float] = None,
port=53, source : Optional[str] = None, source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
sock : Optional[socket.socket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
pass

View file

@ -72,7 +72,6 @@ class Rcode(dns.enum.IntEnum):
def _unknown_exception_class(cls): def _unknown_exception_class(cls):
return UnknownRcode return UnknownRcode
globals().update(Rcode.__members__)
class UnknownRcode(dns.exception.DNSException): class UnknownRcode(dns.exception.DNSException):
"""A DNS rcode is unknown.""" """A DNS rcode is unknown."""
@ -104,8 +103,6 @@ def from_flags(flags, ednsflags):
""" """
value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0)
if value < 0 or value > 4095:
raise ValueError('rcode must be >= 0 and <= 4095')
return value return value
@ -139,3 +136,29 @@ def to_text(value, tsig=False):
if tsig and value == Rcode.BADVERS: if tsig and value == Rcode.BADVERS:
return 'BADSIG' return 'BADSIG'
return Rcode.to_text(value) return Rcode.to_text(value)
### BEGIN generated Rcode constants
NOERROR = Rcode.NOERROR
FORMERR = Rcode.FORMERR
SERVFAIL = Rcode.SERVFAIL
NXDOMAIN = Rcode.NXDOMAIN
NOTIMP = Rcode.NOTIMP
REFUSED = Rcode.REFUSED
YXDOMAIN = Rcode.YXDOMAIN
YXRRSET = Rcode.YXRRSET
NXRRSET = Rcode.NXRRSET
NOTAUTH = Rcode.NOTAUTH
NOTZONE = Rcode.NOTZONE
DSOTYPENI = Rcode.DSOTYPENI
BADVERS = Rcode.BADVERS
BADSIG = Rcode.BADSIG
BADKEY = Rcode.BADKEY
BADTIME = Rcode.BADTIME
BADMODE = Rcode.BADMODE
BADNAME = Rcode.BADNAME
BADALG = Rcode.BADALG
BADTRUNC = Rcode.BADTRUNC
BADCOOKIE = Rcode.BADCOOKIE
### END generated Rcode constants

View file

@ -23,43 +23,68 @@ import binascii
import io import io
import inspect import inspect
import itertools import itertools
import random
import dns.wire import dns.wire
import dns.exception import dns.exception
import dns.immutable
import dns.ipv4
import dns.ipv6
import dns.name import dns.name
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.tokenizer import dns.tokenizer
import dns.ttl
_chunksize = 32 _chunksize = 32
# We currently allow comparisons for rdata with relative names for backwards
# compatibility, but in the future we will not, as these kinds of comparisons
# can lead to subtle bugs if code is not carefully written.
#
# This switch allows the future behavior to be turned on so code can be
# tested with it.
_allow_relative_comparisons = True
def _wordbreak(data, chunksize=_chunksize):
class NoRelativeRdataOrdering(dns.exception.DNSException):
"""An attempt was made to do an ordered comparison of one or more
rdata with relative names. The only reliable way of sorting rdata
is to use non-relativized rdata.
"""
def _wordbreak(data, chunksize=_chunksize, separator=b' '):
"""Break a binary string into chunks of chunksize characters separated by """Break a binary string into chunks of chunksize characters separated by
a space. a space.
""" """
if not chunksize: if not chunksize:
return data.decode() return data.decode()
return b' '.join([data[i:i + chunksize] return separator.join([data[i:i + chunksize]
for i for i
in range(0, len(data), chunksize)]).decode() in range(0, len(data), chunksize)]).decode()
def _hexify(data, chunksize=_chunksize): # pylint: disable=unused-argument
def _hexify(data, chunksize=_chunksize, separator=b' ', **kw):
"""Convert a binary string into its hex encoding, broken up into chunks """Convert a binary string into its hex encoding, broken up into chunks
of chunksize characters separated by a space. of chunksize characters separated by a separator.
""" """
return _wordbreak(binascii.hexlify(data), chunksize) return _wordbreak(binascii.hexlify(data), chunksize, separator)
def _base64ify(data, chunksize=_chunksize): def _base64ify(data, chunksize=_chunksize, separator=b' ', **kw):
"""Convert a binary string into its base64 encoding, broken up into chunks """Convert a binary string into its base64 encoding, broken up into chunks
of chunksize characters separated by a space. of chunksize characters separated by a separator.
""" """
return _wordbreak(base64.b64encode(data), chunksize) return _wordbreak(base64.b64encode(data), chunksize, separator)
# pylint: enable=unused-argument
__escaped = b'"\\' __escaped = b'"\\'
@ -92,26 +117,15 @@ def _truncate_bitmap(what):
return what[0: i + 1] return what[0: i + 1]
return what[0:1] return what[0:1]
def _constify(o): # So we don't have to edit all the rdata classes...
""" _constify = dns.immutable.constify
Convert mutable types to immutable types.
"""
if isinstance(o, bytearray):
return bytes(o)
if isinstance(o, tuple):
try:
hash(o)
return o
except Exception:
return tuple(_constify(elt) for elt in o)
if isinstance(o, list):
return tuple(_constify(elt) for elt in o)
return o
@dns.immutable.immutable
class Rdata: class Rdata:
"""Base class for all DNS rdata types.""" """Base class for all DNS rdata types."""
__slots__ = ['rdclass', 'rdtype'] __slots__ = ['rdclass', 'rdtype', 'rdcomment']
def __init__(self, rdclass, rdtype): def __init__(self, rdclass, rdtype):
"""Initialize an rdata. """Initialize an rdata.
@ -121,16 +135,9 @@ class Rdata:
*rdtype*, an ``int`` is the rdatatype of the Rdata. *rdtype*, an ``int`` is the rdatatype of the Rdata.
""" """
object.__setattr__(self, 'rdclass', rdclass) self.rdclass = self._as_rdataclass(rdclass)
object.__setattr__(self, 'rdtype', rdtype) self.rdtype = self._as_rdatatype(rdtype)
self.rdcomment = None
def __setattr__(self, name, value):
# Rdatas are immutable
raise TypeError("object doesn't support attribute assignment")
def __delattr__(self, name):
# Rdatas are immutable
raise TypeError("object doesn't support attribute deletion")
def _get_all_slots(self): def _get_all_slots(self):
return itertools.chain.from_iterable(getattr(cls, '__slots__', []) return itertools.chain.from_iterable(getattr(cls, '__slots__', [])
@ -153,6 +160,10 @@ class Rdata:
def __setstate__(self, state): def __setstate__(self, state):
for slot, val in state.items(): for slot, val in state.items():
object.__setattr__(self, slot, val) object.__setattr__(self, slot, val)
if not hasattr(self, 'rdcomment'):
# Pickled rdata from 2.0.x might not have a rdcomment, so add
# it if needed.
object.__setattr__(self, 'rdcomment', None)
def covers(self): def covers(self):
"""Return the type a Rdata covers. """Return the type a Rdata covers.
@ -184,10 +195,10 @@ class Rdata:
Returns a ``str``. Returns a ``str``.
""" """
raise NotImplementedError raise NotImplementedError # pragma: no cover
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
raise NotImplementedError raise NotImplementedError # pragma: no cover
def to_wire(self, file=None, compress=None, origin=None, def to_wire(self, file=None, compress=None, origin=None,
canonicalize=False): canonicalize=False):
@ -237,12 +248,42 @@ class Rdata:
"""Compare an rdata with another rdata of the same rdtype and """Compare an rdata with another rdata of the same rdtype and
rdclass. rdclass.
Return < 0 if self < other in the DNSSEC ordering, 0 if self For rdata with only absolute names:
== other, and > 0 if self > other. Return < 0 if self < other in the DNSSEC ordering, 0 if self
== other, and > 0 if self > other.
For rdata with at least one relative names:
The rdata sorts before any rdata with only absolute names.
When compared with another relative rdata, all names are
made absolute as if they were relative to the root, as the
proper origin is not available. While this creates a stable
ordering, it is NOT guaranteed to be the DNSSEC ordering.
In the future, all ordering comparisons for rdata with
relative names will be disallowed.
""" """
our = self.to_digestable(dns.name.root) try:
their = other.to_digestable(dns.name.root) our = self.to_digestable()
our_relative = False
except dns.name.NeedAbsoluteNameOrOrigin:
if _allow_relative_comparisons:
our = self.to_digestable(dns.name.root)
our_relative = True
try:
their = other.to_digestable()
their_relative = False
except dns.name.NeedAbsoluteNameOrOrigin:
if _allow_relative_comparisons:
their = other.to_digestable(dns.name.root)
their_relative = True
if _allow_relative_comparisons:
if our_relative != their_relative:
# For the purpose of comparison, all rdata with at least one
# relative name is less than an rdata with only absolute names.
if our_relative:
return -1
else:
return 1
elif our_relative or their_relative:
raise NoRelativeRdataOrdering
if our == their: if our == their:
return 0 return 0
elif our > their: elif our > their:
@ -255,14 +296,28 @@ class Rdata:
return False return False
if self.rdclass != other.rdclass or self.rdtype != other.rdtype: if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return False return False
return self._cmp(other) == 0 our_relative = False
their_relative = False
try:
our = self.to_digestable()
except dns.name.NeedAbsoluteNameOrOrigin:
our = self.to_digestable(dns.name.root)
our_relative = True
try:
their = other.to_digestable()
except dns.name.NeedAbsoluteNameOrOrigin:
their = other.to_digestable(dns.name.root)
their_relative = True
if our_relative != their_relative:
return False
return our == their
def __ne__(self, other): def __ne__(self, other):
if not isinstance(other, Rdata): if not isinstance(other, Rdata):
return True return True
if self.rdclass != other.rdclass or self.rdtype != other.rdtype: if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return True return True
return self._cmp(other) != 0 return not self.__eq__(other)
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, Rdata) or \ if not isinstance(other, Rdata) or \
@ -295,11 +350,11 @@ class Rdata:
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
raise NotImplementedError raise NotImplementedError # pragma: no cover
@classmethod @classmethod
def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
raise NotImplementedError raise NotImplementedError # pragma: no cover
def replace(self, **kwargs): def replace(self, **kwargs):
""" """
@ -319,6 +374,8 @@ class Rdata:
# Ensure that all of the arguments correspond to valid fields. # Ensure that all of the arguments correspond to valid fields.
# Don't allow rdclass or rdtype to be changed, though. # Don't allow rdclass or rdtype to be changed, though.
for key in kwargs: for key in kwargs:
if key == 'rdcomment':
continue
if key not in parameters: if key not in parameters:
raise AttributeError("'{}' object has no attribute '{}'" raise AttributeError("'{}' object has no attribute '{}'"
.format(self.__class__.__name__, key)) .format(self.__class__.__name__, key))
@ -331,13 +388,149 @@ class Rdata:
args = (kwargs.get(key, getattr(self, key)) for key in parameters) args = (kwargs.get(key, getattr(self, key)) for key in parameters)
# Create, validate, and return the new object. # Create, validate, and return the new object.
#
# Note that if we make constructors do validation in the future,
# this validation can go away.
rd = self.__class__(*args) rd = self.__class__(*args)
dns.rdata.from_text(rd.rdclass, rd.rdtype, rd.to_text()) # The comment is not set in the constructor, so give it special
# handling.
rdcomment = kwargs.get('rdcomment', self.rdcomment)
if rdcomment is not None:
object.__setattr__(rd, 'rdcomment', rdcomment)
return rd return rd
# Type checking and conversion helpers. These are class methods as
# they don't touch object state and may be useful to others.
@classmethod
def _as_rdataclass(cls, value):
return dns.rdataclass.RdataClass.make(value)
@classmethod
def _as_rdatatype(cls, value):
return dns.rdatatype.RdataType.make(value)
@classmethod
def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True):
if encode and isinstance(value, str):
value = value.encode()
elif isinstance(value, bytearray):
value = bytes(value)
elif not isinstance(value, bytes):
raise ValueError('not bytes')
if max_length is not None and len(value) > max_length:
raise ValueError('too long')
if not empty_ok and len(value) == 0:
raise ValueError('empty bytes not allowed')
return value
@classmethod
def _as_name(cls, value):
# Note that proper name conversion (e.g. with origin and IDNA
# awareness) is expected to be done via from_text. This is just
# a simple thing for people invoking the constructor directly.
if isinstance(value, str):
return dns.name.from_text(value)
elif not isinstance(value, dns.name.Name):
raise ValueError('not a name')
return value
@classmethod
def _as_uint8(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
if value < 0 or value > 255:
raise ValueError('not a uint8')
return value
@classmethod
def _as_uint16(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
if value < 0 or value > 65535:
raise ValueError('not a uint16')
return value
@classmethod
def _as_uint32(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
if value < 0 or value > 4294967295:
raise ValueError('not a uint32')
return value
@classmethod
def _as_uint48(cls, value):
if not isinstance(value, int):
raise ValueError('not an integer')
if value < 0 or value > 281474976710655:
raise ValueError('not a uint48')
return value
@classmethod
def _as_int(cls, value, low=None, high=None):
if not isinstance(value, int):
raise ValueError('not an integer')
if low is not None and value < low:
raise ValueError('value too small')
if high is not None and value > high:
raise ValueError('value too large')
return value
@classmethod
def _as_ipv4_address(cls, value):
if isinstance(value, str):
# call to check validity
dns.ipv4.inet_aton(value)
return value
elif isinstance(value, bytes):
return dns.ipv4.inet_ntoa(value)
else:
raise ValueError('not an IPv4 address')
@classmethod
def _as_ipv6_address(cls, value):
if isinstance(value, str):
# call to check validity
dns.ipv6.inet_aton(value)
return value
elif isinstance(value, bytes):
return dns.ipv6.inet_ntoa(value)
else:
raise ValueError('not an IPv6 address')
@classmethod
def _as_bool(cls, value):
if isinstance(value, bool):
return value
else:
raise ValueError('not a boolean')
@classmethod
def _as_ttl(cls, value):
if isinstance(value, int):
return cls._as_int(value, 0, dns.ttl.MAX_TTL)
elif isinstance(value, str):
return dns.ttl.from_text(value)
else:
raise ValueError('not a TTL')
@classmethod
def _as_tuple(cls, value, as_value):
try:
# For user convenience, if value is a singleton of the list
# element type, wrap it in a tuple.
return (as_value(value),)
except Exception:
# Otherwise, check each element of the iterable *value*
# against *as_value*.
return tuple(as_value(v) for v in value)
# Processing order
@classmethod
def _processing_order(cls, iterable):
items = list(iterable)
random.shuffle(items)
return items
class GenericRdata(Rdata): class GenericRdata(Rdata):
@ -354,7 +547,7 @@ class GenericRdata(Rdata):
object.__setattr__(self, 'data', data) object.__setattr__(self, 'data', data)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return r'\# %d ' % len(self.data) + _hexify(self.data) return r'\# %d ' % len(self.data) + _hexify(self.data, **kw)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
@ -364,13 +557,7 @@ class GenericRdata(Rdata):
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
r'generic rdata does not start with \#') r'generic rdata does not start with \#')
length = tok.get_int() length = tok.get_int()
chunks = [] hex = tok.concatenate_remaining_identifiers().encode()
while 1:
token = tok.get()
if token.is_eol_or_eof():
break
chunks.append(token.value.encode())
hex = b''.join(chunks)
data = binascii.unhexlify(hex) data = binascii.unhexlify(hex)
if len(data) != length: if len(data) != length:
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
@ -453,29 +640,45 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
Returns an instance of the chosen Rdata subclass. Returns an instance of the chosen Rdata subclass.
""" """
if isinstance(tok, str): if isinstance(tok, str):
tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec) tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype) cls = get_rdata_class(rdclass, rdtype)
if cls != GenericRdata: with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
# peek at first token rdata = None
token = tok.get() if cls != GenericRdata:
tok.unget(token) # peek at first token
if token.is_identifier() and \ token = tok.get()
token.value == r'\#': tok.unget(token)
# if token.is_identifier() and \
# Known type using the generic syntax. Extract the token.value == r'\#':
# wire form from the generic syntax, and then run #
# from_wire on it. # Known type using the generic syntax. Extract the
# # wire form from the generic syntax, and then run
rdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, # from_wire on it.
relativize, relativize_to) #
return from_wire(rdclass, rdtype, rdata.data, 0, len(rdata.data), grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
origin) relativize, relativize_to)
return cls.from_text(rdclass, rdtype, tok, origin, relativize, rdata = from_wire(rdclass, rdtype, grdata.data, 0,
relativize_to) len(grdata.data), origin)
#
# If this comparison isn't equal, then there must have been
# compressed names in the wire format, which is an error,
# there being no reasonable context to decompress with.
#
rwire = rdata.to_wire()
if rwire != grdata.data:
raise dns.exception.SyntaxError('compressed data in '
'generic syntax form '
'of known rdatatype')
if rdata is None:
rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
relativize_to)
token = tok.get_eol_as_token()
if token.comment is not None:
object.__setattr__(rdata, 'rdcomment', token.comment)
return rdata
def from_wire_parser(rdclass, rdtype, parser, origin=None): def from_wire_parser(rdclass, rdtype, parser, origin=None):
@ -505,7 +708,8 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
rdclass = dns.rdataclass.RdataClass.make(rdclass) rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype) rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype) cls = get_rdata_class(rdclass, rdtype)
return cls.from_wire_parser(rdclass, rdtype, parser, origin) with dns.exception.ExceptionWrapper(dns.exception.FormError):
return cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
@ -543,7 +747,7 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
class RdatatypeExists(dns.exception.DNSException): class RdatatypeExists(dns.exception.DNSException):
"""DNS rdatatype already exists.""" """DNS rdatatype already exists."""
supp_kwargs = {'rdclass', 'rdtype'} supp_kwargs = {'rdclass', 'rdtype'}
fmt = "The rdata type with class {rdclass} and rdtype {rdtype} " + \ fmt = "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + \
"already exists." "already exists."

19
lib/dns/rdata.pyi Normal file
View file

@ -0,0 +1,19 @@
from typing import Dict, Tuple, Any, Optional, BinaryIO
from .name import Name, IDNACodec
class Rdata:
def __init__(self):
self.address : str
def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]:
...
@classmethod
def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True):
...
_rdata_modules : Dict[Tuple[Any,Rdata],Any]
def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None,
relativize : bool = True, relativize_to : Optional[Name] = None,
idna_codec : Optional[IDNACodec] = None):
...
def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None):
...

View file

@ -48,7 +48,6 @@ class RdataClass(dns.enum.IntEnum):
def _unknown_exception_class(cls): def _unknown_exception_class(cls):
return UnknownRdataclass return UnknownRdataclass
globals().update(RdataClass.__members__)
_metaclasses = {RdataClass.NONE, RdataClass.ANY} _metaclasses = {RdataClass.NONE, RdataClass.ANY}
@ -100,3 +99,17 @@ def is_metaclass(rdclass):
if rdclass in _metaclasses: if rdclass in _metaclasses:
return True return True
return False return False
### BEGIN generated RdataClass constants
RESERVED0 = RdataClass.RESERVED0
IN = RdataClass.IN
INTERNET = RdataClass.INTERNET
CH = RdataClass.CH
CHAOS = RdataClass.CHAOS
HS = RdataClass.HS
HESIOD = RdataClass.HESIOD
NONE = RdataClass.NONE
ANY = RdataClass.ANY
### END generated RdataClass constants

View file

@ -22,6 +22,7 @@ import random
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdatatype import dns.rdatatype
import dns.rdataclass import dns.rdataclass
import dns.rdata import dns.rdata
@ -79,15 +80,15 @@ class Rdataset(dns.set.Set):
TTL or the specified TTL. If the set contains no rdatas, set the TTL TTL or the specified TTL. If the set contains no rdatas, set the TTL
to the specified TTL. to the specified TTL.
*ttl*, an ``int``. *ttl*, an ``int`` or ``str``.
""" """
ttl = dns.ttl.make(ttl)
if len(self) == 0: if len(self) == 0:
self.ttl = ttl self.ttl = ttl
elif ttl < self.ttl: elif ttl < self.ttl:
self.ttl = ttl self.ttl = ttl
def add(self, rd, ttl=None): def add(self, rd, ttl=None): # pylint: disable=arguments-differ
"""Add the specified rdata to the rdataset. """Add the specified rdata to the rdataset.
If the optional *ttl* parameter is supplied, then If the optional *ttl* parameter is supplied, then
@ -176,8 +177,8 @@ class Rdataset(dns.set.Set):
return not self.__eq__(other) return not self.__eq__(other)
def to_text(self, name=None, origin=None, relativize=True, def to_text(self, name=None, origin=None, relativize=True,
override_rdclass=None, **kw): override_rdclass=None, want_comments=False, **kw):
"""Convert the rdataset into DNS master file format. """Convert the rdataset into DNS zone file format.
See ``dns.name.Name.choose_relativity`` for more information See ``dns.name.Name.choose_relativity`` for more information
on how *origin* and *relativize* determine the way names on how *origin* and *relativize* determine the way names
@ -194,6 +195,12 @@ class Rdataset(dns.set.Set):
*relativize*, a ``bool``. If ``True``, names will be relativized *relativize*, a ``bool``. If ``True``, names will be relativized
to *origin*. to *origin*.
*override_rdclass*, a ``dns.rdataclass.RdataClass`` or ``None``.
If not ``None``, use this class instead of the Rdataset's class.
*want_comments*, a ``bool``. If ``True``, emit comments for rdata
which have them. The default is ``False``.
""" """
if name is not None: if name is not None:
@ -219,11 +226,16 @@ class Rdataset(dns.set.Set):
dns.rdatatype.to_text(self.rdtype))) dns.rdatatype.to_text(self.rdtype)))
else: else:
for rd in self: for rd in self:
s.write('%s%s%d %s %s %s\n' % extra = ''
if want_comments:
if rd.rdcomment:
extra = f' ;{rd.rdcomment}'
s.write('%s%s%d %s %s %s%s\n' %
(ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass), (ntext, pad, self.ttl, dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype), dns.rdatatype.to_text(self.rdtype),
rd.to_text(origin=origin, relativize=relativize, rd.to_text(origin=origin, relativize=relativize,
**kw))) **kw),
extra))
# #
# We strip off the final \n for the caller's convenience in printing # We strip off the final \n for the caller's convenience in printing
# #
@ -260,7 +272,7 @@ class Rdataset(dns.set.Set):
want_shuffle = False want_shuffle = False
else: else:
rdclass = self.rdclass rdclass = self.rdclass
file.seek(0, 2) file.seek(0, io.SEEK_END)
if len(self) == 0: if len(self) == 0:
name.to_wire(file, compress, origin) name.to_wire(file, compress, origin)
stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0) stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)
@ -284,7 +296,7 @@ class Rdataset(dns.set.Set):
file.seek(start - 2) file.seek(start - 2)
stuff = struct.pack("!H", end - start) stuff = struct.pack("!H", end - start)
file.write(stuff) file.write(stuff)
file.seek(0, 2) file.seek(0, io.SEEK_END)
return len(self) return len(self)
def match(self, rdclass, rdtype, covers): def match(self, rdclass, rdtype, covers):
@ -297,8 +309,86 @@ class Rdataset(dns.set.Set):
return True return True
return False return False
def processing_order(self):
"""Return rdatas in a valid processing order according to the type's
specification. For example, MX records are in preference order from
lowest to highest preferences, with items of the same perference
shuffled.
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None): For types that do not define a processing order, the rdatas are
simply shuffled.
"""
if len(self) == 0:
return []
else:
return self[0]._processing_order(iter(self))
@dns.immutable.immutable
class ImmutableRdataset(Rdataset):
"""An immutable DNS rdataset."""
_clone_class = Rdataset
def __init__(self, rdataset):
"""Create an immutable rdataset from the specified rdataset."""
super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
rdataset.ttl)
self.items = dns.immutable.Dict(rdataset.items)
def update_ttl(self, ttl):
raise TypeError('immutable')
def add(self, rd, ttl=None):
raise TypeError('immutable')
def union_update(self, other):
raise TypeError('immutable')
def intersection_update(self, other):
raise TypeError('immutable')
def update(self, other):
raise TypeError('immutable')
def __delitem__(self, i):
raise TypeError('immutable')
def __ior__(self, other):
raise TypeError('immutable')
def __iand__(self, other):
raise TypeError('immutable')
def __iadd__(self, other):
raise TypeError('immutable')
def __isub__(self, other):
raise TypeError('immutable')
def clear(self):
raise TypeError('immutable')
def __copy__(self):
return ImmutableRdataset(super().copy())
def copy(self):
return ImmutableRdataset(super().copy())
def union(self, other):
return ImmutableRdataset(super().union(other))
def intersection(self, other):
return ImmutableRdataset(super().intersection(other))
def difference(self, other):
return ImmutableRdataset(super().difference(other))
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
origin=None, relativize=True, relativize_to=None):
"""Create an rdataset with the specified class, type, and TTL, and with """Create an rdataset with the specified class, type, and TTL, and with
the specified list of rdatas in text format. the specified list of rdatas in text format.
@ -306,6 +396,14 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None):
encoder/decoder to use; if ``None``, the default IDNA 2003 encoder/decoder to use; if ``None``, the default IDNA 2003
encoder/decoder is used. encoder/decoder is used.
*origin*, a ``dns.name.Name`` (or ``None``), the
origin to use for relative names.
*relativize*, a ``bool``. If true, name will be relativized.
*relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use
when relativizing names. If not set, the *origin* value will be used.
Returns a ``dns.rdataset.Rdataset`` object. Returns a ``dns.rdataset.Rdataset`` object.
""" """
@ -314,7 +412,8 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None):
r = Rdataset(rdclass, rdtype) r = Rdataset(rdclass, rdtype)
r.update_ttl(ttl) r.update_ttl(ttl)
for t in text_rdatas: for t in text_rdatas:
rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, idna_codec=idna_codec) rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize,
relativize_to, idna_codec)
r.add(rd) r.add(rd)
return r return r

58
lib/dns/rdataset.pyi Normal file
View file

@ -0,0 +1,58 @@
from typing import Optional, Dict, List, Union
from io import BytesIO
from . import exception, name, set, rdatatype, rdata, rdataset
class DifferingCovers(exception.DNSException):
"""An attempt was made to add a DNS SIG/RRSIG whose covered type
is not the same as that of the other rdatas in the rdataset."""
class IncompatibleTypes(exception.DNSException):
"""An attempt was made to add DNS RR data of an incompatible type."""
class Rdataset(set.Set):
def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0):
self.rdclass : int = rdclass
self.rdtype : int = rdtype
self.covers : int = covers
self.ttl : int = ttl
def update_ttl(self, ttl : int) -> None:
...
def add(self, rd : rdata.Rdata, ttl : Optional[int] =None):
...
def union_update(self, other : Rdataset):
...
def intersection_update(self, other : Rdataset):
...
def update(self, other : Rdataset):
...
def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True,
override_rdclass : Optional[int] =None, **kw) -> bytes:
...
def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None,
override_rdclass : Optional[int] = None, want_shuffle=True) -> int:
...
def match(self, rdclass : int, rdtype : int, covers : int) -> bool:
...
def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset:
...
def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset:
...
def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
...
def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset:
...

View file

@ -72,14 +72,22 @@ class RdataType(dns.enum.IntEnum):
NSEC3 = 50 NSEC3 = 50
NSEC3PARAM = 51 NSEC3PARAM = 51
TLSA = 52 TLSA = 52
SMIMEA = 53
HIP = 55 HIP = 55
NINFO = 56 NINFO = 56
CDS = 59 CDS = 59
CDNSKEY = 60 CDNSKEY = 60
OPENPGPKEY = 61 OPENPGPKEY = 61
CSYNC = 62 CSYNC = 62
ZONEMD = 63
SVCB = 64
HTTPS = 65
SPF = 99 SPF = 99
UNSPEC = 103 UNSPEC = 103
NID = 104
L32 = 105
L64 = 106
LP = 107
EUI48 = 108 EUI48 = 108
EUI64 = 109 EUI64 = 109
TKEY = 249 TKEY = 249
@ -92,7 +100,7 @@ class RdataType(dns.enum.IntEnum):
URI = 256 URI = 256
CAA = 257 CAA = 257
AVC = 258 AVC = 258
AMTRELAY = 259 AMTRELAY = 260
TA = 32768 TA = 32768
DLV = 32769 DLV = 32769
@ -115,8 +123,6 @@ class RdataType(dns.enum.IntEnum):
_registered_by_text = {} _registered_by_text = {}
_registered_by_value = {} _registered_by_value = {}
globals().update(RdataType.__members__)
_metatypes = {RdataType.OPT} _metatypes = {RdataType.OPT}
_singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME, _singletons = {RdataType.SOA, RdataType.NXT, RdataType.DNAME,
@ -219,3 +225,89 @@ def register_type(rdtype, rdtype_text, is_singleton=False):
_registered_by_value[rdtype] = rdtype_text _registered_by_value[rdtype] = rdtype_text
if is_singleton: if is_singleton:
_singletons.add(rdtype) _singletons.add(rdtype)
### BEGIN generated RdataType constants
TYPE0 = RdataType.TYPE0
NONE = RdataType.NONE
A = RdataType.A
NS = RdataType.NS
MD = RdataType.MD
MF = RdataType.MF
CNAME = RdataType.CNAME
SOA = RdataType.SOA
MB = RdataType.MB
MG = RdataType.MG
MR = RdataType.MR
NULL = RdataType.NULL
WKS = RdataType.WKS
PTR = RdataType.PTR
HINFO = RdataType.HINFO
MINFO = RdataType.MINFO
MX = RdataType.MX
TXT = RdataType.TXT
RP = RdataType.RP
AFSDB = RdataType.AFSDB
X25 = RdataType.X25
ISDN = RdataType.ISDN
RT = RdataType.RT
NSAP = RdataType.NSAP
NSAP_PTR = RdataType.NSAP_PTR
SIG = RdataType.SIG
KEY = RdataType.KEY
PX = RdataType.PX
GPOS = RdataType.GPOS
AAAA = RdataType.AAAA
LOC = RdataType.LOC
NXT = RdataType.NXT
SRV = RdataType.SRV
NAPTR = RdataType.NAPTR
KX = RdataType.KX
CERT = RdataType.CERT
A6 = RdataType.A6
DNAME = RdataType.DNAME
OPT = RdataType.OPT
APL = RdataType.APL
DS = RdataType.DS
SSHFP = RdataType.SSHFP
IPSECKEY = RdataType.IPSECKEY
RRSIG = RdataType.RRSIG
NSEC = RdataType.NSEC
DNSKEY = RdataType.DNSKEY
DHCID = RdataType.DHCID
NSEC3 = RdataType.NSEC3
NSEC3PARAM = RdataType.NSEC3PARAM
TLSA = RdataType.TLSA
SMIMEA = RdataType.SMIMEA
HIP = RdataType.HIP
NINFO = RdataType.NINFO
CDS = RdataType.CDS
CDNSKEY = RdataType.CDNSKEY
OPENPGPKEY = RdataType.OPENPGPKEY
CSYNC = RdataType.CSYNC
ZONEMD = RdataType.ZONEMD
SVCB = RdataType.SVCB
HTTPS = RdataType.HTTPS
SPF = RdataType.SPF
UNSPEC = RdataType.UNSPEC
NID = RdataType.NID
L32 = RdataType.L32
L64 = RdataType.L64
LP = RdataType.LP
EUI48 = RdataType.EUI48
EUI64 = RdataType.EUI64
TKEY = RdataType.TKEY
TSIG = RdataType.TSIG
IXFR = RdataType.IXFR
AXFR = RdataType.AXFR
MAILB = RdataType.MAILB
MAILA = RdataType.MAILA
ANY = RdataType.ANY
URI = RdataType.URI
CAA = RdataType.CAA
AVC = RdataType.AVC
AMTRELAY = RdataType.AMTRELAY
TA = RdataType.TA
DLV = RdataType.DLV
### END generated RdataType constants

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase import dns.rdtypes.mxbase
import dns.immutable
@dns.immutable.immutable
class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""AFSDB record""" """AFSDB record"""

View file

@ -18,12 +18,19 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdtypes.util import dns.rdtypes.util
class Relay(dns.rdtypes.util.Gateway): class Relay(dns.rdtypes.util.Gateway):
name = 'AMTRELAY relay' name = 'AMTRELAY relay'
@property
def relay(self):
return self.gateway
@dns.immutable.immutable
class AMTRELAY(dns.rdata.Rdata): class AMTRELAY(dns.rdata.Rdata):
"""AMTRELAY record""" """AMTRELAY record"""
@ -35,11 +42,11 @@ class AMTRELAY(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, precedence, discovery_optional, def __init__(self, rdclass, rdtype, precedence, discovery_optional,
relay_type, relay): relay_type, relay):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
Relay(relay_type, relay).check() relay = Relay(relay_type, relay)
object.__setattr__(self, 'precedence', precedence) self.precedence = self._as_uint8(precedence)
object.__setattr__(self, 'discovery_optional', discovery_optional) self.discovery_optional = self._as_bool(discovery_optional)
object.__setattr__(self, 'relay_type', relay_type) self.relay_type = relay.type
object.__setattr__(self, 'relay', relay) self.relay = relay.relay
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) relay = Relay(self.relay_type, self.relay).to_text(origin, relativize)
@ -57,10 +64,10 @@ class AMTRELAY(dns.rdata.Rdata):
relay_type = tok.get_uint8() relay_type = tok.get_uint8()
if relay_type > 0x7f: if relay_type > 0x7f:
raise dns.exception.SyntaxError('expecting an integer <= 127') raise dns.exception.SyntaxError('expecting an integer <= 127')
relay = Relay(relay_type).from_text(tok, origin, relativize, relay = Relay.from_text(relay_type, tok, origin, relativize,
relativize_to) relativize_to)
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
relay) relay.relay)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
relay_type = self.relay_type | (self.discovery_optional << 7) relay_type = self.relay_type | (self.discovery_optional << 7)
@ -74,6 +81,6 @@ class AMTRELAY(dns.rdata.Rdata):
(precedence, relay_type) = parser.get_struct('!BB') (precedence, relay_type) = parser.get_struct('!BB')
discovery_optional = bool(relay_type >> 7) discovery_optional = bool(relay_type >> 7)
relay_type &= 0x7f relay_type &= 0x7f
relay = Relay(relay_type).from_wire_parser(parser, origin) relay = Relay.from_wire_parser(relay_type, parser, origin)
return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, return cls(rdclass, rdtype, precedence, discovery_optional, relay_type,
relay) relay.relay)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase import dns.rdtypes.txtbase
import dns.immutable
@dns.immutable.immutable
class AVC(dns.rdtypes.txtbase.TXTBase): class AVC(dns.rdtypes.txtbase.TXTBase):
"""AVC record""" """AVC record"""

View file

@ -18,10 +18,12 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class CAA(dns.rdata.Rdata): class CAA(dns.rdata.Rdata):
"""CAA (Certification Authority Authorization) record""" """CAA (Certification Authority Authorization) record"""
@ -32,9 +34,11 @@ class CAA(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, flags, tag, value): def __init__(self, rdclass, rdtype, flags, tag, value):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'flags', flags) self.flags = self._as_uint8(flags)
object.__setattr__(self, 'tag', tag) self.tag = self._as_bytes(tag, True, 255)
object.__setattr__(self, 'value', value) if not tag.isalnum():
raise ValueError("tag is not alphanumeric")
self.value = self._as_bytes(value)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '%u %s "%s"' % (self.flags, return '%u %s "%s"' % (self.flags,
@ -46,10 +50,6 @@ class CAA(dns.rdata.Rdata):
relativize_to=None): relativize_to=None):
flags = tok.get_uint8() flags = tok.get_uint8()
tag = tok.get_string().encode() tag = tok.get_string().encode()
if len(tag) > 255:
raise dns.exception.SyntaxError("tag too long")
if not tag.isalnum():
raise dns.exception.SyntaxError("tag is not alphanumeric")
value = tok.get_string().encode() value = tok.get_string().encode()
return cls(rdclass, rdtype, flags, tag, value) return cls(rdclass, rdtype, flags, tag, value)

View file

@ -16,9 +16,13 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase import dns.rdtypes.dnskeybase
import dns.immutable
# pylint: disable=unused-import
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401
# pylint: enable=unused-import
@dns.immutable.immutable
class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""CDNSKEY record""" """CDNSKEY record"""

View file

@ -16,8 +16,15 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase import dns.rdtypes.dsbase
import dns.immutable
@dns.immutable.immutable
class CDS(dns.rdtypes.dsbase.DSBase): class CDS(dns.rdtypes.dsbase.DSBase):
"""CDS record""" """CDS record"""
_digest_length_by_type = {
**dns.rdtypes.dsbase.DSBase._digest_length_by_type,
0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049)
}

View file

@ -19,6 +19,7 @@ import struct
import base64 import base64
import dns.exception import dns.exception
import dns.immutable
import dns.dnssec import dns.dnssec
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@ -27,6 +28,11 @@ _ctype_by_value = {
1: 'PKIX', 1: 'PKIX',
2: 'SPKI', 2: 'SPKI',
3: 'PGP', 3: 'PGP',
4: 'IPKIX',
5: 'ISPKI',
6: 'IPGP',
7: 'ACPKIX',
8: 'IACPKIX',
253: 'URI', 253: 'URI',
254: 'OID', 254: 'OID',
} }
@ -35,6 +41,11 @@ _ctype_by_name = {
'PKIX': 1, 'PKIX': 1,
'SPKI': 2, 'SPKI': 2,
'PGP': 3, 'PGP': 3,
'IPKIX': 4,
'ISPKI': 5,
'IPGP': 6,
'ACPKIX': 7,
'IACPKIX': 8,
'URI': 253, 'URI': 253,
'OID': 254, 'OID': 254,
} }
@ -54,27 +65,28 @@ def _ctype_to_text(what):
return str(what) return str(what)
@dns.immutable.immutable
class CERT(dns.rdata.Rdata): class CERT(dns.rdata.Rdata):
"""CERT record""" """CERT record"""
# see RFC 2538 # see RFC 4398
__slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate'] __slots__ = ['certificate_type', 'key_tag', 'algorithm', 'certificate']
def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm,
certificate): certificate):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'certificate_type', certificate_type) self.certificate_type = self._as_uint16(certificate_type)
object.__setattr__(self, 'key_tag', key_tag) self.key_tag = self._as_uint16(key_tag)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_uint8(algorithm)
object.__setattr__(self, 'certificate', certificate) self.certificate = self._as_bytes(certificate)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
certificate_type = _ctype_to_text(self.certificate_type) certificate_type = _ctype_to_text(self.certificate_type)
return "%s %d %s %s" % (certificate_type, self.key_tag, return "%s %d %s %s" % (certificate_type, self.key_tag,
dns.dnssec.algorithm_to_text(self.algorithm), dns.dnssec.algorithm_to_text(self.algorithm),
dns.rdata._base64ify(self.certificate)) dns.rdata._base64ify(self.certificate, **kw))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
@ -82,8 +94,6 @@ class CERT(dns.rdata.Rdata):
certificate_type = _ctype_from_text(tok.get_string()) certificate_type = _ctype_from_text(tok.get_string())
key_tag = tok.get_uint16() key_tag = tok.get_uint16()
algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) algorithm = dns.dnssec.algorithm_from_text(tok.get_string())
if algorithm < 0 or algorithm > 255:
raise dns.exception.SyntaxError("bad algorithm type")
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
certificate = base64.b64decode(b64) certificate = base64.b64decode(b64)
return cls(rdclass, rdtype, certificate_type, key_tag, return cls(rdclass, rdtype, certificate_type, key_tag,

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase import dns.rdtypes.nsbase
import dns.immutable
@dns.immutable.immutable
class CNAME(dns.rdtypes.nsbase.NSBase): class CNAME(dns.rdtypes.nsbase.NSBase):
"""CNAME record """CNAME record

View file

@ -18,16 +18,19 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.name import dns.name
import dns.rdtypes.util import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap): class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'CSYNC' type_name = 'CSYNC'
@dns.immutable.immutable
class CSYNC(dns.rdata.Rdata): class CSYNC(dns.rdata.Rdata):
"""CSYNC record""" """CSYNC record"""
@ -36,9 +39,11 @@ class CSYNC(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, serial, flags, windows): def __init__(self, rdclass, rdtype, serial, flags, windows):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'serial', serial) self.serial = self._as_uint32(serial)
object.__setattr__(self, 'flags', flags) self.flags = self._as_uint16(flags)
object.__setattr__(self, 'windows', dns.rdata._constify(windows)) if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
text = Bitmap(self.windows).to_text() text = Bitmap(self.windows).to_text()
@ -49,8 +54,8 @@ class CSYNC(dns.rdata.Rdata):
relativize_to=None): relativize_to=None):
serial = tok.get_uint32() serial = tok.get_uint32()
flags = tok.get_uint16() flags = tok.get_uint16()
windows = Bitmap().from_text(tok) bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, serial, flags, windows) return cls(rdclass, rdtype, serial, flags, bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!IH', self.serial, self.flags)) file.write(struct.pack('!IH', self.serial, self.flags))
@ -59,5 +64,5 @@ class CSYNC(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(serial, flags) = parser.get_struct("!IH") (serial, flags) = parser.get_struct("!IH")
windows = Bitmap().from_wire_parser(parser) bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, serial, flags, windows) return cls(rdclass, rdtype, serial, flags, bitmap)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase import dns.rdtypes.dsbase
import dns.immutable
@dns.immutable.immutable
class DLV(dns.rdtypes.dsbase.DSBase): class DLV(dns.rdtypes.dsbase.DSBase):
"""DLV record""" """DLV record"""

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase import dns.rdtypes.nsbase
import dns.immutable
@dns.immutable.immutable
class DNAME(dns.rdtypes.nsbase.UncompressedNS): class DNAME(dns.rdtypes.nsbase.UncompressedNS):
"""DNAME record""" """DNAME record"""

View file

@ -16,9 +16,13 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dnskeybase import dns.rdtypes.dnskeybase
import dns.immutable
# pylint: disable=unused-import
from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401 from dns.rdtypes.dnskeybase import SEP, REVOKE, ZONE # noqa: F401
# pylint: enable=unused-import
@dns.immutable.immutable
class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase):
"""DNSKEY record""" """DNSKEY record"""

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.dsbase import dns.rdtypes.dsbase
import dns.immutable
@dns.immutable.immutable
class DS(dns.rdtypes.dsbase.DSBase): class DS(dns.rdtypes.dsbase.DSBase):
"""DS record""" """DS record"""

View file

@ -17,8 +17,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.euibase import dns.rdtypes.euibase
import dns.immutable
@dns.immutable.immutable
class EUI48(dns.rdtypes.euibase.EUIBase): class EUI48(dns.rdtypes.euibase.EUIBase):
"""EUI48 record""" """EUI48 record"""

View file

@ -17,8 +17,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.euibase import dns.rdtypes.euibase
import dns.immutable
@dns.immutable.immutable
class EUI64(dns.rdtypes.euibase.EUIBase): class EUI64(dns.rdtypes.euibase.EUIBase):
"""EUI64 record""" """EUI64 record"""

View file

@ -18,6 +18,7 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@ -41,12 +42,7 @@ def _validate_float_string(what):
raise dns.exception.FormError raise dns.exception.FormError
def _sanitize(value): @dns.immutable.immutable
if isinstance(value, str):
return value.encode()
return value
class GPOS(dns.rdata.Rdata): class GPOS(dns.rdata.Rdata):
"""GPOS record""" """GPOS record"""
@ -66,15 +62,15 @@ class GPOS(dns.rdata.Rdata):
if isinstance(altitude, float) or \ if isinstance(altitude, float) or \
isinstance(altitude, int): isinstance(altitude, int):
altitude = str(altitude) altitude = str(altitude)
latitude = _sanitize(latitude) latitude = self._as_bytes(latitude, True, 255)
longitude = _sanitize(longitude) longitude = self._as_bytes(longitude, True, 255)
altitude = _sanitize(altitude) altitude = self._as_bytes(altitude, True, 255)
_validate_float_string(latitude) _validate_float_string(latitude)
_validate_float_string(longitude) _validate_float_string(longitude)
_validate_float_string(altitude) _validate_float_string(altitude)
object.__setattr__(self, 'latitude', latitude) self.latitude = latitude
object.__setattr__(self, 'longitude', longitude) self.longitude = longitude
object.__setattr__(self, 'altitude', altitude) self.altitude = altitude
flat = self.float_latitude flat = self.float_latitude
if flat < -90.0 or flat > 90.0: if flat < -90.0 or flat > 90.0:
raise dns.exception.FormError('bad latitude') raise dns.exception.FormError('bad latitude')
@ -93,7 +89,6 @@ class GPOS(dns.rdata.Rdata):
latitude = tok.get_string() latitude = tok.get_string()
longitude = tok.get_string() longitude = tok.get_string()
altitude = tok.get_string() altitude = tok.get_string()
tok.get_eol()
return cls(rdclass, rdtype, latitude, longitude, altitude) return cls(rdclass, rdtype, latitude, longitude, altitude)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -18,10 +18,12 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class HINFO(dns.rdata.Rdata): class HINFO(dns.rdata.Rdata):
"""HINFO record""" """HINFO record"""
@ -32,14 +34,8 @@ class HINFO(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, cpu, os): def __init__(self, rdclass, rdtype, cpu, os):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
if isinstance(cpu, str): self.cpu = self._as_bytes(cpu, True, 255)
object.__setattr__(self, 'cpu', cpu.encode()) self.os = self._as_bytes(os, True, 255)
else:
object.__setattr__(self, 'cpu', cpu)
if isinstance(os, str):
object.__setattr__(self, 'os', os.encode())
else:
object.__setattr__(self, 'os', os)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu),
@ -50,7 +46,6 @@ class HINFO(dns.rdata.Rdata):
relativize_to=None): relativize_to=None):
cpu = tok.get_string(max_length=255) cpu = tok.get_string(max_length=255)
os = tok.get_string(max_length=255) os = tok.get_string(max_length=255)
tok.get_eol()
return cls(rdclass, rdtype, cpu, os) return cls(rdclass, rdtype, cpu, os)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -20,10 +20,12 @@ import base64
import binascii import binascii
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
@dns.immutable.immutable
class HIP(dns.rdata.Rdata): class HIP(dns.rdata.Rdata):
"""HIP record""" """HIP record"""
@ -34,10 +36,10 @@ class HIP(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): def __init__(self, rdclass, rdtype, hit, algorithm, key, servers):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'hit', hit) self.hit = self._as_bytes(hit, True, 255)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_uint8(algorithm)
object.__setattr__(self, 'key', key) self.key = self._as_bytes(key, True)
object.__setattr__(self, 'servers', dns.rdata._constify(servers)) self.servers = self._as_tuple(servers, self._as_name)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
hit = binascii.hexlify(self.hit).decode() hit = binascii.hexlify(self.hit).decode()
@ -55,14 +57,9 @@ class HIP(dns.rdata.Rdata):
relativize_to=None): relativize_to=None):
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
hit = binascii.unhexlify(tok.get_string().encode()) hit = binascii.unhexlify(tok.get_string().encode())
if len(hit) > 255:
raise dns.exception.SyntaxError("HIT too long")
key = base64.b64decode(tok.get_string().encode()) key = base64.b64decode(tok.get_string().encode())
servers = [] servers = []
while 1: for token in tok.get_remaining():
token = tok.get()
if token.is_eol_or_eof():
break
server = tok.as_name(token, origin, relativize, relativize_to) server = tok.as_name(token, origin, relativize, relativize_to)
servers.append(server) servers.append(server)
return cls(rdclass, rdtype, hit, algorithm, key, servers) return cls(rdclass, rdtype, hit, algorithm, key, servers)

View file

@ -18,10 +18,12 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class ISDN(dns.rdata.Rdata): class ISDN(dns.rdata.Rdata):
"""ISDN record""" """ISDN record"""
@ -32,14 +34,8 @@ class ISDN(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, address, subaddress): def __init__(self, rdclass, rdtype, address, subaddress):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
if isinstance(address, str): self.address = self._as_bytes(address, True, 255)
object.__setattr__(self, 'address', address.encode()) self.subaddress = self._as_bytes(subaddress, True, 255)
else:
object.__setattr__(self, 'address', address)
if isinstance(address, str):
object.__setattr__(self, 'subaddress', subaddress.encode())
else:
object.__setattr__(self, 'subaddress', subaddress)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress: if self.subaddress:
@ -52,14 +48,11 @@ class ISDN(dns.rdata.Rdata):
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
address = tok.get_string() address = tok.get_string()
t = tok.get() tokens = tok.get_remaining(max_tokens=1)
if not t.is_eol_or_eof(): if len(tokens) >= 1:
tok.unget(t) subaddress = tokens[0].unescape().value
subaddress = tok.get_string()
else: else:
tok.unget(t)
subaddress = '' subaddress = ''
tok.get_eol()
return cls(rdclass, rdtype, address, subaddress) return cls(rdclass, rdtype, address, subaddress)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -0,0 +1,40 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
@dns.immutable.immutable
class L32(dns.rdata.Rdata):
"""L32 record"""
# see: rfc6742.txt
__slots__ = ['preference', 'locator32']
def __init__(self, rdclass, rdtype, preference, locator32):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.locator32 = self._as_ipv4_address(locator32)
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.locator32}'
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
preference = tok.get_uint16()
nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(dns.ipv4.inet_aton(self.locator32))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
locator32 = parser.get_remaining()
return cls(rdclass, rdtype, preference, locator32)

View file

@ -0,0 +1,48 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
import dns.rdtypes.util
@dns.immutable.immutable
class L64(dns.rdata.Rdata):
"""L64 record"""
# see: rfc6742.txt
__slots__ = ['preference', 'locator64']
def __init__(self, rdclass, rdtype, preference, locator64):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(locator64, bytes):
if len(locator64) != 8:
raise ValueError('invalid locator64')
self.locator64 = dns.rdata._hexify(locator64, 4, b':')
else:
dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ':')
self.locator64 = locator64
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.locator64}'
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
preference = tok.get_uint16()
locator64 = tok.get_identifier()
return cls(rdclass, rdtype, preference, locator64)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64,
4, 4, ':'))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
locator64 = parser.get_remaining()
return cls(rdclass, rdtype, preference, locator64)

View file

@ -18,6 +18,7 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
@ -34,17 +35,13 @@ _MIN_LATITUDE = 0x80000000 - 90 * 3600000
_MAX_LONGITUDE = 0x80000000 + 180 * 3600000 _MAX_LONGITUDE = 0x80000000 + 180 * 3600000
_MIN_LONGITUDE = 0x80000000 - 180 * 3600000 _MIN_LONGITUDE = 0x80000000 - 180 * 3600000
# pylint complains about division since we don't have a from __future__ for
# it, but we don't care about python 2 warnings, so turn them off.
#
# pylint: disable=old-division
def _exponent_of(what, desc): def _exponent_of(what, desc):
if what == 0: if what == 0:
return 0 return 0
exp = None exp = None
for (i, pow) in enumerate(_pows): for (i, pow) in enumerate(_pows):
if what // pow == 0: if what < pow:
exp = i - 1 exp = i - 1
break break
if exp is None or exp < 0: if exp is None or exp < 0:
@ -58,7 +55,7 @@ def _float_to_tuple(what):
what *= -1 what *= -1
else: else:
sign = 1 sign = 1
what = round(what * 3600000) # pylint: disable=round-builtin what = round(what * 3600000)
degrees = int(what // 3600000) degrees = int(what // 3600000)
what -= degrees * 3600000 what -= degrees * 3600000
minutes = int(what // 60000) minutes = int(what // 60000)
@ -94,6 +91,20 @@ def _decode_size(what, desc):
return base * pow(10, exponent) return base * pow(10, exponent)
def _check_coordinate_list(value, low, high):
if value[0] < low or value[0] > high:
raise ValueError(f'not in range [{low}, {high}]')
if value[1] < 0 or value[1] > 59:
raise ValueError('bad minutes value')
if value[2] < 0 or value[2] > 59:
raise ValueError('bad seconds value')
if value[3] < 0 or value[3] > 999:
raise ValueError('bad milliseconds value')
if value[4] != 1 and value[4] != -1:
raise ValueError('bad hemisphere value')
@dns.immutable.immutable
class LOC(dns.rdata.Rdata): class LOC(dns.rdata.Rdata):
"""LOC record""" """LOC record"""
@ -119,16 +130,18 @@ class LOC(dns.rdata.Rdata):
latitude = float(latitude) latitude = float(latitude)
if isinstance(latitude, float): if isinstance(latitude, float):
latitude = _float_to_tuple(latitude) latitude = _float_to_tuple(latitude)
object.__setattr__(self, 'latitude', dns.rdata._constify(latitude)) _check_coordinate_list(latitude, -90, 90)
self.latitude = tuple(latitude)
if isinstance(longitude, int): if isinstance(longitude, int):
longitude = float(longitude) longitude = float(longitude)
if isinstance(longitude, float): if isinstance(longitude, float):
longitude = _float_to_tuple(longitude) longitude = _float_to_tuple(longitude)
object.__setattr__(self, 'longitude', dns.rdata._constify(longitude)) _check_coordinate_list(longitude, -180, 180)
object.__setattr__(self, 'altitude', float(altitude)) self.longitude = tuple(longitude)
object.__setattr__(self, 'size', float(size)) self.altitude = float(altitude)
object.__setattr__(self, 'horizontal_precision', float(hprec)) self.size = float(size)
object.__setattr__(self, 'vertical_precision', float(vprec)) self.horizontal_precision = float(hprec)
self.vertical_precision = float(vprec)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.latitude[4] > 0: if self.latitude[4] > 0:
@ -167,13 +180,9 @@ class LOC(dns.rdata.Rdata):
vprec = _default_vprec vprec = _default_vprec
latitude[0] = tok.get_int() latitude[0] = tok.get_int()
if latitude[0] > 90:
raise dns.exception.SyntaxError('latitude >= 90')
t = tok.get_string() t = tok.get_string()
if t.isdigit(): if t.isdigit():
latitude[1] = int(t) latitude[1] = int(t)
if latitude[1] >= 60:
raise dns.exception.SyntaxError('latitude minutes >= 60')
t = tok.get_string() t = tok.get_string()
if '.' in t: if '.' in t:
(seconds, milliseconds) = t.split('.') (seconds, milliseconds) = t.split('.')
@ -181,8 +190,6 @@ class LOC(dns.rdata.Rdata):
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
'bad latitude seconds value') 'bad latitude seconds value')
latitude[2] = int(seconds) latitude[2] = int(seconds)
if latitude[2] >= 60:
raise dns.exception.SyntaxError('latitude seconds >= 60')
l = len(milliseconds) l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit(): if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
@ -204,13 +211,9 @@ class LOC(dns.rdata.Rdata):
raise dns.exception.SyntaxError('bad latitude hemisphere value') raise dns.exception.SyntaxError('bad latitude hemisphere value')
longitude[0] = tok.get_int() longitude[0] = tok.get_int()
if longitude[0] > 180:
raise dns.exception.SyntaxError('longitude > 180')
t = tok.get_string() t = tok.get_string()
if t.isdigit(): if t.isdigit():
longitude[1] = int(t) longitude[1] = int(t)
if longitude[1] >= 60:
raise dns.exception.SyntaxError('longitude minutes >= 60')
t = tok.get_string() t = tok.get_string()
if '.' in t: if '.' in t:
(seconds, milliseconds) = t.split('.') (seconds, milliseconds) = t.split('.')
@ -218,8 +221,6 @@ class LOC(dns.rdata.Rdata):
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
'bad longitude seconds value') 'bad longitude seconds value')
longitude[2] = int(seconds) longitude[2] = int(seconds)
if longitude[2] >= 60:
raise dns.exception.SyntaxError('longitude seconds >= 60')
l = len(milliseconds) l = len(milliseconds)
if l == 0 or l > 3 or not milliseconds.isdigit(): if l == 0 or l > 3 or not milliseconds.isdigit():
raise dns.exception.SyntaxError( raise dns.exception.SyntaxError(
@ -245,25 +246,22 @@ class LOC(dns.rdata.Rdata):
t = t[0: -1] t = t[0: -1]
altitude = float(t) * 100.0 # m -> cm altitude = float(t) * 100.0 # m -> cm
token = tok.get().unescape() tokens = tok.get_remaining(max_tokens=3)
if not token.is_eol_or_eof(): if len(tokens) >= 1:
value = token.value value = tokens[0].unescape().value
if value[-1] == 'm': if value[-1] == 'm':
value = value[0: -1] value = value[0: -1]
size = float(value) * 100.0 # m -> cm size = float(value) * 100.0 # m -> cm
token = tok.get().unescape() if len(tokens) >= 2:
if not token.is_eol_or_eof(): value = tokens[1].unescape().value
value = token.value
if value[-1] == 'm': if value[-1] == 'm':
value = value[0: -1] value = value[0: -1]
hprec = float(value) * 100.0 # m -> cm hprec = float(value) * 100.0 # m -> cm
token = tok.get().unescape() if len(tokens) >= 3:
if not token.is_eol_or_eof(): value = tokens[2].unescape().value
value = token.value
if value[-1] == 'm': if value[-1] == 'm':
value = value[0: -1] value = value[0: -1]
vprec = float(value) * 100.0 # m -> cm vprec = float(value) * 100.0 # m -> cm
tok.get_eol()
# Try encoding these now so we raise if they are bad # Try encoding these now so we raise if they are bad
_encode_size(size, "size") _encode_size(size, "size")
@ -296,6 +294,8 @@ class LOC(dns.rdata.Rdata):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
(version, size, hprec, vprec, latitude, longitude, altitude) = \ (version, size, hprec, vprec, latitude, longitude, altitude) = \
parser.get_struct("!BBBBIII") parser.get_struct("!BBBBIII")
if version != 0:
raise dns.exception.FormError("LOC version not zero")
if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE:
raise dns.exception.FormError("bad latitude") raise dns.exception.FormError("bad latitude")
if latitude > 0x80000000: if latitude > 0x80000000:

41
lib/dns/rdtypes/ANY/LP.py Normal file
View file

@ -0,0 +1,41 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
@dns.immutable.immutable
class LP(dns.rdata.Rdata):
"""LP record"""
# see: rfc6742.txt
__slots__ = ['preference', 'fqdn']
def __init__(self, rdclass, rdtype, preference, fqdn):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
self.fqdn = self._as_name(fqdn)
def to_text(self, origin=None, relativize=True, **kw):
fqdn = self.fqdn.choose_relativity(origin, relativize)
return '%d %s' % (self.preference, fqdn)
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
preference = tok.get_uint16()
fqdn = tok.get_name(origin, relativize, relativize_to)
return cls(rdclass, rdtype, preference, fqdn)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
self.fqdn.to_wire(file, compress, origin, canonicalize)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
fqdn = parser.get_name(origin)
return cls(rdclass, rdtype, preference, fqdn)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase import dns.rdtypes.mxbase
import dns.immutable
@dns.immutable.immutable
class MX(dns.rdtypes.mxbase.MXBase): class MX(dns.rdtypes.mxbase.MXBase):
"""MX record""" """MX record"""

View file

@ -0,0 +1,47 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import dns.immutable
import dns.rdtypes.util
@dns.immutable.immutable
class NID(dns.rdata.Rdata):
"""NID record"""
# see: rfc6742.txt
__slots__ = ['preference', 'nodeid']
def __init__(self, rdclass, rdtype, preference, nodeid):
super().__init__(rdclass, rdtype)
self.preference = self._as_uint16(preference)
if isinstance(nodeid, bytes):
if len(nodeid) != 8:
raise ValueError('invalid nodeid')
self.nodeid = dns.rdata._hexify(nodeid, 4, b':')
else:
dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ':')
self.nodeid = nodeid
def to_text(self, origin=None, relativize=True, **kw):
return f'{self.preference} {self.nodeid}'
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
preference = tok.get_uint16()
nodeid = tok.get_identifier()
return cls(rdclass, rdtype, preference, nodeid)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
file.write(struct.pack('!H', self.preference))
file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ':'))
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
preference = parser.get_uint16()
nodeid = parser.get_remaining()
return cls(rdclass, rdtype, preference, nodeid)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase import dns.rdtypes.txtbase
import dns.immutable
@dns.immutable.immutable
class NINFO(dns.rdtypes.txtbase.TXTBase): class NINFO(dns.rdtypes.txtbase.TXTBase):
"""NINFO record""" """NINFO record"""

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase import dns.rdtypes.nsbase
import dns.immutable
@dns.immutable.immutable
class NS(dns.rdtypes.nsbase.NSBase): class NS(dns.rdtypes.nsbase.NSBase):
"""NS record""" """NS record"""

View file

@ -16,16 +16,19 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.name import dns.name
import dns.rdtypes.util import dns.rdtypes.util
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap): class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'NSEC' type_name = 'NSEC'
@dns.immutable.immutable
class NSEC(dns.rdata.Rdata): class NSEC(dns.rdata.Rdata):
"""NSEC record""" """NSEC record"""
@ -34,8 +37,10 @@ class NSEC(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, next, windows): def __init__(self, rdclass, rdtype, next, windows):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'next', next) self.next = self._as_name(next)
object.__setattr__(self, 'windows', dns.rdata._constify(windows)) if not isinstance(windows, Bitmap):
windows = Bitmap(windows)
self.windows = tuple(windows.windows)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize) next = self.next.choose_relativity(origin, relativize)
@ -46,15 +51,17 @@ class NSEC(dns.rdata.Rdata):
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
next = tok.get_name(origin, relativize, relativize_to) next = tok.get_name(origin, relativize, relativize_to)
windows = Bitmap().from_text(tok) windows = Bitmap.from_text(tok)
return cls(rdclass, rdtype, next, windows) return cls(rdclass, rdtype, next, windows)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
# Note that NSEC downcasing, originally mandated by RFC 4034
# section 6.2 was removed by RFC 6840 section 5.1.
self.next.to_wire(file, None, origin, False) self.next.to_wire(file, None, origin, False)
Bitmap(self.windows).to_wire(file) Bitmap(self.windows).to_wire(file)
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
next = parser.get_name(origin) next = parser.get_name(origin)
windows = Bitmap().from_wire_parser(parser) bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, next, windows) return cls(rdclass, rdtype, next, bitmap)

View file

@ -20,6 +20,7 @@ import binascii
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
import dns.rdtypes.util import dns.rdtypes.util
@ -37,10 +38,12 @@ SHA1 = 1
OPTOUT = 1 OPTOUT = 1
@dns.immutable.immutable
class Bitmap(dns.rdtypes.util.Bitmap): class Bitmap(dns.rdtypes.util.Bitmap):
type_name = 'NSEC3' type_name = 'NSEC3'
@dns.immutable.immutable
class NSEC3(dns.rdata.Rdata): class NSEC3(dns.rdata.Rdata):
"""NSEC3 record""" """NSEC3 record"""
@ -50,15 +53,14 @@ class NSEC3(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt, def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt,
next, windows): next, windows):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_uint8(algorithm)
object.__setattr__(self, 'flags', flags) self.flags = self._as_uint8(flags)
object.__setattr__(self, 'iterations', iterations) self.iterations = self._as_uint16(iterations)
if isinstance(salt, str): self.salt = self._as_bytes(salt, True, 255)
object.__setattr__(self, 'salt', salt.encode()) self.next = self._as_bytes(next, True, 255)
else: if not isinstance(windows, Bitmap):
object.__setattr__(self, 'salt', salt) windows = Bitmap(windows)
object.__setattr__(self, 'next', next) self.windows = tuple(windows.windows)
object.__setattr__(self, 'windows', dns.rdata._constify(windows))
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
next = base64.b32encode(self.next).translate( next = base64.b32encode(self.next).translate(
@ -85,9 +87,9 @@ class NSEC3(dns.rdata.Rdata):
next = tok.get_string().encode( next = tok.get_string().encode(
'ascii').upper().translate(b32_hex_to_normal) 'ascii').upper().translate(b32_hex_to_normal)
next = base64.b32decode(next) next = base64.b32decode(next)
windows = Bitmap().from_text(tok) bitmap = Bitmap.from_text(tok)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
windows) bitmap)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
l = len(self.salt) l = len(self.salt)
@ -104,6 +106,6 @@ class NSEC3(dns.rdata.Rdata):
(algorithm, flags, iterations) = parser.get_struct('!BBH') (algorithm, flags, iterations) = parser.get_struct('!BBH')
salt = parser.get_counted_bytes() salt = parser.get_counted_bytes()
next = parser.get_counted_bytes() next = parser.get_counted_bytes()
windows = Bitmap().from_wire_parser(parser) bitmap = Bitmap.from_wire_parser(parser)
return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next,
windows) bitmap)

View file

@ -19,9 +19,11 @@ import struct
import binascii import binascii
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
@dns.immutable.immutable
class NSEC3PARAM(dns.rdata.Rdata): class NSEC3PARAM(dns.rdata.Rdata):
"""NSEC3PARAM record""" """NSEC3PARAM record"""
@ -30,13 +32,10 @@ class NSEC3PARAM(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_uint8(algorithm)
object.__setattr__(self, 'flags', flags) self.flags = self._as_uint8(flags)
object.__setattr__(self, 'iterations', iterations) self.iterations = self._as_uint16(iterations)
if isinstance(salt, str): self.salt = self._as_bytes(salt, True, 255)
object.__setattr__(self, 'salt', salt.encode())
else:
object.__setattr__(self, 'salt', salt)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
if self.salt == b'': if self.salt == b'':
@ -57,7 +56,6 @@ class NSEC3PARAM(dns.rdata.Rdata):
salt = '' salt = ''
else: else:
salt = binascii.unhexlify(salt.encode()) salt = binascii.unhexlify(salt.encode())
tok.get_eol()
return cls(rdclass, rdtype, algorithm, flags, iterations, salt) return cls(rdclass, rdtype, algorithm, flags, iterations, salt)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -18,9 +18,11 @@
import base64 import base64
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class OPENPGPKEY(dns.rdata.Rdata): class OPENPGPKEY(dns.rdata.Rdata):
"""OPENPGPKEY record""" """OPENPGPKEY record"""
@ -29,10 +31,10 @@ class OPENPGPKEY(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, key): def __init__(self, rdclass, rdtype, key):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'key', key) self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return dns.rdata._base64ify(self.key) return dns.rdata._base64ify(self.key, chunksize=None, **kw)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,

View file

@ -18,10 +18,15 @@
import struct import struct
import dns.edns import dns.edns
import dns.immutable
import dns.exception import dns.exception
import dns.rdata import dns.rdata
# We don't implement from_text, and that's ok.
# pylint: disable=abstract-method
@dns.immutable.immutable
class OPT(dns.rdata.Rdata): class OPT(dns.rdata.Rdata):
"""OPT record""" """OPT record"""
@ -40,7 +45,11 @@ class OPT(dns.rdata.Rdata):
""" """
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'options', dns.rdata._constify(options)) def as_option(option):
if not isinstance(option, dns.edns.Option):
raise ValueError('option is not a dns.edns.option')
return option
self.options = self._as_tuple(options, as_option)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
for opt in self.options: for opt in self.options:

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.nsbase import dns.rdtypes.nsbase
import dns.immutable
@dns.immutable.immutable
class PTR(dns.rdtypes.nsbase.NSBase): class PTR(dns.rdtypes.nsbase.NSBase):
"""PTR record""" """PTR record"""

View file

@ -16,10 +16,12 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.name import dns.name
@dns.immutable.immutable
class RP(dns.rdata.Rdata): class RP(dns.rdata.Rdata):
"""RP record""" """RP record"""
@ -30,8 +32,8 @@ class RP(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, mbox, txt): def __init__(self, rdclass, rdtype, mbox, txt):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'mbox', mbox) self.mbox = self._as_name(mbox)
object.__setattr__(self, 'txt', txt) self.txt = self._as_name(txt)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
mbox = self.mbox.choose_relativity(origin, relativize) mbox = self.mbox.choose_relativity(origin, relativize)
@ -43,7 +45,6 @@ class RP(dns.rdata.Rdata):
relativize_to=None): relativize_to=None):
mbox = tok.get_name(origin, relativize, relativize_to) mbox = tok.get_name(origin, relativize, relativize_to)
txt = tok.get_name(origin, relativize, relativize_to) txt = tok.get_name(origin, relativize, relativize_to)
tok.get_eol()
return cls(rdclass, rdtype, mbox, txt) return cls(rdclass, rdtype, mbox, txt)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -21,6 +21,7 @@ import struct
import time import time
import dns.dnssec import dns.dnssec
import dns.immutable
import dns.exception import dns.exception
import dns.rdata import dns.rdata
import dns.rdatatype import dns.rdatatype
@ -50,6 +51,7 @@ def posixtime_to_sigtime(what):
return time.strftime('%Y%m%d%H%M%S', time.gmtime(what)) return time.strftime('%Y%m%d%H%M%S', time.gmtime(what))
@dns.immutable.immutable
class RRSIG(dns.rdata.Rdata): class RRSIG(dns.rdata.Rdata):
"""RRSIG record""" """RRSIG record"""
@ -62,15 +64,15 @@ class RRSIG(dns.rdata.Rdata):
original_ttl, expiration, inception, key_tag, signer, original_ttl, expiration, inception, key_tag, signer,
signature): signature):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'type_covered', type_covered) self.type_covered = self._as_rdatatype(type_covered)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = dns.dnssec.Algorithm.make(algorithm)
object.__setattr__(self, 'labels', labels) self.labels = self._as_uint8(labels)
object.__setattr__(self, 'original_ttl', original_ttl) self.original_ttl = self._as_ttl(original_ttl)
object.__setattr__(self, 'expiration', expiration) self.expiration = self._as_uint32(expiration)
object.__setattr__(self, 'inception', inception) self.inception = self._as_uint32(inception)
object.__setattr__(self, 'key_tag', key_tag) self.key_tag = self._as_uint16(key_tag)
object.__setattr__(self, 'signer', signer) self.signer = self._as_name(signer)
object.__setattr__(self, 'signature', signature) self.signature = self._as_bytes(signature)
def covers(self): def covers(self):
return self.type_covered return self.type_covered
@ -85,7 +87,7 @@ class RRSIG(dns.rdata.Rdata):
posixtime_to_sigtime(self.inception), posixtime_to_sigtime(self.inception),
self.key_tag, self.key_tag,
self.signer.choose_relativity(origin, relativize), self.signer.choose_relativity(origin, relativize),
dns.rdata._base64ify(self.signature) dns.rdata._base64ify(self.signature, **kw)
) )
@classmethod @classmethod

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase import dns.rdtypes.mxbase
import dns.immutable
@dns.immutable.immutable
class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX): class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""RT record""" """RT record"""

View file

@ -0,0 +1,9 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.tlsabase
@dns.immutable.immutable
class SMIMEA(dns.rdtypes.tlsabase.TLSABase):
"""SMIMEA record"""

View file

@ -18,10 +18,12 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.name import dns.name
@dns.immutable.immutable
class SOA(dns.rdata.Rdata): class SOA(dns.rdata.Rdata):
"""SOA record""" """SOA record"""
@ -34,13 +36,13 @@ class SOA(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry, def __init__(self, rdclass, rdtype, mname, rname, serial, refresh, retry,
expire, minimum): expire, minimum):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'mname', mname) self.mname = self._as_name(mname)
object.__setattr__(self, 'rname', rname) self.rname = self._as_name(rname)
object.__setattr__(self, 'serial', serial) self.serial = self._as_uint32(serial)
object.__setattr__(self, 'refresh', refresh) self.refresh = self._as_ttl(refresh)
object.__setattr__(self, 'retry', retry) self.retry = self._as_ttl(retry)
object.__setattr__(self, 'expire', expire) self.expire = self._as_ttl(expire)
object.__setattr__(self, 'minimum', minimum) self.minimum = self._as_ttl(minimum)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
mname = self.mname.choose_relativity(origin, relativize) mname = self.mname.choose_relativity(origin, relativize)
@ -59,7 +61,6 @@ class SOA(dns.rdata.Rdata):
retry = tok.get_ttl() retry = tok.get_ttl()
expire = tok.get_ttl() expire = tok.get_ttl()
minimum = tok.get_ttl() minimum = tok.get_ttl()
tok.get_eol()
return cls(rdclass, rdtype, mname, rname, serial, refresh, retry, return cls(rdclass, rdtype, mname, rname, serial, refresh, retry,
expire, minimum) expire, minimum)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase import dns.rdtypes.txtbase
import dns.immutable
@dns.immutable.immutable
class SPF(dns.rdtypes.txtbase.TXTBase): class SPF(dns.rdtypes.txtbase.TXTBase):
"""SPF record""" """SPF record"""

View file

@ -19,9 +19,11 @@ import struct
import binascii import binascii
import dns.rdata import dns.rdata
import dns.immutable
import dns.rdatatype import dns.rdatatype
@dns.immutable.immutable
class SSHFP(dns.rdata.Rdata): class SSHFP(dns.rdata.Rdata):
"""SSHFP record""" """SSHFP record"""
@ -33,15 +35,18 @@ class SSHFP(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, algorithm, fp_type, def __init__(self, rdclass, rdtype, algorithm, fp_type,
fingerprint): fingerprint):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_uint8(algorithm)
object.__setattr__(self, 'fp_type', fp_type) self.fp_type = self._as_uint8(fp_type)
object.__setattr__(self, 'fingerprint', fingerprint) self.fingerprint = self._as_bytes(fingerprint, True)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy()
chunksize = kw.pop('chunksize', 128)
return '%d %d %s' % (self.algorithm, return '%d %d %s' % (self.algorithm,
self.fp_type, self.fp_type,
dns.rdata._hexify(self.fingerprint, dns.rdata._hexify(self.fingerprint,
chunksize=128)) chunksize=chunksize,
**kw))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,

118
lib/dns/rdtypes/ANY/TKEY.py Normal file
View file

@ -0,0 +1,118 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import struct
import dns.dnssec
import dns.immutable
import dns.exception
import dns.rdata
@dns.immutable.immutable
class TKEY(dns.rdata.Rdata):
"""TKEY Record"""
__slots__ = ['algorithm', 'inception', 'expiration', 'mode', 'error',
'key', 'other']
def __init__(self, rdclass, rdtype, algorithm, inception, expiration,
mode, error, key, other=b''):
super().__init__(rdclass, rdtype)
self.algorithm = self._as_name(algorithm)
self.inception = self._as_uint32(inception)
self.expiration = self._as_uint32(expiration)
self.mode = self._as_uint16(mode)
self.error = self._as_uint16(error)
self.key = self._as_bytes(key)
self.other = self._as_bytes(other)
def to_text(self, origin=None, relativize=True, **kw):
_algorithm = self.algorithm.choose_relativity(origin, relativize)
text = '%s %u %u %u %u %s' % (str(_algorithm), self.inception,
self.expiration, self.mode, self.error,
dns.rdata._base64ify(self.key, 0))
if len(self.other) > 0:
text += ' %s' % (dns.rdata._base64ify(self.other, 0))
return text
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
algorithm = tok.get_name(relativize=False)
inception = tok.get_uint32()
expiration = tok.get_uint32()
mode = tok.get_uint16()
error = tok.get_uint16()
key_b64 = tok.get_string().encode()
key = base64.b64decode(key_b64)
other_b64 = tok.concatenate_remaining_identifiers().encode()
other = base64.b64decode(other_b64)
return cls(rdclass, rdtype, algorithm, inception, expiration, mode,
error, key, other)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, compress, origin)
file.write(struct.pack("!IIHH", self.inception, self.expiration,
self.mode, self.error))
file.write(struct.pack("!H", len(self.key)))
file.write(self.key)
file.write(struct.pack("!H", len(self.other)))
if len(self.other) > 0:
file.write(self.other)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
algorithm = parser.get_name(origin)
inception, expiration, mode, error = parser.get_struct("!IIHH")
key = parser.get_counted_bytes(2)
other = parser.get_counted_bytes(2)
return cls(rdclass, rdtype, algorithm, inception, expiration, mode,
error, key, other)
# Constants for the mode field - from RFC 2930:
# 2.5 The Mode Field
#
# The mode field specifies the general scheme for key agreement or
# the purpose of the TKEY DNS message. Servers and resolvers
# supporting this specification MUST implement the Diffie-Hellman key
# agreement mode and the key deletion mode for queries. All other
# modes are OPTIONAL. A server supporting TKEY that receives a TKEY
# request with a mode it does not support returns the BADMODE error.
# The following values of the Mode octet are defined, available, or
# reserved:
#
# Value Description
# ----- -----------
# 0 - reserved, see section 7
# 1 server assignment
# 2 Diffie-Hellman exchange
# 3 GSS-API negotiation
# 4 resolver assignment
# 5 key deletion
# 6-65534 - available, see section 7
# 65535 - reserved, see section 7
SERVER_ASSIGNMENT = 1
DIFFIE_HELLMAN_EXCHANGE = 2
GSSAPI_NEGOTIATION = 3
RESOLVER_ASSIGNMENT = 4
KEY_DELETION = 5

View file

@ -1,67 +1,10 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. import dns.immutable
# import dns.rdtypes.tlsabase
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import struct
import binascii
import dns.rdata
import dns.rdatatype
class TLSA(dns.rdata.Rdata): @dns.immutable.immutable
class TLSA(dns.rdtypes.tlsabase.TLSABase):
"""TLSA record""" """TLSA record"""
# see: RFC 6698
__slots__ = ['usage', 'selector', 'mtype', 'cert']
def __init__(self, rdclass, rdtype, usage, selector,
mtype, cert):
super().__init__(rdclass, rdtype)
object.__setattr__(self, 'usage', usage)
object.__setattr__(self, 'selector', selector)
object.__setattr__(self, 'mtype', mtype)
object.__setattr__(self, 'cert', cert)
def to_text(self, origin=None, relativize=True, **kw):
return '%d %d %d %s' % (self.usage,
self.selector,
self.mtype,
dns.rdata._hexify(self.cert,
chunksize=128))
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
usage = tok.get_uint8()
selector = tok.get_uint8()
mtype = tok.get_uint8()
cert = tok.concatenate_remaining_identifiers().encode()
cert = binascii.unhexlify(cert)
return cls(rdclass, rdtype, usage, selector, mtype, cert)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!BBB", self.usage, self.selector, self.mtype)
file.write(header)
file.write(self.cert)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct("BBB")
cert = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2], cert)

View file

@ -15,12 +15,16 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import base64
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rcode
import dns.rdata import dns.rdata
@dns.immutable.immutable
class TSIG(dns.rdata.Rdata): class TSIG(dns.rdata.Rdata):
"""TSIG record""" """TSIG record"""
@ -52,20 +56,45 @@ class TSIG(dns.rdata.Rdata):
""" """
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_name(algorithm)
object.__setattr__(self, 'time_signed', time_signed) self.time_signed = self._as_uint48(time_signed)
object.__setattr__(self, 'fudge', fudge) self.fudge = self._as_uint16(fudge)
object.__setattr__(self, 'mac', dns.rdata._constify(mac)) self.mac = self._as_bytes(mac)
object.__setattr__(self, 'original_id', original_id) self.original_id = self._as_uint16(original_id)
object.__setattr__(self, 'error', error) self.error = dns.rcode.Rcode.make(error)
object.__setattr__(self, 'other', dns.rdata._constify(other)) self.other = self._as_bytes(other)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
algorithm = self.algorithm.choose_relativity(origin, relativize) algorithm = self.algorithm.choose_relativity(origin, relativize)
return f"{algorithm} {self.fudge} {self.time_signed} " + \ error = dns.rcode.to_text(self.error, True)
text = f"{algorithm} {self.time_signed} {self.fudge} " + \
f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \ f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + \
f"{self.original_id} {self.error} " + \ f"{self.original_id} {error} {len(self.other)}"
f"{len(self.other)} {dns.rdata._base64ify(self.other, 0)}" if self.other:
text += f" {dns.rdata._base64ify(self.other, 0)}"
return text
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
algorithm = tok.get_name(relativize=False)
time_signed = tok.get_uint48()
fudge = tok.get_uint16()
mac_len = tok.get_uint16()
mac = base64.b64decode(tok.get_string())
if len(mac) != mac_len:
raise SyntaxError('invalid MAC')
original_id = tok.get_uint16()
error = dns.rcode.from_text(tok.get_string())
other_len = tok.get_uint16()
if other_len > 0:
other = base64.b64decode(tok.get_string())
if len(other) != other_len:
raise SyntaxError('invalid other data')
else:
other = b''
return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac,
original_id, error, other)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
self.algorithm.to_wire(file, None, origin, False) self.algorithm.to_wire(file, None, origin, False)
@ -81,9 +110,9 @@ class TSIG(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
algorithm = parser.get_name(origin) algorithm = parser.get_name()
(time_hi, time_lo, fudge) = parser.get_struct('!HIH') time_signed = parser.get_uint48()
time_signed = (time_hi << 32) + time_lo fudge = parser.get_uint16()
mac = parser.get_counted_bytes(2) mac = parser.get_counted_bytes(2)
(original_id, error) = parser.get_struct('!HH') (original_id, error) = parser.get_struct('!HH')
other = parser.get_counted_bytes(2) other = parser.get_counted_bytes(2)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.txtbase import dns.rdtypes.txtbase
import dns.immutable
@dns.immutable.immutable
class TXT(dns.rdtypes.txtbase.TXTBase): class TXT(dns.rdtypes.txtbase.TXTBase):
"""TXT record""" """TXT record"""

View file

@ -19,10 +19,13 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.rdtypes.util
import dns.name import dns.name
@dns.immutable.immutable
class URI(dns.rdata.Rdata): class URI(dns.rdata.Rdata):
"""URI record""" """URI record"""
@ -33,14 +36,11 @@ class URI(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, priority, weight, target): def __init__(self, rdclass, rdtype, priority, weight, target):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'priority', priority) self.priority = self._as_uint16(priority)
object.__setattr__(self, 'weight', weight) self.weight = self._as_uint16(weight)
if len(target) < 1: self.target = self._as_bytes(target, True)
if len(self.target) == 0:
raise dns.exception.SyntaxError("URI target cannot be empty") raise dns.exception.SyntaxError("URI target cannot be empty")
if isinstance(target, str):
object.__setattr__(self, 'target', target.encode())
else:
object.__setattr__(self, 'target', target)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '%d %d "%s"' % (self.priority, self.weight, return '%d %d "%s"' % (self.priority, self.weight,
@ -54,7 +54,6 @@ class URI(dns.rdata.Rdata):
target = tok.get().unescape() target = tok.get().unescape()
if not (target.is_quoted_string() or target.is_identifier()): if not (target.is_quoted_string() or target.is_identifier()):
raise dns.exception.SyntaxError("URI target must be a string") raise dns.exception.SyntaxError("URI target must be a string")
tok.get_eol()
return cls(rdclass, rdtype, priority, weight, target.value) return cls(rdclass, rdtype, priority, weight, target.value)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -69,3 +68,13 @@ class URI(dns.rdata.Rdata):
if len(target) == 0: if len(target) == 0:
raise dns.exception.FormError('URI target may not be empty') raise dns.exception.FormError('URI target may not be empty')
return cls(rdclass, rdtype, priority, weight, target) return cls(rdclass, rdtype, priority, weight, target)
def _processing_priority(self):
return self.priority
def _processing_weight(self):
return self.weight
@classmethod
def _processing_order(cls, iterable):
return dns.rdtypes.util.weighted_processing_order(iterable)

View file

@ -18,10 +18,12 @@
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class X25(dns.rdata.Rdata): class X25(dns.rdata.Rdata):
"""X25 record""" """X25 record"""
@ -32,10 +34,7 @@ class X25(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
if isinstance(address, str): self.address = self._as_bytes(address, True, 255)
object.__setattr__(self, 'address', address.encode())
else:
object.__setattr__(self, 'address', address)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return '"%s"' % dns.rdata._escapify(self.address) return '"%s"' % dns.rdata._escapify(self.address)
@ -44,7 +43,6 @@ class X25(dns.rdata.Rdata):
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
address = tok.get_string() address = tok.get_string()
tok.get_eol()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -0,0 +1,65 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import struct
import binascii
import dns.immutable
import dns.rdata
import dns.rdatatype
import dns.zone
@dns.immutable.immutable
class ZONEMD(dns.rdata.Rdata):
"""ZONEMD record"""
# See RFC 8976
__slots__ = ['serial', 'scheme', 'hash_algorithm', 'digest']
def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest):
super().__init__(rdclass, rdtype)
self.serial = self._as_uint32(serial)
self.scheme = dns.zone.DigestScheme.make(scheme)
self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm)
self.digest = self._as_bytes(digest)
if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2
raise ValueError('scheme 0 is reserved')
if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3
raise ValueError('hash_algorithm 0 is reserved')
hasher = dns.zone._digest_hashers.get(self.hash_algorithm)
if hasher and hasher().digest_size != len(self.digest):
raise ValueError('digest length inconsistent with hash algorithm')
def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy()
chunksize = kw.pop('chunksize', 128)
return '%d %d %d %s' % (self.serial, self.scheme, self.hash_algorithm,
dns.rdata._hexify(self.digest,
chunksize=chunksize,
**kw))
@classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None):
serial = tok.get_uint32()
scheme = tok.get_uint8()
hash_algorithm = tok.get_uint8()
digest = tok.concatenate_remaining_identifiers().encode()
digest = binascii.unhexlify(digest)
return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!IBB", self.serial, self.scheme,
self.hash_algorithm)
file.write(header)
file.write(self.digest)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct("!IBB")
digest = parser.get_remaining()
return cls(rdclass, rdtype, header[0], header[1], header[2], digest)

View file

@ -19,6 +19,7 @@
__all__ = [ __all__ = [
'AFSDB', 'AFSDB',
'AMTRELAY',
'AVC', 'AVC',
'CAA', 'CAA',
'CDNSKEY', 'CDNSKEY',
@ -38,6 +39,7 @@ __all__ = [
'ISDN', 'ISDN',
'LOC', 'LOC',
'MX', 'MX',
'NINFO',
'NS', 'NS',
'NSEC', 'NSEC',
'NSEC3', 'NSEC3',
@ -48,12 +50,15 @@ __all__ = [
'RP', 'RP',
'RRSIG', 'RRSIG',
'RT', 'RT',
'SMIMEA',
'SOA', 'SOA',
'SPF', 'SPF',
'SSHFP', 'SSHFP',
'TKEY',
'TLSA', 'TLSA',
'TSIG', 'TSIG',
'TXT', 'TXT',
'URI', 'URI',
'X25', 'X25',
'ZONEMD',
] ]

View file

@ -15,9 +15,12 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase
import struct import struct
import dns.rdtypes.mxbase
import dns.immutable
@dns.immutable.immutable
class A(dns.rdata.Rdata): class A(dns.rdata.Rdata):
"""A record for Chaosnet""" """A record for Chaosnet"""
@ -29,8 +32,8 @@ class A(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, domain, address): def __init__(self, rdclass, rdtype, domain, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'domain', domain) self.domain = self._as_name(domain)
object.__setattr__(self, 'address', address) self.address = self._as_uint16(address)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
domain = self.domain.choose_relativity(origin, relativize) domain = self.domain.choose_relativity(origin, relativize)
@ -41,7 +44,6 @@ class A(dns.rdata.Rdata):
relativize_to=None): relativize_to=None):
domain = tok.get_name(origin, relativize, relativize_to) domain = tok.get_name(origin, relativize, relativize_to)
address = tok.get_uint16(base=8) address = tok.get_uint16(base=8)
tok.get_eol()
return cls(rdclass, rdtype, domain, address) return cls(rdclass, rdtype, domain, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

View file

@ -16,11 +16,13 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.exception import dns.exception
import dns.immutable
import dns.ipv4 import dns.ipv4
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class A(dns.rdata.Rdata): class A(dns.rdata.Rdata):
"""A record.""" """A record."""
@ -29,9 +31,7 @@ class A(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
# check that it's OK self.address = self._as_ipv4_address(address)
dns.ipv4.inet_aton(address)
object.__setattr__(self, 'address', address)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return self.address return self.address
@ -40,7 +40,6 @@ class A(dns.rdata.Rdata):
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
address = tok.get_identifier() address = tok.get_identifier()
tok.get_eol()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -48,5 +47,5 @@ class A(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
address = dns.ipv4.inet_ntoa(parser.get_remaining()) address = parser.get_remaining()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)

View file

@ -16,11 +16,13 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.exception import dns.exception
import dns.immutable
import dns.ipv6 import dns.ipv6
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class AAAA(dns.rdata.Rdata): class AAAA(dns.rdata.Rdata):
"""AAAA record.""" """AAAA record."""
@ -29,9 +31,7 @@ class AAAA(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, address): def __init__(self, rdclass, rdtype, address):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
# check that it's OK self.address = self._as_ipv6_address(address)
dns.ipv6.inet_aton(address)
object.__setattr__(self, 'address', address)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return self.address return self.address
@ -40,7 +40,6 @@ class AAAA(dns.rdata.Rdata):
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
address = tok.get_identifier() address = tok.get_identifier()
tok.get_eol()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
@ -48,5 +47,5 @@ class AAAA(dns.rdata.Rdata):
@classmethod @classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
address = dns.ipv6.inet_ntoa(parser.get_remaining()) address = parser.get_remaining()
return cls(rdclass, rdtype, address) return cls(rdclass, rdtype, address)

View file

@ -20,11 +20,13 @@ import codecs
import struct import struct
import dns.exception import dns.exception
import dns.immutable
import dns.ipv4 import dns.ipv4
import dns.ipv6 import dns.ipv6
import dns.rdata import dns.rdata
import dns.tokenizer import dns.tokenizer
@dns.immutable.immutable
class APLItem: class APLItem:
"""An APL list item.""" """An APL list item."""
@ -32,10 +34,17 @@ class APLItem:
__slots__ = ['family', 'negation', 'address', 'prefix'] __slots__ = ['family', 'negation', 'address', 'prefix']
def __init__(self, family, negation, address, prefix): def __init__(self, family, negation, address, prefix):
self.family = family self.family = dns.rdata.Rdata._as_uint16(family)
self.negation = negation self.negation = dns.rdata.Rdata._as_bool(negation)
self.address = address if self.family == 1:
self.prefix = prefix self.address = dns.rdata.Rdata._as_ipv4_address(address)
self.prefix = dns.rdata.Rdata._as_int(prefix, 0, 32)
elif self.family == 2:
self.address = dns.rdata.Rdata._as_ipv6_address(address)
self.prefix = dns.rdata.Rdata._as_int(prefix, 0, 128)
else:
self.address = dns.rdata.Rdata._as_bytes(address, max_length=127)
self.prefix = dns.rdata.Rdata._as_uint8(prefix)
def __str__(self): def __str__(self):
if self.negation: if self.negation:
@ -68,6 +77,7 @@ class APLItem:
file.write(address) file.write(address)
@dns.immutable.immutable
class APL(dns.rdata.Rdata): class APL(dns.rdata.Rdata):
"""APL record.""" """APL record."""
@ -78,7 +88,10 @@ class APL(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, items): def __init__(self, rdclass, rdtype, items):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'items', dns.rdata._constify(items)) for item in items:
if not isinstance(item, APLItem):
raise ValueError('item not an APLItem')
self.items = tuple(items)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return ' '.join(map(str, self.items)) return ' '.join(map(str, self.items))
@ -87,11 +100,8 @@ class APL(dns.rdata.Rdata):
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
relativize_to=None): relativize_to=None):
items = [] items = []
while True: for token in tok.get_remaining():
token = tok.get().unescape() item = token.unescape().value
if token.is_eol_or_eof():
break
item = token.value
if item[0] == '!': if item[0] == '!':
negation = True negation = True
item = item[1:] item = item[1:]
@ -127,11 +137,9 @@ class APL(dns.rdata.Rdata):
if header[0] == 1: if header[0] == 1:
if l < 4: if l < 4:
address += b'\x00' * (4 - l) address += b'\x00' * (4 - l)
address = dns.ipv4.inet_ntoa(address)
elif header[0] == 2: elif header[0] == 2:
if l < 16: if l < 16:
address += b'\x00' * (16 - l) address += b'\x00' * (16 - l)
address = dns.ipv6.inet_ntoa(address)
else: else:
# #
# This isn't really right according to the RFC, but it # This isn't really right according to the RFC, but it

View file

@ -18,8 +18,10 @@
import base64 import base64
import dns.exception import dns.exception
import dns.immutable
@dns.immutable.immutable
class DHCID(dns.rdata.Rdata): class DHCID(dns.rdata.Rdata):
"""DHCID record""" """DHCID record"""
@ -30,10 +32,10 @@ class DHCID(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, data): def __init__(self, rdclass, rdtype, data):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
object.__setattr__(self, 'data', data) self.data = self._as_bytes(data)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
return dns.rdata._base64ify(self.data) return dns.rdata._base64ify(self.data, **kw)
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,

View file

@ -0,0 +1,8 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.rdtypes.svcbbase
import dns.immutable
@dns.immutable.immutable
class HTTPS(dns.rdtypes.svcbbase.SVCBBase):
"""HTTPS record"""

View file

@ -19,12 +19,14 @@ import struct
import base64 import base64
import dns.exception import dns.exception
import dns.immutable
import dns.rdtypes.util import dns.rdtypes.util
class Gateway(dns.rdtypes.util.Gateway): class Gateway(dns.rdtypes.util.Gateway):
name = 'IPSECKEY gateway' name = 'IPSECKEY gateway'
@dns.immutable.immutable
class IPSECKEY(dns.rdata.Rdata): class IPSECKEY(dns.rdata.Rdata):
"""IPSECKEY record""" """IPSECKEY record"""
@ -36,19 +38,19 @@ class IPSECKEY(dns.rdata.Rdata):
def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm, def __init__(self, rdclass, rdtype, precedence, gateway_type, algorithm,
gateway, key): gateway, key):
super().__init__(rdclass, rdtype) super().__init__(rdclass, rdtype)
Gateway(gateway_type, gateway).check() gateway = Gateway(gateway_type, gateway)
object.__setattr__(self, 'precedence', precedence) self.precedence = self._as_uint8(precedence)
object.__setattr__(self, 'gateway_type', gateway_type) self.gateway_type = gateway.type
object.__setattr__(self, 'algorithm', algorithm) self.algorithm = self._as_uint8(algorithm)
object.__setattr__(self, 'gateway', gateway) self.gateway = gateway.gateway
object.__setattr__(self, 'key', key) self.key = self._as_bytes(key)
def to_text(self, origin=None, relativize=True, **kw): def to_text(self, origin=None, relativize=True, **kw):
gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, gateway = Gateway(self.gateway_type, self.gateway).to_text(origin,
relativize) relativize)
return '%d %d %d %s %s' % (self.precedence, self.gateway_type, return '%d %d %d %s %s' % (self.precedence, self.gateway_type,
self.algorithm, gateway, self.algorithm, gateway,
dns.rdata._base64ify(self.key)) dns.rdata._base64ify(self.key, **kw))
@classmethod @classmethod
def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True,
@ -56,12 +58,12 @@ class IPSECKEY(dns.rdata.Rdata):
precedence = tok.get_uint8() precedence = tok.get_uint8()
gateway_type = tok.get_uint8() gateway_type = tok.get_uint8()
algorithm = tok.get_uint8() algorithm = tok.get_uint8()
gateway = Gateway(gateway_type).from_text(tok, origin, relativize, gateway = Gateway.from_text(gateway_type, tok, origin, relativize,
relativize_to) relativize_to)
b64 = tok.concatenate_remaining_identifiers().encode() b64 = tok.concatenate_remaining_identifiers().encode()
key = base64.b64decode(b64) key = base64.b64decode(b64)
return cls(rdclass, rdtype, precedence, gateway_type, algorithm, return cls(rdclass, rdtype, precedence, gateway_type, algorithm,
gateway, key) gateway.gateway, key)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False): def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
header = struct.pack("!BBB", self.precedence, self.gateway_type, header = struct.pack("!BBB", self.precedence, self.gateway_type,
@ -75,7 +77,7 @@ class IPSECKEY(dns.rdata.Rdata):
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
header = parser.get_struct('!BBB') header = parser.get_struct('!BBB')
gateway_type = header[1] gateway_type = header[1]
gateway = Gateway(gateway_type).from_wire_parser(parser, origin) gateway = Gateway.from_wire_parser(gateway_type, parser, origin)
key = parser.get_remaining() key = parser.get_remaining()
return cls(rdclass, rdtype, header[0], gateway_type, header[2], return cls(rdclass, rdtype, header[0], gateway_type, header[2],
gateway, key) gateway.gateway, key)

View file

@ -16,8 +16,10 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.rdtypes.mxbase import dns.rdtypes.mxbase
import dns.immutable
@dns.immutable.immutable
class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX): class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX):
"""KX record""" """KX record"""

Some files were not shown because too many files have changed in this diff Show more