From c0aa4e49968b283d54a44669314a4ecb17fa697a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:05:11 -0700 Subject: [PATCH 01/10] Bump dnspython from 2.3.0 to 2.4.2 (#2123) * Bump dnspython from 2.3.0 to 2.4.2 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.3.0 to 2.4.2. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/master/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.3.0...v2.4.2) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update dnspython==2.4.2 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/dns/__init__.py | 1 + lib/dns/_asyncbackend.py | 14 + lib/dns/_asyncio_backend.py | 118 ++++- lib/dns/_curio_backend.py | 122 ------ lib/dns/_ddr.py | 154 +++++++ lib/dns/_immutable_ctx.py | 1 - lib/dns/_trio_backend.py | 116 ++++- lib/dns/asyncbackend.py | 23 +- lib/dns/asyncquery.py | 77 ++-- lib/dns/asyncresolver.py | 245 +++++++++-- lib/dns/dnssec.py | 671 +++++++++++++---------------- lib/dns/dnssecalgs/__init__.py | 121 ++++++ lib/dns/dnssecalgs/base.py | 84 ++++ lib/dns/dnssecalgs/cryptography.py | 68 +++ lib/dns/dnssecalgs/dsa.py | 101 +++++ lib/dns/dnssecalgs/ecdsa.py | 89 ++++ lib/dns/dnssecalgs/eddsa.py | 65 +++ lib/dns/dnssecalgs/rsa.py | 119 +++++ lib/dns/edns.py | 11 +- lib/dns/entropy.py | 6 +- lib/dns/enum.py | 34 +- lib/dns/exception.py | 16 + lib/dns/flags.py | 3 +- lib/dns/immutable.py | 3 +- lib/dns/inet.py | 13 +- lib/dns/ipv4.py | 3 +- lib/dns/ipv6.py | 5 +- lib/dns/message.py | 102 +++-- lib/dns/name.py | 9 +- lib/dns/nameserver.py | 329 ++++++++++++++ lib/dns/node.py | 6 +- lib/dns/query.py | 260 ++++++----- lib/dns/quic/__init__.py | 7 +- lib/dns/quic/_asyncio.py | 45 +- lib/dns/quic/_common.py | 10 +- lib/dns/quic/_sync.py | 68 +-- lib/dns/quic/_trio.py | 69 +-- lib/dns/rdata.py | 27 +- lib/dns/rdataset.py | 13 +- lib/dns/rdtypes/ANY/AFSDB.py | 2 +- lib/dns/rdtypes/ANY/AVC.py | 2 +- lib/dns/rdtypes/ANY/CDNSKEY.py | 8 +- lib/dns/rdtypes/ANY/CDS.py | 2 +- lib/dns/rdtypes/ANY/CERT.py | 4 +- lib/dns/rdtypes/ANY/CNAME.py | 2 +- lib/dns/rdtypes/ANY/CSYNC.py | 2 +- lib/dns/rdtypes/ANY/DLV.py | 2 +- lib/dns/rdtypes/ANY/DNAME.py | 2 +- lib/dns/rdtypes/ANY/DNSKEY.py | 8 +- lib/dns/rdtypes/ANY/DS.py | 2 +- lib/dns/rdtypes/ANY/EUI48.py | 2 +- lib/dns/rdtypes/ANY/EUI64.py | 2 +- lib/dns/rdtypes/ANY/HIP.py | 2 +- lib/dns/rdtypes/ANY/LOC.py | 3 +- lib/dns/rdtypes/ANY/MX.py | 2 +- lib/dns/rdtypes/ANY/NINFO.py | 2 +- lib/dns/rdtypes/ANY/NS.py | 2 +- lib/dns/rdtypes/ANY/NSEC.py | 2 +- lib/dns/rdtypes/ANY/NSEC3.py | 6 +- lib/dns/rdtypes/ANY/NSEC3PARAM.py | 2 +- lib/dns/rdtypes/ANY/OPT.py | 3 +- lib/dns/rdtypes/ANY/PTR.py | 2 +- lib/dns/rdtypes/ANY/RP.py | 2 +- lib/dns/rdtypes/ANY/RRSIG.py | 2 +- lib/dns/rdtypes/ANY/RT.py | 2 +- lib/dns/rdtypes/ANY/SOA.py | 2 +- lib/dns/rdtypes/ANY/SPF.py | 2 +- lib/dns/rdtypes/ANY/SSHFP.py | 4 +- lib/dns/rdtypes/ANY/TKEY.py | 2 +- lib/dns/rdtypes/ANY/TXT.py | 2 +- lib/dns/rdtypes/ANY/URI.py | 2 +- lib/dns/rdtypes/ANY/ZONEMD.py | 2 +- lib/dns/rdtypes/CH/A.py | 2 +- lib/dns/rdtypes/IN/APL.py | 1 - lib/dns/rdtypes/IN/HTTPS.py | 2 +- lib/dns/rdtypes/IN/IPSECKEY.py | 2 +- lib/dns/rdtypes/IN/KX.py | 2 +- lib/dns/rdtypes/IN/NSAP_PTR.py | 2 +- lib/dns/rdtypes/IN/PX.py | 2 +- lib/dns/rdtypes/IN/SRV.py | 2 +- lib/dns/rdtypes/IN/SVCB.py | 2 +- lib/dns/rdtypes/IN/WKS.py | 2 +- lib/dns/rdtypes/dnskeybase.py | 4 +- lib/dns/rdtypes/dsbase.py | 4 +- lib/dns/rdtypes/euibase.py | 2 +- lib/dns/rdtypes/mxbase.py | 2 +- lib/dns/rdtypes/nsbase.py | 2 +- lib/dns/rdtypes/svcbbase.py | 1 + lib/dns/rdtypes/tlsabase.py | 4 +- lib/dns/rdtypes/txtbase.py | 3 +- lib/dns/rdtypes/util.py | 25 +- lib/dns/renderer.py | 3 +- lib/dns/resolver.py | 537 +++++++++++++++++------ lib/dns/reversename.py | 4 +- lib/dns/rrset.py | 10 +- lib/dns/tokenizer.py | 3 +- lib/dns/transaction.py | 32 +- lib/dns/tsig.py | 6 +- lib/dns/tsigkeyring.py | 7 +- lib/dns/update.py | 7 +- lib/dns/version.py | 6 +- lib/dns/versioned.py | 6 +- lib/dns/win32util.py | 8 +- lib/dns/wire.py | 3 +- lib/dns/xfr.py | 2 +- lib/dns/zone.py | 55 +-- lib/dns/zonefile.py | 48 ++- requirements.txt | 2 +- 108 files changed, 2985 insertions(+), 1136 deletions(-) delete mode 100644 lib/dns/_curio_backend.py create mode 100644 lib/dns/_ddr.py create mode 100644 lib/dns/dnssecalgs/__init__.py create mode 100644 lib/dns/dnssecalgs/base.py create mode 100644 lib/dns/dnssecalgs/cryptography.py create mode 100644 lib/dns/dnssecalgs/dsa.py create mode 100644 lib/dns/dnssecalgs/ecdsa.py create mode 100644 lib/dns/dnssecalgs/eddsa.py create mode 100644 lib/dns/dnssecalgs/rsa.py create mode 100644 lib/dns/nameserver.py diff --git a/lib/dns/__init__.py b/lib/dns/__init__.py index 9abdf018..a4249b9e 100644 --- a/lib/dns/__init__.py +++ b/lib/dns/__init__.py @@ -22,6 +22,7 @@ __all__ = [ "asyncquery", "asyncresolver", "dnssec", + "dnssecalgs", "dnssectypes", "e164", "edns", diff --git a/lib/dns/_asyncbackend.py b/lib/dns/_asyncbackend.py index ff24604f..49f14fed 100644 --- a/lib/dns/_asyncbackend.py +++ b/lib/dns/_asyncbackend.py @@ -35,6 +35,9 @@ class Socket: # pragma: no cover async def getsockname(self): raise NotImplementedError + async def getpeercert(self, timeout): + raise NotImplementedError + async def __aenter__(self): return self @@ -61,6 +64,11 @@ class StreamSocket(Socket): # pragma: no cover raise NotImplementedError +class NullTransport: + async def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + class Backend: # pragma: no cover def name(self): return "unknown" @@ -83,3 +91,9 @@ class Backend: # pragma: no cover async def sleep(self, interval): raise NotImplementedError + + def get_transport_class(self): + raise NotImplementedError + + async def wait_for(self, awaitable, timeout): + raise NotImplementedError diff --git a/lib/dns/_asyncio_backend.py b/lib/dns/_asyncio_backend.py index 82a06249..2631228e 100644 --- a/lib/dns/_asyncio_backend.py +++ b/lib/dns/_asyncio_backend.py @@ -2,14 +2,13 @@ """asyncio library query support""" -import socket import asyncio +import socket import sys import dns._asyncbackend import dns.exception - _is_win32 = sys.platform == "win32" @@ -38,14 +37,21 @@ class _DatagramProtocol: def connection_lost(self, exc): if self.recvfrom and not self.recvfrom.done(): - self.recvfrom.set_exception(exc) + if exc is None: + # EOF we triggered. Is there a better way to do this? + try: + raise EOFError + except EOFError as e: + self.recvfrom.set_exception(e) + else: + self.recvfrom.set_exception(exc) def close(self): self.transport.close() async def _maybe_wait_for(awaitable, timeout): - if timeout: + if timeout is not None: try: return await asyncio.wait_for(awaitable, timeout) except asyncio.TimeoutError: @@ -85,6 +91,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getsockname(self): return self.transport.get_extra_info("sockname") + async def getpeercert(self, timeout): + raise NotImplementedError + class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, af, reader, writer): @@ -101,10 +110,6 @@ class StreamSocket(dns._asyncbackend.StreamSocket): async def close(self): self.writer.close() - try: - await self.writer.wait_closed() - except AttributeError: # pragma: no cover - pass async def getpeername(self): return self.writer.get_extra_info("peername") @@ -112,6 +117,97 @@ class StreamSocket(dns._asyncbackend.StreamSocket): async def getsockname(self): return self.writer.get_extra_info("sockname") + async def getpeercert(self, timeout): + return self.writer.get_extra_info("peercert") + + +try: + import anyio + import httpcore + import httpcore._backends.anyio + import httpx + + _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend + _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream + + from dns.query import _compute_times, _expiration_for_this_attempt, _remaining + + class _NetworkBackend(_CoreAsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + if local_port != 0: + raise NotImplementedError( + "the asyncio transport for HTTPX cannot set the local port" + ) + + async def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + with anyio.fail_after(timeout): + stream = await anyio.connect_tcp( + remote_host=address, + remote_port=port, + local_host=local_address, + ) + return _CoreAnyIOStream(stream) + except Exception: + pass + raise httpcore.ConnectError + + async def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + async def sleep(self, seconds): # pylint: disable=signature-differs + await anyio.sleep(seconds) + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + class Backend(dns._asyncbackend.Backend): def name(self): @@ -171,3 +267,9 @@ class Backend(dns._asyncbackend.Backend): def datagram_connection_required(self): return _is_win32 + + def get_transport_class(self): + return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + return await _maybe_wait_for(awaitable, timeout) diff --git a/lib/dns/_curio_backend.py b/lib/dns/_curio_backend.py deleted file mode 100644 index 765d6471..00000000 --- a/lib/dns/_curio_backend.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -"""curio async I/O library query support""" - -import socket -import curio -import curio.socket # type: ignore - -import dns._asyncbackend -import dns.exception -import dns.inet - - -def _maybe_timeout(timeout): - if timeout: - return curio.ignore_after(timeout) - else: - return dns._asyncbackend.NullContext() - - -# for brevity -_lltuple = dns.inet.low_level_address_tuple - -# pylint: disable=redefined-outer-name - - -class DatagramSocket(dns._asyncbackend.DatagramSocket): - def __init__(self, socket): - super().__init__(socket.family) - self.socket = socket - - async def sendto(self, what, destination, timeout): - async with _maybe_timeout(timeout): - return await self.socket.sendto(what, destination) - raise dns.exception.Timeout( - timeout=timeout - ) # pragma: no cover lgtm[py/unreachable-statement] - - async def recvfrom(self, size, timeout): - async with _maybe_timeout(timeout): - return await self.socket.recvfrom(size) - raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] - - async def close(self): - await self.socket.close() - - async def getpeername(self): - return self.socket.getpeername() - - async def getsockname(self): - return self.socket.getsockname() - - -class StreamSocket(dns._asyncbackend.StreamSocket): - def __init__(self, socket): - self.socket = socket - self.family = socket.family - - async def sendall(self, what, timeout): - async with _maybe_timeout(timeout): - return await self.socket.sendall(what) - raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] - - async def recv(self, size, timeout): - async with _maybe_timeout(timeout): - return await self.socket.recv(size) - raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] - - async def close(self): - await self.socket.close() - - async def getpeername(self): - return self.socket.getpeername() - - async def getsockname(self): - return self.socket.getsockname() - - -class Backend(dns._asyncbackend.Backend): - def name(self): - return "curio" - - async def make_socket( - self, - af, - socktype, - proto=0, - source=None, - destination=None, - timeout=None, - ssl_context=None, - server_hostname=None, - ): - if socktype == socket.SOCK_DGRAM: - s = curio.socket.socket(af, socktype, proto) - try: - if source: - s.bind(_lltuple(source, af)) - except Exception: # pragma: no cover - await s.close() - raise - return DatagramSocket(s) - elif socktype == socket.SOCK_STREAM: - if source: - source_addr = _lltuple(source, af) - else: - source_addr = None - async with _maybe_timeout(timeout): - s = await curio.open_connection( - destination[0], - destination[1], - ssl=ssl_context, - source_addr=source_addr, - server_hostname=server_hostname, - ) - return StreamSocket(s) - raise NotImplementedError( - "unsupported socket " + f"type {socktype}" - ) # pragma: no cover - - async def sleep(self, interval): - await curio.sleep(interval) diff --git a/lib/dns/_ddr.py b/lib/dns/_ddr.py new file mode 100644 index 00000000..bf5c11eb --- /dev/null +++ b/lib/dns/_ddr.py @@ -0,0 +1,154 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +# +# Support for Discovery of Designated Resolvers + +import socket +import time +from urllib.parse import urlparse + +import dns.asyncbackend +import dns.inet +import dns.name +import dns.nameserver +import dns.query +import dns.rdtypes.svcbbase + +# The special name of the local resolver when using DDR +_local_resolver_name = dns.name.from_text("_dns.resolver.arpa") + + +# +# Processing is split up into I/O independent and I/O dependent parts to +# make supporting sync and async versions easy. +# + + +class _SVCBInfo: + def __init__(self, bootstrap_address, port, hostname, nameservers): + self.bootstrap_address = bootstrap_address + self.port = port + self.hostname = hostname + self.nameservers = nameservers + + def ddr_check_certificate(self, cert): + """Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)""" + for name, value in cert["subjectAltName"]: + if name == "IP Address" and value == self.bootstrap_address: + return True + return False + + def make_tls_context(self): + ssl = dns.query.ssl + ctx = ssl.create_default_context() + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + return ctx + + def ddr_tls_check_sync(self, lifetime): + ctx = self.make_tls_context() + expiration = time.time() + lifetime + with socket.create_connection( + (self.bootstrap_address, self.port), lifetime + ) as s: + with ctx.wrap_socket(s, server_hostname=self.hostname) as ts: + ts.settimeout(dns.query._remaining(expiration)) + ts.do_handshake() + cert = ts.getpeercert() + return self.ddr_check_certificate(cert) + + async def ddr_tls_check_async(self, lifetime, backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + ctx = self.make_tls_context() + expiration = time.time() + lifetime + async with await backend.make_socket( + dns.inet.af_for_address(self.bootstrap_address), + socket.SOCK_STREAM, + 0, + None, + (self.bootstrap_address, self.port), + lifetime, + ctx, + self.hostname, + ) as ts: + cert = await ts.getpeercert(dns.query._remaining(expiration)) + return self.ddr_check_certificate(cert) + + +def _extract_nameservers_from_svcb(answer): + bootstrap_address = answer.nameserver + if not dns.inet.is_address(bootstrap_address): + return [] + infos = [] + for rr in answer.rrset.processing_order(): + nameservers = [] + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN) + if param is None: + continue + alpns = set(param.ids) + host = rr.target.to_text(omit_final_dot=True) + port = None + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT) + if param is not None: + port = param.port + # For now we ignore address hints and address resolution and always use the + # bootstrap address + if b"h2" in alpns: + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH) + if param is None or not param.value.endswith(b"{?dns}"): + continue + path = param.value[:-6].decode() + if not path.startswith("/"): + path = "/" + path + if port is None: + port = 443 + url = f"https://{host}:{port}{path}" + # check the URL + try: + urlparse(url) + nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address)) + except Exception: + # continue processing other ALPN types + pass + if b"dot" in alpns: + if port is None: + port = 853 + nameservers.append( + dns.nameserver.DoTNameserver(bootstrap_address, port, host) + ) + if b"doq" in alpns: + if port is None: + port = 853 + nameservers.append( + dns.nameserver.DoQNameserver(bootstrap_address, port, True, host) + ) + if len(nameservers) > 0: + infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers)) + return infos + + +def _get_nameservers_sync(answer, lifetime): + """Return a list of TLS-validated resolver nameservers extracted from an SVCB + answer.""" + nameservers = [] + infos = _extract_nameservers_from_svcb(answer) + for info in infos: + try: + if info.ddr_tls_check_sync(lifetime): + nameservers.extend(info.nameservers) + except Exception: + pass + return nameservers + + +async def _get_nameservers_async(answer, lifetime): + """Return a list of TLS-validated resolver nameservers extracted from an SVCB + answer.""" + nameservers = [] + infos = _extract_nameservers_from_svcb(answer) + for info in infos: + try: + if await info.ddr_tls_check_async(lifetime): + nameservers.extend(info.nameservers) + except Exception: + pass + return nameservers diff --git a/lib/dns/_immutable_ctx.py b/lib/dns/_immutable_ctx.py index 63c0a2d3..ae7a33bf 100644 --- a/lib/dns/_immutable_ctx.py +++ b/lib/dns/_immutable_ctx.py @@ -7,7 +7,6 @@ import contextvars import inspect - _in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) diff --git a/lib/dns/_trio_backend.py b/lib/dns/_trio_backend.py index b0c02103..4d9fb820 100644 --- a/lib/dns/_trio_backend.py +++ b/lib/dns/_trio_backend.py @@ -3,6 +3,7 @@ """trio async I/O library query support""" import socket + import trio import trio.socket # type: ignore @@ -12,7 +13,7 @@ import dns.inet def _maybe_timeout(timeout): - if timeout: + if timeout is not None: return trio.move_on_after(timeout) else: return dns._asyncbackend.NullContext() @@ -50,6 +51,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getsockname(self): return self.socket.getsockname() + async def getpeercert(self, timeout): + raise NotImplementedError + class StreamSocket(dns._asyncbackend.StreamSocket): def __init__(self, family, stream, tls=False): @@ -82,6 +86,100 @@ class StreamSocket(dns._asyncbackend.StreamSocket): else: return self.stream.socket.getsockname() + async def getpeercert(self, timeout): + if self.tls: + with _maybe_timeout(timeout): + await self.stream.do_handshake() + return self.stream.getpeercert() + else: + raise NotImplementedError + + +try: + import httpcore + import httpcore._backends.trio + import httpx + + _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend + _CoreTrioStream = httpcore._backends.trio.TrioStream + + from dns.query import _compute_times, _expiration_for_this_attempt, _remaining + + class _NetworkBackend(_CoreAsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + async def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = (local_address, self._local_port) + else: + source = None + destination = (address, port) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + sock = await Backend().make_socket( + af, socket.SOCK_STREAM, 0, source, destination, timeout + ) + return _CoreTrioStream(sock.stream) + except Exception: + continue + raise httpcore.ConnectError + + async def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + async def sleep(self, seconds): # pylint: disable=signature-differs + await trio.sleep(seconds) + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + class Backend(dns._asyncbackend.Backend): def name(self): @@ -104,8 +202,14 @@ class Backend(dns._asyncbackend.Backend): if source: await s.bind(_lltuple(source, af)) if socktype == socket.SOCK_STREAM: + connected = False with _maybe_timeout(timeout): await s.connect(_lltuple(destination, af)) + connected = True + if not connected: + raise dns.exception.Timeout( + timeout=timeout + ) # lgtm[py/unreachable-statement] except Exception: # pragma: no cover s.close() raise @@ -130,3 +234,13 @@ class Backend(dns._asyncbackend.Backend): async def sleep(self, interval): await trio.sleep(interval) + + def get_transport_class(self): + return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + with _maybe_timeout(timeout): + return await awaitable + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] diff --git a/lib/dns/asyncbackend.py b/lib/dns/asyncbackend.py index c7565a99..07d50e1e 100644 --- a/lib/dns/asyncbackend.py +++ b/lib/dns/asyncbackend.py @@ -5,13 +5,12 @@ from typing import Dict import dns.exception # pylint: disable=unused-import - -from dns._asyncbackend import ( - Socket, - DatagramSocket, - StreamSocket, +from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import] Backend, -) # noqa: F401 lgtm[py/unused-import] + DatagramSocket, + Socket, + StreamSocket, +) # pylint: enable=unused-import @@ -30,8 +29,8 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException): def get_backend(name: str) -> Backend: """Get the specified asynchronous backend. - *name*, a ``str``, the name of the backend. Currently the "trio", - "curio", and "asyncio" backends are available. + *name*, a ``str``, the name of the backend. Currently the "trio" + and "asyncio" backends are available. Raises NotImplementError if an unknown backend name is specified. """ @@ -43,10 +42,6 @@ def get_backend(name: str) -> Backend: import dns._trio_backend backend = dns._trio_backend.Backend() - elif name == "curio": - import dns._curio_backend - - backend = dns._curio_backend.Backend() elif name == "asyncio": import dns._asyncio_backend @@ -73,9 +68,7 @@ def sniff() -> str: try: return sniffio.current_async_library() except sniffio.AsyncLibraryNotFoundError: - raise AsyncLibraryNotFoundError( - "sniffio cannot determine " + "async library" - ) + raise AsyncLibraryNotFoundError("sniffio cannot determine async library") except ImportError: import asyncio diff --git a/lib/dns/asyncquery.py b/lib/dns/asyncquery.py index 459c611d..ecf9c1a5 100644 --- a/lib/dns/asyncquery.py +++ b/lib/dns/asyncquery.py @@ -17,39 +17,38 @@ """Talk to a DNS server.""" -from typing import Any, Dict, Optional, Tuple, Union - import base64 import contextlib import socket import struct import time +from typing import Any, Dict, Optional, Tuple, Union import dns.asyncbackend import dns.exception import dns.inet -import dns.name import dns.message +import dns.name import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype import dns.transaction - from dns._asyncbackend import NullContext from dns.query import ( - _compute_times, - _matches_destination, BadResponse, - ssl, - UDPMode, - _have_httpx, - _have_http2, NoDOH, NoDOQ, + UDPMode, + _compute_times, + _have_http2, + _matches_destination, + _remaining, + have_doh, + ssl, ) -if _have_httpx: +if have_doh: import httpx # for brevity @@ -73,7 +72,7 @@ def _source_tuple(af, address, port): def _timeout(expiration, now=None): - if expiration: + if expiration is not None: if not now: now = time.time() return max(expiration - now, 0) @@ -445,9 +444,6 @@ async def tls( ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 if server_hostname is None: ssl_context.check_hostname = False - else: - ssl_context = None - server_hostname = None af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) @@ -495,6 +491,9 @@ async def https( path: str = "/dns-query", post: bool = True, verify: Union[bool, str] = True, + bootstrap_address: Optional[str] = None, + resolver: Optional["dns.asyncresolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -508,8 +507,10 @@ async def https( parameters, exceptions, and return type of this method. """ - if not _have_httpx: - raise NoDOH("httpx is not available.") # pragma: no cover + if not have_doh: + raise NoDOH # pragma: no cover + if client and not isinstance(client, httpx.AsyncClient): + raise ValueError("session parameter must be an httpx.AsyncClient") wire = q.to_wire() try: @@ -518,15 +519,32 @@ async def https( af = None transport = None headers = {"accept": "application/dns-message"} - if af is not None: + if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: url = "https://[{}]:{}{}".format(where, port, path) else: url = where - if source is not None: - transport = httpx.AsyncHTTPTransport(local_address=source[0]) + + backend = dns.asyncbackend.get_default_backend() + + if source is None: + local_address = None + local_port = 0 + else: + local_address = source + local_port = source_port + transport = backend.get_transport_class()( + local_address=local_address, + http1=True, + http2=_have_http2, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) if client: cm: contextlib.AbstractAsyncContextManager = NullContext(client) @@ -545,14 +563,14 @@ async def https( "content-length": str(len(wire)), } ) - response = await the_client.post( - url, headers=headers, content=wire, timeout=timeout + response = await backend.wait_for( + the_client.post(url, headers=headers, content=wire), timeout ) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes - response = await the_client.get( - url, headers=headers, timeout=timeout, params={"dns": twire} + response = await backend.wait_for( + the_client.get(url, headers=headers, params={"dns": twire}), timeout ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH @@ -690,6 +708,7 @@ async def quic( connection: Optional[dns.quic.AsyncQuicConnection] = None, verify: Union[bool, str] = True, backend: Optional[dns.asyncbackend.Backend] = None, + server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending an asynchronous query via DNS-over-QUIC. @@ -715,14 +734,16 @@ async def quic( (cfactory, mfactory) = dns.quic.factories_for_backend(backend) async with cfactory() as context: - async with mfactory(context, verify_mode=verify) as the_manager: + async with mfactory( + context, verify_mode=verify, server_name=server_hostname + ) as the_manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) - start = time.time() - stream = await the_connection.make_stream() + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) async with stream: await stream.send(wire, True) - wire = await stream.receive(timeout) + wire = await stream.receive(_remaining(expiration)) finish = time.time() r = dns.message.from_wire( wire, diff --git a/lib/dns/asyncresolver.py b/lib/dns/asyncresolver.py index 506530e2..8f5e062a 100644 --- a/lib/dns/asyncresolver.py +++ b/lib/dns/asyncresolver.py @@ -17,10 +17,11 @@ """Asynchronous DNS stub resolver.""" -from typing import Any, Dict, Optional, Union - +import socket import time +from typing import Any, Dict, List, Optional, Union +import dns._ddr import dns.asyncbackend import dns.asyncquery import dns.exception @@ -31,8 +32,7 @@ import dns.rdatatype import dns.resolver # lgtm[py/import-and-import-from] # import some resolver symbols for brevity -from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA - +from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute # for indentation purposes below _udp = dns.asyncquery.udp @@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver): assert request is not None # needed for type checking done = False while not done: - (nameserver, port, tcp, backoff) = resolution.next_nameserver() + (nameserver, tcp, backoff) = resolution.next_nameserver() if backoff: await backend.sleep(backoff) timeout = self._compute_timeout(start, lifetime, resolution.errors) try: - if dns.inet.is_address(nameserver): - if tcp: - response = await _tcp( - request, - nameserver, - timeout, - port, - source, - source_port, - backend=backend, - ) - else: - response = await _udp( - request, - nameserver, - timeout, - port, - source, - source_port, - raise_on_truncation=True, - backend=backend, - ) - else: - response = await dns.asyncquery.https( - request, nameserver, timeout=timeout - ) + response = await nameserver.async_query( + request, + timeout=timeout, + source=source, + source_port=source_port, + max_size=tcp, + backend=backend, + ) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -153,6 +135,73 @@ class Resolver(dns.resolver.BaseResolver): dns.reversename.from_address(ipaddr), *args, **modified_kwargs ) + async def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any, + ) -> dns.resolver.HostAnswers: + """Use an asynchronous resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return dns.resolver.HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return dns.resolver.HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) + lifetime = modified_kwargs.pop("lifetime", None) + start = time.time() + v6 = await self.resolve( + name, + dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = await self.resolve( + name, + dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + answers = dns.resolver.HostAnswers.make( + v6=v6, v4=v4, add_empty=not raise_on_no_answer + ) + if not answers: + raise NoAnswer(response=v6.response) + return answers + # pylint: disable=redefined-outer-name async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: @@ -176,6 +225,37 @@ class Resolver(dns.resolver.BaseResolver): canonical_name = e.canonical_name return canonical_name + async def try_ddr(self, lifetime: float = 5.0) -> None: + """Try to update the resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + *lifetime*, a float, is the maximum time to spend attempting DDR. The default + is 5 seconds. + + If the SVCB query is successful and results in a non-empty list of nameservers, + then the resolver's nameservers are set to the returned servers in priority + order. + + The current implementation does not use any address hints from the SVCB record, + nor does it resolve addresses for the SCVB target name, rather it assumes that + the bootstrap nameserver will always be one of the addresses and uses it. + A future revision to the code may offer fuller support. The code verifies that + the bootstrap nameserver is in the Subject Alternative Name field of the + TLS certficate. + """ + try: + expiration = time.time() + lifetime + answer = await self.resolve( + dns._ddr._local_resolver_name, "svcb", lifetime=lifetime + ) + timeout = dns.query._remaining(expiration) + nameservers = await dns._ddr._get_nameservers_async(answer, timeout) + if len(nameservers) > 0: + self.nameservers = nameservers + except Exception: + pass + default_resolver = None @@ -246,6 +326,18 @@ async def resolve_address( return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +async def resolve_name( + name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any +) -> dns.resolver.HostAnswers: + """Use a resolver to asynchronously query for address records. + + See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more + information on the parameters. + """ + + return await get_default_resolver().resolve_name(name, family, **kwargs) + + async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. @@ -256,6 +348,16 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return await get_default_resolver().canonical_name(name) +async def try_ddr(timeout: float = 5.0) -> None: + """Try to update the default resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + See :py:func:`dns.resolver.Resolver.try_ddr` for more information. + """ + return await get_default_resolver().try_ddr(timeout) + + async def zone_for_name( name: Union[dns.name.Name, str], rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, @@ -290,3 +392,84 @@ async def zone_for_name( name = name.parent() except dns.name.NoParent: # pragma: no cover raise NoRootSOA + + +async def make_resolver_at( + where: Union[dns.name.Name, str], + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Resolver: + """Make a stub resolver using the specified destination as the full resolver. + + *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the + full resolver. + + *port*, an ``int``, the port to use. If not specified, the default is 53. + + *family*, an ``int``, the address family to use. This parameter is used if + *where* is not an address. The default is ``socket.AF_UNSPEC`` in which case + the first address returned by ``resolve_name()`` will be used, otherwise the + first address of the specified family will be used. + + *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames. If not specified, the default resolver will be used. + + Returns a ``dns.resolver.Resolver`` or raises an exception. + """ + if resolver is None: + resolver = get_default_resolver() + nameservers: List[Union[str, dns.nameserver.Nameserver]] = [] + if isinstance(where, str) and dns.inet.is_address(where): + nameservers.append(dns.nameserver.Do53Nameserver(where, port)) + else: + answers = await resolver.resolve_name(where, family) + for address in answers.addresses(): + nameservers.append(dns.nameserver.Do53Nameserver(address, port)) + res = dns.asyncresolver.Resolver(configure=False) + res.nameservers = nameservers + return res + + +async def resolve_at( + where: Union[dns.name.Name, str], + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> dns.resolver.Answer: + """Query nameservers to find the answer to the question. + + This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()`` + to make a resolver, and then uses it to resolve the query. + + See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution + parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the + resolver parameters *where*, *port*, *family*, and *resolver*. + + If making more than one query, it is more efficient to call + ``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries + instead of calling ``resolve_at()`` multiple times. + """ + res = await make_resolver_at(where, port, family, resolver) + return await res.resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + backend, + ) diff --git a/lib/dns/dnssec.py b/lib/dns/dnssec.py index 5dc26223..2949f619 100644 --- a/lib/dns/dnssec.py +++ b/lib/dns/dnssec.py @@ -17,50 +17,44 @@ """Common DNSSEC-related functions and constants.""" -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union +import base64 +import contextlib +import functools import hashlib -import math import struct import time -import base64 from datetime import datetime - -from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash +from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast import dns.exception import dns.name import dns.node -import dns.rdataset import dns.rdata -import dns.rdatatype import dns.rdataclass +import dns.rdataset +import dns.rdatatype import dns.rrset +import dns.transaction +import dns.zone +from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash +from dns.exception import ( # pylint: disable=W0611 + AlgorithmKeyMismatch, + DeniedByPolicy, + UnsupportedAlgorithm, + ValidationFailure, +) from dns.rdtypes.ANY.CDNSKEY import CDNSKEY from dns.rdtypes.ANY.CDS import CDS from dns.rdtypes.ANY.DNSKEY import DNSKEY from dns.rdtypes.ANY.DS import DS +from dns.rdtypes.ANY.NSEC import NSEC, Bitmap +from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime from dns.rdtypes.dnskeybase import Flag - -class UnsupportedAlgorithm(dns.exception.DNSException): - """The DNSSEC algorithm is not supported.""" - - -class AlgorithmKeyMismatch(UnsupportedAlgorithm): - """The DNSSEC algorithm is not supported for the given key type.""" - - -class ValidationFailure(dns.exception.DNSException): - """The DNSSEC signature is invalid.""" - - -class DeniedByPolicy(dns.exception.DNSException): - """Denied by DNSSEC policy.""" - - PublicKey = Union[ + "GenericPublicKey", "rsa.RSAPublicKey", "ec.EllipticCurvePublicKey", "ed25519.Ed25519PublicKey", @@ -68,12 +62,15 @@ PublicKey = Union[ ] PrivateKey = Union[ + "GenericPrivateKey", "rsa.RSAPrivateKey", "ec.EllipticCurvePrivateKey", "ed25519.Ed25519PrivateKey", "ed448.Ed448PrivateKey", ] +RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None] + def algorithm_from_text(text: str) -> Algorithm: """Convert text into a DNSSEC algorithm value. @@ -308,113 +305,13 @@ def _find_candidate_keys( return [ cast(DNSKEY, rd) for rd in rdataset - if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag + if rd.algorithm == rrsig.algorithm + and key_id(rd) == rrsig.key_tag + and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1 + and rd.protocol == 3 # RFC 4034 2.1.2 ] -def _is_rsa(algorithm: int) -> bool: - return algorithm in ( - Algorithm.RSAMD5, - Algorithm.RSASHA1, - Algorithm.RSASHA1NSEC3SHA1, - Algorithm.RSASHA256, - Algorithm.RSASHA512, - ) - - -def _is_dsa(algorithm: int) -> bool: - return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1) - - -def _is_ecdsa(algorithm: int) -> bool: - return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384) - - -def _is_eddsa(algorithm: int) -> bool: - return algorithm in (Algorithm.ED25519, Algorithm.ED448) - - -def _is_gost(algorithm: int) -> bool: - return algorithm == Algorithm.ECCGOST - - -def _is_md5(algorithm: int) -> bool: - return algorithm == Algorithm.RSAMD5 - - -def _is_sha1(algorithm: int) -> bool: - return algorithm in ( - Algorithm.DSA, - Algorithm.RSASHA1, - Algorithm.DSANSEC3SHA1, - Algorithm.RSASHA1NSEC3SHA1, - ) - - -def _is_sha256(algorithm: int) -> bool: - return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256) - - -def _is_sha384(algorithm: int) -> bool: - return algorithm == Algorithm.ECDSAP384SHA384 - - -def _is_sha512(algorithm: int) -> bool: - return algorithm == Algorithm.RSASHA512 - - -def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None: - """Ensure algorithm is valid for key type, throwing an exception on - mismatch.""" - if isinstance(key, rsa.RSAPublicKey): - if _is_rsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm) - if isinstance(key, dsa.DSAPublicKey): - if _is_dsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm) - if isinstance(key, ec.EllipticCurvePublicKey): - if _is_ecdsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm) - if isinstance(key, ed25519.Ed25519PublicKey): - if algorithm == Algorithm.ED25519: - return - raise AlgorithmKeyMismatch( - 'algorithm "%s" not valid for ED25519 key' % algorithm - ) - if isinstance(key, ed448.Ed448PublicKey): - if algorithm == Algorithm.ED448: - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm) - - raise TypeError("unsupported key type") - - -def _make_hash(algorithm: int) -> Any: - if _is_md5(algorithm): - return hashes.MD5() - if _is_sha1(algorithm): - return hashes.SHA1() - if _is_sha256(algorithm): - return hashes.SHA256() - if _is_sha384(algorithm): - return hashes.SHA384() - if _is_sha512(algorithm): - return hashes.SHA512() - if algorithm == Algorithm.ED25519: - return hashes.SHA512() - if algorithm == Algorithm.ED448: - return hashes.SHAKE256(114) - - raise ValidationFailure("unknown hash for algorithm %u" % algorithm) - - -def _bytes_to_long(b: bytes) -> int: - return int.from_bytes(b, "big") - - def _get_rrname_rdataset( rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], ) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: @@ -424,85 +321,13 @@ def _get_rrname_rdataset( return rrset.name, rrset -def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: - keyptr: bytes - if _is_rsa(key.algorithm): - # we ignore because mypy is confused and thinks key.key is a str for unknown - # reasons. - keyptr = key.key - (bytes_,) = struct.unpack("!B", keyptr[0:1]) - keyptr = keyptr[1:] - if bytes_ == 0: - (bytes_,) = struct.unpack("!H", keyptr[0:2]) - keyptr = keyptr[2:] - rsa_e = keyptr[0:bytes_] - rsa_n = keyptr[bytes_:] - try: - rsa_public_key = rsa.RSAPublicNumbers( - _bytes_to_long(rsa_e), _bytes_to_long(rsa_n) - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) - elif _is_dsa(key.algorithm): - keyptr = key.key - (t,) = struct.unpack("!B", keyptr[0:1]) - keyptr = keyptr[1:] - octets = 64 + t * 8 - dsa_q = keyptr[0:20] - keyptr = keyptr[20:] - dsa_p = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_g = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_y = keyptr[0:octets] - try: - dsa_public_key = dsa.DSAPublicNumbers( # type: ignore - _bytes_to_long(dsa_y), - dsa.DSAParameterNumbers( - _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g) - ), - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - dsa_public_key.verify(sig, data, chosen_hash) - elif _is_ecdsa(key.algorithm): - keyptr = key.key - curve: Any - if key.algorithm == Algorithm.ECDSAP256SHA256: - curve = ec.SECP256R1() - octets = 32 - else: - curve = ec.SECP384R1() - octets = 48 - ecdsa_x = keyptr[0:octets] - ecdsa_y = keyptr[octets : octets * 2] - try: - ecdsa_public_key = ec.EllipticCurvePublicNumbers( - curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y) - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash)) - elif _is_eddsa(key.algorithm): - keyptr = key.key - loader: Any - if key.algorithm == Algorithm.ED25519: - loader = ed25519.Ed25519PublicKey - else: - loader = ed448.Ed448PublicKey - try: - eddsa_public_key = loader.from_public_bytes(keyptr) - except ValueError: - raise ValidationFailure("invalid public key") - eddsa_public_key.verify(sig, data) - elif _is_gost(key.algorithm): - raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' - % algorithm_to_text(key.algorithm) - ) - else: - raise ValidationFailure("unknown algorithm %u" % key.algorithm) +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None: + public_cls = get_algorithm_cls_from_dnskey(key).public_cls + try: + public_key = public_cls.from_dnskey(key) + except ValueError: + raise ValidationFailure("invalid public key") + public_key.verify(sig, data) def _validate_rrsig( @@ -559,29 +384,13 @@ def _validate_rrsig( if rrsig.inception > now: raise ValidationFailure("not yet valid") - if _is_dsa(rrsig.algorithm): - sig_r = rrsig.signature[1:21] - sig_s = rrsig.signature[21:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) - elif _is_ecdsa(rrsig.algorithm): - if rrsig.algorithm == Algorithm.ECDSAP256SHA256: - octets = 32 - else: - octets = 48 - sig_r = rrsig.signature[0:octets] - sig_s = rrsig.signature[octets:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) - else: - sig = rrsig.signature - data = _make_rrsig_signature_data(rrset, rrsig, origin) - chosen_hash = _make_hash(rrsig.algorithm) for candidate_key in candidate_keys: if not policy.ok_to_validate(candidate_key): continue try: - _validate_signature(sig, data, candidate_key, chosen_hash) + _validate_signature(rrsig.signature, data, candidate_key) return except (InvalidSignature, ValidationFailure): # this happens on an individual validation failure @@ -673,6 +482,7 @@ def _sign( lifetime: Optional[int] = None, verify: bool = False, policy: Optional[Policy] = None, + origin: Optional[dns.name.Name] = None, ) -> RRSIG: """Sign RRset using private key. @@ -708,6 +518,10 @@ def _sign( *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + *origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all + names in the rrset (including its owner name) must be absolute; otherwise the + specified origin will be used to make names absolute when signing. + Raises ``DeniedByPolicy`` if the signature is denied by policy. """ @@ -735,16 +549,26 @@ def _sign( if expiration is not None: rrsig_expiration = to_timestamp(expiration) elif lifetime is not None: - rrsig_expiration = int(time.time()) + lifetime + rrsig_expiration = rrsig_inception + lifetime else: raise ValueError("expiration or lifetime must be specified") + # Derelativize now because we need a correct labels length for the + # rrsig_template. + if origin is not None: + rrname = rrname.derelativize(origin) + labels = len(rrname) - 1 + + # Adjust labels appropriately for wildcards. + if rrname.is_wild(): + labels -= 1 + rrsig_template = RRSIG( rdclass=rdclass, rdtype=dns.rdatatype.RRSIG, type_covered=rdtype, algorithm=dnskey.algorithm, - labels=len(rrname) - 1, + labels=labels, original_ttl=original_ttl, expiration=rrsig_expiration, inception=rrsig_inception, @@ -753,63 +577,18 @@ def _sign( signature=b"", ) - data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template) - chosen_hash = _make_hash(rrsig_template.algorithm) - signature = None + data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) - if isinstance(private_key, rsa.RSAPrivateKey): - if not _is_rsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for RSA key") - signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash) - if verify: - private_key.public_key().verify( - signature, data, padding.PKCS1v15(), chosen_hash - ) - elif isinstance(private_key, dsa.DSAPrivateKey): - if not _is_dsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for DSA key") - public_dsa_key = private_key.public_key() - if public_dsa_key.key_size > 1024: - raise ValueError("DSA key size overflow") - der_signature = private_key.sign(data, chosen_hash) - if verify: - public_dsa_key.verify(der_signature, data, chosen_hash) - dsa_r, dsa_s = utils.decode_dss_signature(der_signature) - dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 - octets = 20 - signature = ( - struct.pack("!B", dsa_t) - + int.to_bytes(dsa_r, length=octets, byteorder="big") - + int.to_bytes(dsa_s, length=octets, byteorder="big") - ) - elif isinstance(private_key, ec.EllipticCurvePrivateKey): - if not _is_ecdsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for EC key") - der_signature = private_key.sign(data, ec.ECDSA(chosen_hash)) - if verify: - private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash)) - if dnskey.algorithm == Algorithm.ECDSAP256SHA256: - octets = 32 - else: - octets = 48 - dsa_r, dsa_s = utils.decode_dss_signature(der_signature) - signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes( - dsa_s, length=octets, byteorder="big" - ) - elif isinstance(private_key, ed25519.Ed25519PrivateKey): - if dnskey.algorithm != Algorithm.ED25519: - raise ValueError("Invalid DNSKEY algorithm for ED25519 key") - signature = private_key.sign(data) - if verify: - private_key.public_key().verify(signature, data) - elif isinstance(private_key, ed448.Ed448PrivateKey): - if dnskey.algorithm != Algorithm.ED448: - raise ValueError("Invalid DNSKEY algorithm for ED448 key") - signature = private_key.sign(data) - if verify: - private_key.public_key().verify(signature, data) + if isinstance(private_key, GenericPrivateKey): + signing_key = private_key else: - raise TypeError("Unsupported key algorithm") + try: + private_cls = get_algorithm_cls_from_dnskey(dnskey) + signing_key = private_cls(key=private_key) + except UnsupportedAlgorithm: + raise TypeError("Unsupported key algorithm") + + signature = signing_key.sign(data, verify) return cast(RRSIG, rrsig_template.replace(signature=signature)) @@ -858,9 +637,12 @@ def _make_rrsig_signature_data( raise ValidationFailure("relative RR name without an origin specified") rrname = rrname.derelativize(origin) - if len(rrname) - 1 < rrsig.labels: + name_len = len(rrname) + if rrname.is_wild() and rrsig.labels != name_len - 2: + raise ValidationFailure("wild owner name has wrong label length") + if name_len - 1 < rrsig.labels: raise ValidationFailure("owner name longer than RRSIG labels") - elif rrsig.labels < len(rrname) - 1: + elif rrsig.labels < name_len - 1: suffix = rrname.split(rrsig.labels + 1)[1] rrname = dns.name.from_text("*", suffix) rrnamebuf = rrname.to_digestable() @@ -884,9 +666,8 @@ def _make_dnskey( ) -> DNSKEY: """Convert a public key to DNSKEY Rdata - *public_key*, the public key to convert, a - ``cryptography.hazmat.primitives.asymmetric`` public key class applicable - for DNSSEC. + *public_key*, a ``PublicKey`` (``GenericPublicKey`` or + ``cryptography.hazmat.primitives.asymmetric``) to convert. *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. @@ -902,72 +683,13 @@ def _make_dnskey( Return DNSKEY ``Rdata``. """ - def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes: - """Encode a public key per RFC 3110, section 2.""" - pn = public_key.public_numbers() - _exp_len = math.ceil(int.bit_length(pn.e) / 8) - exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") - if _exp_len > 255: - exp_header = b"\0" + struct.pack("!H", _exp_len) - else: - exp_header = struct.pack("!B", _exp_len) - if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: - raise ValueError("unsupported RSA key length") - return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + algorithm = Algorithm.make(algorithm) - def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes: - """Encode a public key per RFC 2536, section 2.""" - pn = public_key.public_numbers() - dsa_t = (public_key.key_size // 8 - 64) // 8 - if dsa_t > 8: - raise ValueError("unsupported DSA key size") - octets = 64 + dsa_t * 8 - res = struct.pack("!B", dsa_t) - res += pn.parameter_numbers.q.to_bytes(20, "big") - res += pn.parameter_numbers.p.to_bytes(octets, "big") - res += pn.parameter_numbers.g.to_bytes(octets, "big") - res += pn.y.to_bytes(octets, "big") - return res - - def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes: - """Encode a public key per RFC 6605, section 4.""" - pn = public_key.public_numbers() - if isinstance(public_key.curve, ec.SECP256R1): - return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big") - elif isinstance(public_key.curve, ec.SECP384R1): - return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big") - else: - raise ValueError("unsupported ECDSA curve") - - the_algorithm = Algorithm.make(algorithm) - - _ensure_algorithm_key_combination(the_algorithm, public_key) - - if isinstance(public_key, rsa.RSAPublicKey): - key_bytes = encode_rsa_public_key(public_key) - elif isinstance(public_key, dsa.DSAPublicKey): - key_bytes = encode_dsa_public_key(public_key) - elif isinstance(public_key, ec.EllipticCurvePublicKey): - key_bytes = encode_ecdsa_public_key(public_key) - elif isinstance(public_key, ed25519.Ed25519PublicKey): - key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) - elif isinstance(public_key, ed448.Ed448PublicKey): - key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) + if isinstance(public_key, GenericPublicKey): + return public_key.to_dnskey(flags=flags, protocol=protocol) else: - raise TypeError("unsupported key algorithm") - - return DNSKEY( - rdclass=dns.rdataclass.IN, - rdtype=dns.rdatatype.DNSKEY, - flags=flags, - protocol=protocol, - algorithm=the_algorithm, - key=key_bytes, - ) + public_cls = get_algorithm_cls(algorithm).public_cls + return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol) def _make_cdnskey( @@ -1216,23 +938,252 @@ def dnskey_rdataset_to_cdnskey_rdataset( return dns.rdataset.from_rdata_list(rdataset.ttl, res) +def default_rrset_signer( + txn: dns.transaction.Transaction, + rrset: dns.rrset.RRset, + signer: dns.name.Name, + ksks: List[Tuple[PrivateKey, DNSKEY]], + zsks: List[Tuple[PrivateKey, DNSKEY]], + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + policy: Optional[Policy] = None, + origin: Optional[dns.name.Name] = None, +) -> None: + """Default RRset signer""" + + if rrset.rdtype in set( + [ + dns.rdatatype.RdataType.DNSKEY, + dns.rdatatype.RdataType.CDS, + dns.rdatatype.RdataType.CDNSKEY, + ] + ): + keys = ksks + else: + keys = zsks + + for private_key, dnskey in keys: + rrsig = dns.dnssec.sign( + rrset=rrset, + private_key=private_key, + dnskey=dnskey, + inception=inception, + expiration=expiration, + lifetime=lifetime, + signer=signer, + policy=policy, + origin=origin, + ) + txn.add(rrset.name, rrset.ttl, rrsig) + + +def sign_zone( + zone: dns.zone.Zone, + txn: Optional[dns.transaction.Transaction] = None, + keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None, + add_dnskey: bool = True, + dnskey_ttl: Optional[int] = None, + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + nsec3: Optional[NSEC3PARAM] = None, + rrset_signer: Optional[RRsetSigner] = None, + policy: Optional[Policy] = None, +) -> None: + """Sign zone. + + *zone*, a ``dns.zone.Zone``, the zone to sign. + + *txn*, a ``dns.transaction.Transaction``, an optional transaction to use for + signing. + + *keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK + roles are assigned automatically if the SEP flag is used, otherwise all RRsets are + signed by all keys. + + *add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are + automatically added to the zone on signing. + + *dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL + of the existing DNSKEY RRset used or the TTL of the SOA RRset. + + *inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + inception time. If ``None``, the current time is used. If a ``str``, the format is + "YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text + form; this is the same the RRSIG rdata's text form. Values of type `int` or `float` + are interpreted as seconds since the UNIX epoch. + + *expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + expiration time. If ``None``, the expiration time will be the inception time plus + the value of the *lifetime* parameter. See the description of *inception* above for + how the various parameter types are interpreted. + + *lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This + parameter is only meaningful if *expiration* is ``None``. + + *nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet + implemented. + + *rrset_signer*, a ``Callable``, an optional function for signing RRsets. The + function requires two arguments: transaction and RRset. If the not specified, + ``dns.dnssec.default_rrset_signer`` will be used. + + Returns ``None``. + """ + + ksks = [] + zsks = [] + + # if we have both KSKs and ZSKs, split by SEP flag. if not, sign all + # records with all keys + if keys: + for key in keys: + if key[1].flags & Flag.SEP: + ksks.append(key) + else: + zsks.append(key) + if not ksks: + ksks = keys + if not zsks: + zsks = keys + else: + keys = [] + + if txn: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn) + else: + cm = zone.writer() + + with cm as _txn: + if add_dnskey: + if dnskey_ttl is None: + dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY) + if dnskey: + dnskey_ttl = dnskey.ttl + else: + soa = _txn.get(zone.origin, dns.rdatatype.SOA) + dnskey_ttl = soa.ttl + for _, dnskey in keys: + _txn.add(zone.origin, dnskey_ttl, dnskey) + + if nsec3: + raise NotImplementedError("Signing with NSEC3 not yet implemented") + else: + _rrset_signer = rrset_signer or functools.partial( + default_rrset_signer, + signer=zone.origin, + ksks=ksks, + zsks=zsks, + inception=inception, + expiration=expiration, + lifetime=lifetime, + policy=policy, + origin=zone.origin, + ) + return _sign_zone_nsec(zone, _txn, _rrset_signer) + + +def _sign_zone_nsec( + zone: dns.zone.Zone, + txn: dns.transaction.Transaction, + rrset_signer: Optional[RRsetSigner] = None, +) -> None: + """NSEC zone signer""" + + def _txn_add_nsec( + txn: dns.transaction.Transaction, + name: dns.name.Name, + next_secure: Optional[dns.name.Name], + rdclass: dns.rdataclass.RdataClass, + ttl: int, + rrset_signer: Optional[RRsetSigner] = None, + ) -> None: + """NSEC zone signer helper""" + mandatory_types = set( + [dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC] + ) + node = txn.get_node(name) + if node and next_secure: + types = ( + set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types + ) + windows = Bitmap.from_rdtypes(list(types)) + rrset = dns.rrset.from_rdata( + name, + ttl, + NSEC( + rdclass=rdclass, + rdtype=dns.rdatatype.RdataType.NSEC, + next=next_secure, + windows=windows, + ), + ) + txn.add(rrset) + if rrset_signer: + rrset_signer(txn, rrset) + + rrsig_ttl = zone.get_soa().minimum + delegation = None + last_secure = None + + for name in sorted(txn.iterate_names()): + if delegation and name.is_subdomain(delegation): + # names below delegations are not secure + continue + elif txn.get(name, dns.rdatatype.NS) and name != zone.origin: + # inside delegation + delegation = name + else: + # outside delegation + delegation = None + + if rrset_signer: + node = txn.get_node(name) + if node: + for rdataset in node.rdatasets: + if rdataset.rdtype == dns.rdatatype.RRSIG: + # do not sign RRSIGs + continue + elif delegation and rdataset.rdtype != dns.rdatatype.DS: + # do not sign delegations except DS records + continue + else: + rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset) + rrset_signer(txn, rrset) + + # We need "is not None" as the empty name is False because its length is 0. + if last_secure is not None: + _txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer) + last_secure = name + + if last_secure: + _txn_add_nsec( + txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer + ) + + def _need_pyca(*args, **kwargs): raise ImportError( - "DNSSEC validation requires " + "python cryptography" + "DNSSEC validation requires python cryptography" ) # pragma: no cover try: from cryptography.exceptions import InvalidSignature - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import padding - from cryptography.hazmat.primitives.asymmetric import utils - from cryptography.hazmat.primitives.asymmetric import dsa - from cryptography.hazmat.primitives.asymmetric import ec - from cryptography.hazmat.primitives.asymmetric import ed25519 - from cryptography.hazmat.primitives.asymmetric import ed448 - from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611 + ed25519, + ) + + from dns.dnssecalgs import ( # pylint: disable=C0412 + get_algorithm_cls, + get_algorithm_cls_from_dnskey, + ) + from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey except ImportError: # pragma: no cover validate = _need_pyca validate_rrsig = _need_pyca diff --git a/lib/dns/dnssecalgs/__init__.py b/lib/dns/dnssecalgs/__init__.py new file mode 100644 index 00000000..d1ffd519 --- /dev/null +++ b/lib/dns/dnssecalgs/__init__.py @@ -0,0 +1,121 @@ +from typing import Dict, Optional, Tuple, Type, Union + +import dns.name + +try: + from dns.dnssecalgs.base import GenericPrivateKey + from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 + from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 + from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519 + from dns.dnssecalgs.rsa import ( + PrivateRSAMD5, + PrivateRSASHA1, + PrivateRSASHA1NSEC3SHA1, + PrivateRSASHA256, + PrivateRSASHA512, + ) + + _have_cryptography = True +except ImportError: + _have_cryptography = False + +from dns.dnssectypes import Algorithm +from dns.exception import UnsupportedAlgorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + +AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] + +algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} +if _have_cryptography: + algorithms.update( + { + (Algorithm.RSAMD5, None): PrivateRSAMD5, + (Algorithm.DSA, None): PrivateDSA, + (Algorithm.RSASHA1, None): PrivateRSASHA1, + (Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1, + (Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1, + (Algorithm.RSASHA256, None): PrivateRSASHA256, + (Algorithm.RSASHA512, None): PrivateRSASHA512, + (Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256, + (Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384, + (Algorithm.ED25519, None): PrivateED25519, + (Algorithm.ED448, None): PrivateED448, + } + ) + + +def get_algorithm_cls( + algorithm: Union[int, str], prefix: AlgorithmPrefix = None +) -> Type[GenericPrivateKey]: + """Get Private Key class from Algorithm. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + algorithm = Algorithm.make(algorithm) + cls = algorithms.get((algorithm, prefix)) + if cls: + return cls + raise UnsupportedAlgorithm( + 'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm) + ) + + +def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]: + """Get Private Key class from DNSKEY. + + *dnskey*, a ``DNSKEY`` to get Algorithm class for. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + prefix: AlgorithmPrefix = None + if dnskey.algorithm == Algorithm.PRIVATEDNS: + prefix, _ = dns.name.from_wire(dnskey.key, 0) + elif dnskey.algorithm == Algorithm.PRIVATEOID: + length = int(dnskey.key[0]) + prefix = dnskey.key[0 : length + 1] + return get_algorithm_cls(dnskey.algorithm, prefix) + + +def register_algorithm_cls( + algorithm: Union[int, str], + algorithm_cls: Type[GenericPrivateKey], + name: Optional[Union[dns.name.Name, str]] = None, + oid: Optional[bytes] = None, +) -> None: + """Register Algorithm Private Key class. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *algorithm_cls*: A `GenericPrivateKey` class. + + *name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms. + + *oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms. + + Raises ``ValueError`` if a name or oid is specified incorrectly. + """ + if not issubclass(algorithm_cls, GenericPrivateKey): + raise TypeError("Invalid algorithm class") + algorithm = Algorithm.make(algorithm) + prefix: AlgorithmPrefix = None + if algorithm == Algorithm.PRIVATEDNS: + if name is None: + raise ValueError("Name required for PRIVATEDNS algorithms") + if isinstance(name, str): + name = dns.name.from_text(name) + prefix = name + elif algorithm == Algorithm.PRIVATEOID: + if oid is None: + raise ValueError("OID required for PRIVATEOID algorithms") + prefix = bytes([len(oid)]) + oid + elif name: + raise ValueError("Name only supported for PRIVATEDNS algorithm") + elif oid: + raise ValueError("OID only supported for PRIVATEOID algorithm") + algorithms[(algorithm, prefix)] = algorithm_cls diff --git a/lib/dns/dnssecalgs/base.py b/lib/dns/dnssecalgs/base.py new file mode 100644 index 00000000..e990575a --- /dev/null +++ b/lib/dns/dnssecalgs/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod # pylint: disable=no-name-in-module +from typing import Any, Optional, Type + +import dns.rdataclass +import dns.rdatatype +from dns.dnssectypes import Algorithm +from dns.exception import AlgorithmKeyMismatch +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.dnskeybase import Flag + + +class GenericPublicKey(ABC): + algorithm: Algorithm + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def verify(self, signature: bytes, data: bytes) -> None: + """Verify signed DNSSEC data""" + + @abstractmethod + def encode_key_bytes(self) -> bytes: + """Encode key as bytes for DNSKEY""" + + @classmethod + def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None: + if key.algorithm != cls.algorithm: + raise AlgorithmKeyMismatch + + def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY: + """Return public key as DNSKEY""" + return DNSKEY( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.DNSKEY, + flags=flags, + protocol=protocol, + algorithm=self.algorithm, + key=self.encode_key_bytes(), + ) + + @classmethod + @abstractmethod + def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey": + """Create public key from DNSKEY""" + + @classmethod + @abstractmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + """Create public key from PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + @abstractmethod + def to_pem(self) -> bytes: + """Return public-key as PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + +class GenericPrivateKey(ABC): + public_cls: Type[GenericPublicKey] + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign DNSSEC data""" + + @abstractmethod + def public_key(self) -> "GenericPublicKey": + """Return public key instance""" + + @classmethod + @abstractmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + """Create private key from PEM-encoded PKCS#8""" + + @abstractmethod + def to_pem(self, password: Optional[bytes] = None) -> bytes: + """Return private key as PEM-encoded PKCS#8""" diff --git a/lib/dns/dnssecalgs/cryptography.py b/lib/dns/dnssecalgs/cryptography.py new file mode 100644 index 00000000..5a31a812 --- /dev/null +++ b/lib/dns/dnssecalgs/cryptography.py @@ -0,0 +1,68 @@ +from typing import Any, Optional, Type + +from cryptography.hazmat.primitives import serialization + +from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey +from dns.exception import AlgorithmKeyMismatch + + +class CryptographyPublicKey(GenericPublicKey): + key: Any = None + key_cls: Any = None + + def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + key, self.key_cls + ): + raise AlgorithmKeyMismatch + self.key = key + + @classmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + key = serialization.load_pem_public_key(public_pem) + return cls(key=key) + + def to_pem(self) -> bytes: + return self.key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +class CryptographyPrivateKey(GenericPrivateKey): + key: Any = None + key_cls: Any = None + public_cls: Type[CryptographyPublicKey] + + def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + key, self.key_cls + ): + raise AlgorithmKeyMismatch + self.key = key + + def public_key(self) -> "CryptographyPublicKey": + return self.public_cls(key=self.key.public_key()) + + @classmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + key = serialization.load_pem_private_key(private_pem, password=password) + return cls(key=key) + + def to_pem(self, password: Optional[bytes] = None) -> bytes: + encryption_algorithm: serialization.KeySerializationEncryption + if password: + encryption_algorithm = serialization.BestAvailableEncryption(password) + else: + encryption_algorithm = serialization.NoEncryption() + return self.key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) diff --git a/lib/dns/dnssecalgs/dsa.py b/lib/dns/dnssecalgs/dsa.py new file mode 100644 index 00000000..0fe4690d --- /dev/null +++ b/lib/dns/dnssecalgs/dsa.py @@ -0,0 +1,101 @@ +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import dsa, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicDSA(CryptographyPublicKey): + key: dsa.DSAPublicKey + key_cls = dsa.DSAPublicKey + algorithm = Algorithm.DSA + chosen_hash = hashes.SHA1() + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[1:21] + sig_s = signature[21:] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 2536, section 2.""" + pn = self.key.public_numbers() + dsa_t = (self.key.key_size // 8 - 64) // 8 + if dsa_t > 8: + raise ValueError("unsupported DSA key size") + octets = 64 + dsa_t * 8 + res = struct.pack("!B", dsa_t) + res += pn.parameter_numbers.q.to_bytes(20, "big") + res += pn.parameter_numbers.p.to_bytes(octets, "big") + res += pn.parameter_numbers.g.to_bytes(octets, "big") + res += pn.y.to_bytes(octets, "big") + return res + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicDSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (t,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + octets = 64 + t * 8 + dsa_q = keyptr[0:20] + keyptr = keyptr[20:] + dsa_p = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_g = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_y = keyptr[0:octets] + return cls( + key=dsa.DSAPublicNumbers( # type: ignore + int.from_bytes(dsa_y, "big"), + dsa.DSAParameterNumbers( + int.from_bytes(dsa_p, "big"), + int.from_bytes(dsa_q, "big"), + int.from_bytes(dsa_g, "big"), + ), + ).public_key(default_backend()), + ) + + +class PrivateDSA(CryptographyPrivateKey): + key: dsa.DSAPrivateKey + key_cls = dsa.DSAPrivateKey + public_cls = PublicDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 2536, section 3.""" + public_dsa_key = self.key.public_key() + if public_dsa_key.key_size > 1024: + raise ValueError("DSA key size overflow") + der_signature = self.key.sign(data, self.public_cls.chosen_hash) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 + octets = 20 + signature = ( + struct.pack("!B", dsa_t) + + int.to_bytes(dsa_r, length=octets, byteorder="big") + + int.to_bytes(dsa_s, length=octets, byteorder="big") + ) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateDSA": + return cls( + key=dsa.generate_private_key(key_size=key_size), + ) + + +class PublicDSANSEC3SHA1(PublicDSA): + algorithm = Algorithm.DSANSEC3SHA1 + + +class PrivateDSANSEC3SHA1(PrivateDSA): + public_cls = PublicDSANSEC3SHA1 diff --git a/lib/dns/dnssecalgs/ecdsa.py b/lib/dns/dnssecalgs/ecdsa.py new file mode 100644 index 00000000..a31d79f2 --- /dev/null +++ b/lib/dns/dnssecalgs/ecdsa.py @@ -0,0 +1,89 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicECDSA(CryptographyPublicKey): + key: ec.EllipticCurvePublicKey + key_cls = ec.EllipticCurvePublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + curve: ec.EllipticCurve + octets: int + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[0 : self.octets] + sig_s = signature[self.octets :] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, ec.ECDSA(self.chosen_hash)) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 6605, section 4.""" + pn = self.key.public_numbers() + return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA": + cls._ensure_algorithm_key_combination(key) + ecdsa_x = key.key[0 : cls.octets] + ecdsa_y = key.key[cls.octets : cls.octets * 2] + return cls( + key=ec.EllipticCurvePublicNumbers( + curve=cls.curve, + x=int.from_bytes(ecdsa_x, "big"), + y=int.from_bytes(ecdsa_y, "big"), + ).public_key(default_backend()), + ) + + +class PrivateECDSA(CryptographyPrivateKey): + key: ec.EllipticCurvePrivateKey + key_cls = ec.EllipticCurvePrivateKey + public_cls = PublicECDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 6605, section 4.""" + der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash)) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + signature = int.to_bytes( + dsa_r, length=self.public_cls.octets, byteorder="big" + ) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big") + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateECDSA": + return cls( + key=ec.generate_private_key( + curve=cls.public_cls.curve, backend=default_backend() + ), + ) + + +class PublicECDSAP256SHA256(PublicECDSA): + algorithm = Algorithm.ECDSAP256SHA256 + chosen_hash = hashes.SHA256() + curve = ec.SECP256R1() + octets = 32 + + +class PrivateECDSAP256SHA256(PrivateECDSA): + public_cls = PublicECDSAP256SHA256 + + +class PublicECDSAP384SHA384(PublicECDSA): + algorithm = Algorithm.ECDSAP384SHA384 + chosen_hash = hashes.SHA384() + curve = ec.SECP384R1() + octets = 48 + + +class PrivateECDSAP384SHA384(PrivateECDSA): + public_cls = PublicECDSAP384SHA384 diff --git a/lib/dns/dnssecalgs/eddsa.py b/lib/dns/dnssecalgs/eddsa.py new file mode 100644 index 00000000..70505342 --- /dev/null +++ b/lib/dns/dnssecalgs/eddsa.py @@ -0,0 +1,65 @@ +from typing import Type + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed448, ed25519 + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicEDDSA(CryptographyPublicKey): + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 8080, section 3.""" + return self.key.public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA": + cls._ensure_algorithm_key_combination(key) + return cls( + key=cls.key_cls.from_public_bytes(key.key), + ) + + +class PrivateEDDSA(CryptographyPrivateKey): + public_cls: Type[PublicEDDSA] + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 8080, section 4.""" + signature = self.key.sign(data) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateEDDSA": + return cls(key=cls.key_cls.generate()) + + +class PublicED25519(PublicEDDSA): + key: ed25519.Ed25519PublicKey + key_cls = ed25519.Ed25519PublicKey + algorithm = Algorithm.ED25519 + + +class PrivateED25519(PrivateEDDSA): + key: ed25519.Ed25519PrivateKey + key_cls = ed25519.Ed25519PrivateKey + public_cls = PublicED25519 + + +class PublicED448(PublicEDDSA): + key: ed448.Ed448PublicKey + key_cls = ed448.Ed448PublicKey + algorithm = Algorithm.ED448 + + +class PrivateED448(PrivateEDDSA): + key: ed448.Ed448PrivateKey + key_cls = ed448.Ed448PrivateKey + public_cls = PublicED448 diff --git a/lib/dns/dnssecalgs/rsa.py b/lib/dns/dnssecalgs/rsa.py new file mode 100644 index 00000000..e95dcf1d --- /dev/null +++ b/lib/dns/dnssecalgs/rsa.py @@ -0,0 +1,119 @@ +import math +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicRSA(CryptographyPublicKey): + key: rsa.RSAPublicKey + key_cls = rsa.RSAPublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 3110, section 2.""" + pn = self.key.public_numbers() + _exp_len = math.ceil(int.bit_length(pn.e) / 8) + exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") + if _exp_len > 255: + exp_header = b"\0" + struct.pack("!H", _exp_len) + else: + exp_header = struct.pack("!B", _exp_len) + if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: + raise ValueError("unsupported RSA key length") + return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicRSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (bytes_,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + if bytes_ == 0: + (bytes_,) = struct.unpack("!H", keyptr[0:2]) + keyptr = keyptr[2:] + rsa_e = keyptr[0:bytes_] + rsa_n = keyptr[bytes_:] + return cls( + key=rsa.RSAPublicNumbers( + int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big") + ).public_key(default_backend()) + ) + + +class PrivateRSA(CryptographyPrivateKey): + key: rsa.RSAPrivateKey + key_cls = rsa.RSAPrivateKey + public_cls = PublicRSA + default_public_exponent = 65537 + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 3110, section 3.""" + signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateRSA": + return cls( + key=rsa.generate_private_key( + public_exponent=cls.default_public_exponent, + key_size=key_size, + backend=default_backend(), + ) + ) + + +class PublicRSAMD5(PublicRSA): + algorithm = Algorithm.RSAMD5 + chosen_hash = hashes.MD5() + + +class PrivateRSAMD5(PrivateRSA): + public_cls = PublicRSAMD5 + + +class PublicRSASHA1(PublicRSA): + algorithm = Algorithm.RSASHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1(PrivateRSA): + public_cls = PublicRSASHA1 + + +class PublicRSASHA1NSEC3SHA1(PublicRSA): + algorithm = Algorithm.RSASHA1NSEC3SHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1NSEC3SHA1(PrivateRSA): + public_cls = PublicRSASHA1NSEC3SHA1 + + +class PublicRSASHA256(PublicRSA): + algorithm = Algorithm.RSASHA256 + chosen_hash = hashes.SHA256() + + +class PrivateRSASHA256(PrivateRSA): + public_cls = PublicRSASHA256 + + +class PublicRSASHA512(PublicRSA): + algorithm = Algorithm.RSASHA512 + chosen_hash = hashes.SHA512() + + +class PrivateRSASHA512(PrivateRSA): + public_cls = PublicRSASHA512 diff --git a/lib/dns/edns.py b/lib/dns/edns.py index 64436cde..f05baac4 100644 --- a/lib/dns/edns.py +++ b/lib/dns/edns.py @@ -17,11 +17,10 @@ """EDNS Options""" -from typing import Any, Dict, Optional, Union - import math import socket import struct +from typing import Any, Dict, Optional, Union import dns.enum import dns.inet @@ -380,7 +379,7 @@ class EDEOption(Option): # lgtm[py/missing-equals] def from_wire_parser( cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" ) -> Option: - the_code = EDECode.make(parser.get_uint16()) + code = EDECode.make(parser.get_uint16()) text = parser.get_remaining() if text: @@ -390,7 +389,7 @@ class EDEOption(Option): # lgtm[py/missing-equals] else: btext = None - return cls(the_code, btext) + return cls(code, btext) _type_to_class: Dict[OptionType, Any] = { @@ -424,8 +423,8 @@ def option_from_wire_parser( Returns an instance of a subclass of ``dns.edns.Option``. """ - the_otype = OptionType.make(otype) - cls = get_option_class(the_otype) + otype = OptionType.make(otype) + cls = get_option_class(otype) return cls.from_wire_parser(otype, parser) diff --git a/lib/dns/entropy.py b/lib/dns/entropy.py index 5e1f5e23..4dcdc627 100644 --- a/lib/dns/entropy.py +++ b/lib/dns/entropy.py @@ -15,17 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -from typing import Any, Optional - -import os import hashlib +import os import random import threading import time +from typing import Any, Optional class EntropyPool: - # This is an entropy pool for Python implementations that do not # have a working SystemRandom. I'm not sure there are any, but # leaving this code doesn't hurt anything as the library code diff --git a/lib/dns/enum.py b/lib/dns/enum.py index b5a4aed8..71461f17 100644 --- a/lib/dns/enum.py +++ b/lib/dns/enum.py @@ -16,18 +16,31 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import enum +from typing import Type, TypeVar, Union + +TIntEnum = TypeVar("TIntEnum", bound="IntEnum") class IntEnum(enum.IntEnum): @classmethod - def _check_value(cls, value): - max = cls._maximum() - if value < 0 or value > max: - name = cls._short_name() - raise ValueError(f"{name} must be between >= 0 and <= {max}") + def _missing_(cls, value): + cls._check_value(value) + val = int.__new__(cls, value) + val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}" + val._value_ = value + return val @classmethod - def from_text(cls, text): + def _check_value(cls, value): + max = cls._maximum() + if not isinstance(value, int): + raise TypeError + if value < 0 or value > max: + name = cls._short_name() + raise ValueError(f"{name} must be an int between >= 0 and <= {max}") + + @classmethod + def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum: text = text.upper() try: return cls[text] @@ -47,7 +60,7 @@ class IntEnum(enum.IntEnum): raise cls._unknown_exception_class() @classmethod - def to_text(cls, value): + def to_text(cls: Type[TIntEnum], value: int) -> str: cls._check_value(value) try: text = cls(value).name @@ -59,7 +72,7 @@ class IntEnum(enum.IntEnum): return text @classmethod - def make(cls, value): + def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum: """Convert text or a value into an enumerated type, if possible. *value*, the ``int`` or ``str`` to convert. @@ -76,10 +89,7 @@ class IntEnum(enum.IntEnum): if isinstance(value, str): return cls.from_text(value) cls._check_value(value) - try: - return cls(value) - except ValueError: - return value + return cls(value) @classmethod def _maximum(cls): diff --git a/lib/dns/exception.py b/lib/dns/exception.py index 4b1481d1..6982373d 100644 --- a/lib/dns/exception.py +++ b/lib/dns/exception.py @@ -140,6 +140,22 @@ class Timeout(DNSException): super().__init__(*args, **kwargs) +class UnsupportedAlgorithm(DNSException): + """The DNSSEC algorithm is not supported.""" + + +class AlgorithmKeyMismatch(UnsupportedAlgorithm): + """The DNSSEC algorithm is not supported for the given key type.""" + + +class ValidationFailure(DNSException): + """The DNSSEC signature is invalid.""" + + +class DeniedByPolicy(DNSException): + """Denied by DNSSEC policy.""" + + class ExceptionWrapper: def __init__(self, exception_class): self.exception_class = exception_class diff --git a/lib/dns/flags.py b/lib/dns/flags.py index b21b8e3b..4c60be13 100644 --- a/lib/dns/flags.py +++ b/lib/dns/flags.py @@ -17,9 +17,8 @@ """DNS Message Flags.""" -from typing import Any - import enum +from typing import Any # Standard DNS flags diff --git a/lib/dns/immutable.py b/lib/dns/immutable.py index 38fbe597..cab8d6fb 100644 --- a/lib/dns/immutable.py +++ b/lib/dns/immutable.py @@ -1,8 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Any - import collections.abc +from typing import Any from dns._immutable_ctx import immutable diff --git a/lib/dns/inet.py b/lib/dns/inet.py index 11180c96..02e925c6 100644 --- a/lib/dns/inet.py +++ b/lib/dns/inet.py @@ -17,14 +17,12 @@ """Generic Internet address helper functions.""" -from typing import Any, Optional, Tuple - import socket +from typing import Any, Optional, Tuple import dns.ipv4 import dns.ipv6 - # We assume that AF_INET and AF_INET6 are always defined. We keep # these here for the benefit of any old code (unlikely though that # is!). @@ -171,3 +169,12 @@ def low_level_address_tuple( return tup else: raise NotImplementedError(f"unknown address family {af}") + + +def any_for_af(af): + """Return the 'any' address for the specified address family.""" + if af == socket.AF_INET: + return "0.0.0.0" + elif af == socket.AF_INET6: + return "::" + raise NotImplementedError(f"unknown address family {af}") diff --git a/lib/dns/ipv4.py b/lib/dns/ipv4.py index b8e148f3..f549150a 100644 --- a/lib/dns/ipv4.py +++ b/lib/dns/ipv4.py @@ -17,9 +17,8 @@ """IPv4 helper functions.""" -from typing import Union - import struct +from typing import Union import dns.exception diff --git a/lib/dns/ipv6.py b/lib/dns/ipv6.py index fbd49623..0cc3d868 100644 --- a/lib/dns/ipv6.py +++ b/lib/dns/ipv6.py @@ -17,10 +17,9 @@ """IPv6 helper functions.""" -from typing import List, Union - -import re import binascii +import re +from typing import List, Union import dns.exception import dns.ipv4 diff --git a/lib/dns/message.py b/lib/dns/message.py index 8250db3b..daae6363 100644 --- a/lib/dns/message.py +++ b/lib/dns/message.py @@ -17,30 +17,29 @@ """DNS Messages""" -from typing import Any, Dict, List, Optional, Tuple, Union - import contextlib import io import time +from typing import Any, Dict, List, Optional, Tuple, Union -import dns.wire import dns.edns +import dns.entropy import dns.enum import dns.exception import dns.flags import dns.name import dns.opcode -import dns.entropy import dns.rcode import dns.rdata import dns.rdataclass import dns.rdatatype -import dns.rrset -import dns.renderer -import dns.ttl -import dns.tsig import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.TSIG +import dns.renderer +import dns.rrset +import dns.tsig +import dns.ttl +import dns.wire class ShortHeader(dns.exception.FormError): @@ -135,7 +134,7 @@ IndexKeyType = Tuple[ Optional[dns.rdataclass.RdataClass], ] IndexType = Dict[IndexKeyType, dns.rrset.RRset] -SectionType = Union[int, List[dns.rrset.RRset]] +SectionType = Union[int, str, List[dns.rrset.RRset]] class Message: @@ -231,7 +230,7 @@ class Message: s.write("payload %d\n" % self.payload) for opt in self.options: s.write("option %s\n" % opt.to_text()) - for (name, which) in self._section_enum.__members__.items(): + for name, which in self._section_enum.__members__.items(): s.write(f";{name}\n") for rrset in self.section_from_number(which): s.write(rrset.to_text(origin, relativize, **kw)) @@ -348,27 +347,29 @@ class Message: deleting: Optional[dns.rdataclass.RdataClass] = None, create: bool = False, force_unique: bool = False, + idna_codec: Optional[dns.name.IDNACodec] = None, ) -> dns.rrset.RRset: """Find the RRset with the given attributes in the specified section. - *section*, an ``int`` section number, or one of the section - attributes of this message. This specifies the + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the the section of the message to search. For example:: my_message.find_rrset(my_message.answer, name, rdclass, rdtype) my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) + my_message.find_rrset("ANSWER", name, rdclass, rdtype) - *name*, a ``dns.name.Name``, the name of the RRset. + *name*, a ``dns.name.Name`` or ``str``, the name of the RRset. - *rdclass*, an ``int``, the class of the RRset. + *rdclass*, an ``int`` or ``str``, the class of the RRset. - *rdtype*, an ``int``, the type of the RRset. + *rdtype*, an ``int`` or ``str``, the type of the RRset. - *covers*, an ``int`` or ``None``, the covers value of the RRset. - The default is ``None``. + *covers*, an ``int`` or ``str``, the covers value of the RRset. + The default is ``dns.rdatatype.NONE``. - *deleting*, an ``int`` or ``None``, the deleting value of the RRset. - The default is ``None``. + *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the + RRset. The default is ``None``. *create*, a ``bool``. If ``True``, create the RRset if it is not found. The created RRset is appended to *section*. @@ -378,6 +379,10 @@ class Message: already. The default is ``False``. This is useful when creating DDNS Update messages, as order matters for them. + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + Raises ``KeyError`` if the RRset was not found and create was ``False``. @@ -386,10 +391,19 @@ class Message: if isinstance(section, int): section_number = section - the_section = self.section_from_number(section_number) + section = self.section_from_number(section_number) + elif isinstance(section, str): + section_number = MessageSection.from_text(section) + section = self.section_from_number(section_number) else: section_number = self.section_number(section) - the_section = section + if isinstance(name, str): + name = dns.name.from_text(name, idna_codec=idna_codec) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + covers = dns.rdatatype.RdataType.make(covers) + if deleting is not None: + deleting = dns.rdataclass.RdataClass.make(deleting) key = (section_number, name, rdclass, rdtype, covers, deleting) if not force_unique: if self.index is not None: @@ -397,13 +411,13 @@ class Message: if rrset is not None: return rrset else: - for rrset in the_section: + for rrset in section: if rrset.full_match(name, rdclass, rdtype, covers, deleting): return rrset if not create: raise KeyError rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) - the_section.append(rrset) + section.append(rrset) if self.index is not None: self.index[key] = rrset return rrset @@ -418,29 +432,31 @@ class Message: deleting: Optional[dns.rdataclass.RdataClass] = None, create: bool = False, force_unique: bool = False, + idna_codec: Optional[dns.name.IDNACodec] = None, ) -> Optional[dns.rrset.RRset]: """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. - *section*, an ``int`` section number, or one of the section - attributes of this message. This specifies the + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the the section of the message to search. For example:: my_message.get_rrset(my_message.answer, name, rdclass, rdtype) my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) + my_message.get_rrset("ANSWER", name, rdclass, rdtype) - *name*, a ``dns.name.Name``, the name of the RRset. + *name*, a ``dns.name.Name`` or ``str``, the name of the RRset. - *rdclass*, an ``int``, the class of the RRset. + *rdclass*, an ``int`` or ``str``, the class of the RRset. - *rdtype*, an ``int``, the type of the RRset. + *rdtype*, an ``int`` or ``str``, the type of the RRset. - *covers*, an ``int`` or ``None``, the covers value of the RRset. - The default is ``None``. + *covers*, an ``int`` or ``str``, the covers value of the RRset. + The default is ``dns.rdatatype.NONE``. - *deleting*, an ``int`` or ``None``, the deleting value of the RRset. - The default is ``None``. + *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the + RRset. The default is ``None``. *create*, a ``bool``. If ``True``, create the RRset if it is not found. The created RRset is appended to *section*. @@ -450,12 +466,24 @@ class Message: already. The default is ``False``. This is useful when creating DDNS Update messages, as order matters for them. + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + Returns a ``dns.rrset.RRset object`` or ``None``. """ try: rrset = self.find_rrset( - section, name, rdclass, rdtype, covers, deleting, create, force_unique + section, + name, + rdclass, + rdtype, + covers, + deleting, + create, + force_unique, + idna_codec, ) except KeyError: rrset = None @@ -1708,13 +1736,11 @@ def make_query( if isinstance(qname, str): qname = dns.name.from_text(qname, idna_codec=idna_codec) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) m.flags = dns.flags.Flag(flags) - m.find_rrset( - m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True - ) + m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't # been configured. diff --git a/lib/dns/name.py b/lib/dns/name.py index 612af021..f452bfed 100644 --- a/lib/dns/name.py +++ b/lib/dns/name.py @@ -18,12 +18,10 @@ """DNS Names. """ -from typing import Any, Dict, Iterable, Optional, Tuple, Union - import copy -import struct - import encodings.idna # type: ignore +import struct +from typing import Any, Dict, Iterable, Optional, Tuple, Union try: import idna # type: ignore @@ -33,10 +31,9 @@ except ImportError: # pragma: no cover have_idna_2008 = False import dns.enum -import dns.wire import dns.exception import dns.immutable - +import dns.wire CompressType = Dict["Name", int] diff --git a/lib/dns/nameserver.py b/lib/dns/nameserver.py new file mode 100644 index 00000000..5910139e --- /dev/null +++ b/lib/dns/nameserver.py @@ -0,0 +1,329 @@ +from typing import Optional, Union +from urllib.parse import urlparse + +import dns.asyncbackend +import dns.asyncquery +import dns.inet +import dns.message +import dns.query + + +class Nameserver: + def __init__(self): + pass + + def __str__(self): + raise NotImplementedError + + def kind(self) -> str: + raise NotImplementedError + + def is_always_max_size(self) -> bool: + raise NotImplementedError + + def answer_nameserver(self) -> str: + raise NotImplementedError + + def answer_port(self) -> int: + raise NotImplementedError + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + raise NotImplementedError + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + raise NotImplementedError + + +class AddressAndPortNameserver(Nameserver): + def __init__(self, address: str, port: int): + super().__init__() + self.address = address + self.port = port + + def kind(self) -> str: + raise NotImplementedError + + def is_always_max_size(self) -> bool: + return False + + def __str__(self): + ns_kind = self.kind() + return f"{ns_kind}:{self.address}@{self.port}" + + def answer_nameserver(self) -> str: + return self.address + + def answer_port(self) -> int: + return self.port + + +class Do53Nameserver(AddressAndPortNameserver): + def __init__(self, address: str, port: int = 53): + super().__init__(address, port) + + def kind(self): + return "Do53" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + if max_size: + response = dns.query.tcp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + else: + response = dns.query.udp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + raise_on_truncation=True, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + return response + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + if max_size: + response = await dns.asyncquery.tcp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + backend=backend, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + else: + response = await dns.asyncquery.udp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + raise_on_truncation=True, + backend=backend, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + return response + + +class DoHNameserver(Nameserver): + def __init__(self, url: str, bootstrap_address: Optional[str] = None): + super().__init__() + self.url = url + self.bootstrap_address = bootstrap_address + + def kind(self): + return "DoH" + + def is_always_max_size(self) -> bool: + return True + + def __str__(self): + return self.url + + def answer_nameserver(self) -> str: + return self.url + + def answer_port(self) -> int: + port = urlparse(self.url).port + if port is None: + port = 443 + return port + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.https( + request, + self.url, + timeout=timeout, + bootstrap_address=self.bootstrap_address, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.https( + request, + self.url, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + + +class DoTNameserver(AddressAndPortNameserver): + def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None): + super().__init__(address, port) + self.hostname = hostname + + def kind(self): + return "DoT" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.tls( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + server_hostname=self.hostname, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.tls( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + server_hostname=self.hostname, + ) + + +class DoQNameserver(AddressAndPortNameserver): + def __init__( + self, + address: str, + port: int = 853, + verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, + ): + super().__init__(address, port) + self.verify = verify + self.server_hostname = server_hostname + + def kind(self): + return "DoQ" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.quic( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + server_hostname=self.server_hostname, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.quic( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + server_hostname=self.server_hostname, + ) diff --git a/lib/dns/node.py b/lib/dns/node.py index 22bbe7cb..c670243c 100644 --- a/lib/dns/node.py +++ b/lib/dns/node.py @@ -17,19 +17,17 @@ """DNS nodes. A node is a set of rdatasets.""" -from typing import Any, Dict, Optional - import enum import io +from typing import Any, Dict, Optional import dns.immutable import dns.name import dns.rdataclass import dns.rdataset import dns.rdatatype -import dns.rrset import dns.renderer - +import dns.rrset _cname_types = { dns.rdatatype.CNAME, diff --git a/lib/dns/query.py b/lib/dns/query.py index b4cd69f7..0d711251 100644 --- a/lib/dns/query.py +++ b/lib/dns/query.py @@ -17,8 +17,6 @@ """Talk to a DNS server.""" -from typing import Any, Dict, Optional, Tuple, Union - import base64 import contextlib import enum @@ -28,12 +26,12 @@ import selectors import socket import struct import time -import urllib.parse +from typing import Any, Dict, Optional, Tuple, Union import dns.exception import dns.inet -import dns.name import dns.message +import dns.name import dns.quic import dns.rcode import dns.rdataclass @@ -43,20 +41,32 @@ import dns.transaction import dns.tsig import dns.xfr -try: - import requests - from requests_toolbelt.adapters.source import SourceAddressAdapter - from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter - _have_requests = True -except ImportError: # pragma: no cover - _have_requests = False +def _remaining(expiration): + if expiration is None: + return None + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + return timeout + + +def _expiration_for_this_attempt(timeout, expiration): + if expiration is None: + return None + return min(time.time() + timeout, expiration) + _have_httpx = False _have_http2 = False try: + import httpcore + import httpcore._backends.sync import httpx + _CoreNetworkBackend = httpcore.NetworkBackend + _CoreSyncStream = httpcore._backends.sync.SyncStream + _have_httpx = True try: # See if http2 support is available. @@ -64,10 +74,87 @@ try: _have_http2 = True except Exception: pass -except ImportError: # pragma: no cover - pass -have_doh = _have_requests or _have_httpx + class _NetworkBackend(_CoreNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = dns.inet.low_level_address_tuple( + (local_address, self._local_port), af + ) + else: + source = None + sock = _make_socket(af, socket.SOCK_STREAM, source) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + try: + _connect( + sock, + dns.inet.low_level_address_tuple((address, port), af), + attempt_expiration, + ) + return _CoreSyncStream(sock) + except Exception: + pass + raise httpcore.ConnectError + + def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + class _HTTPTransport(httpx.HTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.resolver + + resolver = dns.resolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: # pragma: no cover + + class _HTTPTransport: # type: ignore + def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + +have_doh = _have_httpx try: import ssl @@ -88,7 +175,7 @@ except ImportError: # pragma: no cover @classmethod def create_default_context(cls, *args, **kwargs): - raise Exception("no ssl support") + raise Exception("no ssl support") # pylint: disable=broad-exception-raised # Function used to create a socket. Can be overridden if needed in special @@ -105,7 +192,7 @@ class BadResponse(dns.exception.FormError): class NoDOH(dns.exception.DNSException): - """DNS over HTTPS (DOH) was requested but the requests module is not + """DNS over HTTPS (DOH) was requested but the httpx module is not available.""" @@ -230,7 +317,7 @@ def _destination_and_source( # We know the destination af, so source had better agree! if saf != af: raise ValueError( - "different address families for source " + "and destination" + "different address families for source and destination" ) else: # We didn't know the destination af, but we know the source, @@ -240,11 +327,10 @@ def _destination_and_source( # Caller has specified a source_port but not an address, so we # need to return a source, and we need to use the appropriate # wildcard address as the address. - if af == socket.AF_INET: - source = "0.0.0.0" - elif af == socket.AF_INET6: - source = "::" - else: + try: + source = dns.inet.any_for_af(af) + except Exception: + # we catch this and raise ValueError for backwards compatibility raise ValueError("source_port specified but address family is unknown") # Convert high-level (address, port) tuples into low-level address # tuples. @@ -289,6 +375,8 @@ def https( post: bool = True, bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, + resolver: Optional["dns.resolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -314,91 +402,78 @@ def https( *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the - client/session to use to send the queries. + *session*, an ``httpx.Client``. If provided, the client session to use to send the + queries. *path*, a ``str``. If *where* is an IP address, then *path* will be used to construct the URL to send the DNS query to. *post*, a ``bool``. If ``True``, the default, POST method will be used. - *bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS - resolver. + *bootstrap_address*, a ``str``, the IP address to use to bypass resolution. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification of the server is done using the default CA bundle; if ``False``, then no verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames in URLs. If not specified, a new resolver with a default + configuration will be used; note this is *not* the default resolver as that resolver + might have been configured to use DoH causing a chicken-and-egg problem. This + parameter only has an effect if the HTTP library is httpx. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A + and AAAA records will be retrieved. + Returns a ``dns.message.Message``. """ if not have_doh: - raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover - - _httpx_ok = _have_httpx + raise NoDOH # pragma: no cover + if session and not isinstance(session, httpx.Client): + raise ValueError("session parameter must be an httpx.Client") wire = q.to_wire() - (af, _, source) = _destination_and_source(where, port, source, source_port, False) - transport_adapter = None + (af, _, the_source) = _destination_and_source( + where, port, source, source_port, False + ) transport = None headers = {"accept": "application/dns-message"} - if af is not None: + if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: url = "https://[{}]:{}{}".format(where, port, path) - elif bootstrap_address is not None: - _httpx_ok = False - split_url = urllib.parse.urlsplit(where) - if split_url.hostname is None: - raise ValueError("DoH URL has no hostname") - headers["Host"] = split_url.hostname - url = where.replace(split_url.hostname, bootstrap_address) - if _have_requests: - transport_adapter = HostHeaderSSLAdapter() else: url = where - if source is not None: - # set source port and source address - if _have_httpx: - if source_port == 0: - transport = httpx.HTTPTransport(local_address=source[0], verify=verify) - else: - _httpx_ok = False - if _have_requests: - transport_adapter = SourceAddressAdapter(source) - if session: - if _have_httpx: - _is_httpx = isinstance(session, httpx.Client) - else: - _is_httpx = False - if _is_httpx and not _httpx_ok: - raise NoDOH( - "Session is httpx, but httpx cannot be used for " - "the requested operation." - ) + # set source port and source address + + if the_source is None: + local_address = None + local_port = 0 else: - _is_httpx = _httpx_ok - - if not _httpx_ok and not _have_requests: - raise NoDOH( - "Cannot use httpx for this operation, and requests is not available." - ) + local_address = the_source[0] + local_port = the_source[1] + transport = _HTTPTransport( + local_address=local_address, + http1=True, + http2=_have_http2, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) - elif _is_httpx: + else: cm = httpx.Client( http1=True, http2=_have_http2, verify=verify, transport=transport ) - else: - cm = requests.sessions.Session() with cm as session: - if transport_adapter and not _is_httpx: - session.mount(url, transport_adapter) - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: @@ -408,29 +483,13 @@ def https( "content-length": str(len(wire)), } ) - if _is_httpx: - response = session.post( - url, headers=headers, content=wire, timeout=timeout - ) - else: - response = session.post( - url, headers=headers, data=wire, timeout=timeout, verify=verify - ) + response = session.post(url, headers=headers, content=wire, timeout=timeout) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") - if _is_httpx: - twire = wire.decode() # httpx does a repr() if we give it bytes - response = session.get( - url, headers=headers, timeout=timeout, params={"dns": twire} - ) - else: - response = session.get( - url, - headers=headers, - timeout=timeout, - verify=verify, - params={"dns": wire}, - ) + twire = wire.decode() # httpx does a repr() if we give it bytes + response = session.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes @@ -1070,6 +1129,7 @@ def quic( ignore_trailing: bool = False, connection: Optional[dns.quic.SyncQuicConnection] = None, verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-QUIC. @@ -1101,6 +1161,10 @@ def quic( verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + Returns a ``dns.message.Message``. """ @@ -1115,16 +1179,18 @@ def quic( manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) the_connection = connection else: - manager = dns.quic.SyncQuicManager(verify_mode=verify) + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=server_hostname + ) the_manager = manager # for type checking happiness with manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) - start = time.time() - with the_connection.make_stream() as stream: + (start, expiration) = _compute_times(timeout) + with the_connection.make_stream(timeout) as stream: stream.send(wire, True) - wire = stream.receive(timeout) + wire = stream.receive(_remaining(expiration)) finish = time.time() r = dns.message.from_wire( wire, diff --git a/lib/dns/quic/__init__.py b/lib/dns/quic/__init__.py index f48ecf57..69813f9f 100644 --- a/lib/dns/quic/__init__.py +++ b/lib/dns/quic/__init__.py @@ -5,13 +5,13 @@ try: import dns.asyncbackend from dns._asyncbackend import NullContext - from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream from dns.quic._asyncio import ( - AsyncioQuicManager, AsyncioQuicConnection, + AsyncioQuicManager, AsyncioQuicStream, ) from dns.quic._common import AsyncQuicConnection, AsyncQuicManager + from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream have_quic = True @@ -33,9 +33,10 @@ try: try: import trio + from dns.quic._trio import ( # pylint: disable=ungrouped-imports - TrioQuicManager, TrioQuicConnection, + TrioQuicManager, TrioQuicStream, ) diff --git a/lib/dns/quic/_asyncio.py b/lib/dns/quic/_asyncio.py index 0a2e220d..e1c52339 100644 --- a/lib/dns/quic/_asyncio.py +++ b/lib/dns/quic/_asyncio.py @@ -9,14 +9,16 @@ import time import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore -import dns.inet -import dns.asyncbackend +import dns.asyncbackend +import dns.exception +import dns.inet from dns.quic._common import ( - BaseQuicStream, + QUIC_MAX_DATAGRAM, AsyncQuicConnection, AsyncQuicManager, - QUIC_MAX_DATAGRAM, + BaseQuicStream, + UnexpectedEOF, ) @@ -30,15 +32,15 @@ class AsyncioQuicStream(BaseQuicStream): await self._wake_up.wait() async def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) if self._buffer.have(amount): return self._expecting = amount try: await asyncio.wait_for(self._wait_for_wake_up(), timeout) - except Exception: - pass + except TimeoutError: + raise dns.exception.Timeout self._expecting = 0 async def receive(self, timeout=None): @@ -86,8 +88,10 @@ class AsyncioQuicConnection(AsyncQuicConnection): try: af = dns.inet.af_for_address(self._address) backend = dns.asyncbackend.get_backend("asyncio") + # Note that peer is a low-level address tuple, but make_socket() wants + # a high-level address tuple, so we convert. self._socket = await backend.make_socket( - af, socket.SOCK_DGRAM, 0, self._source, self._peer + af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1]) ) self._socket_created.set() async with self._socket: @@ -106,6 +110,11 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._wake_timer.notify_all() except Exception: pass + finally: + self._done = True + async with self._wake_timer: + self._wake_timer.notify_all() + self._handshake_complete.set() async def _wait_for_wake_timer(self): async with self._wake_timer: @@ -115,7 +124,7 @@ class AsyncioQuicConnection(AsyncQuicConnection): await self._socket_created.wait() while not self._done: datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, address) in datagrams: + for datagram, address in datagrams: assert address == self._peer[0] await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() @@ -160,8 +169,13 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._receiver_task = asyncio.Task(self._receiver()) self._sender_task = asyncio.Task(self._sender()) - async def make_stream(self): - await self._handshake_complete.wait() + async def make_stream(self, timeout=None): + try: + await asyncio.wait_for(self._handshake_complete.wait(), timeout) + except TimeoutError: + raise dns.exception.Timeout + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = AsyncioQuicStream(self, stream_id) self._streams[stream_id] = stream @@ -172,6 +186,9 @@ class AsyncioQuicConnection(AsyncQuicConnection): self._manager.closed(self._peer[0], self._peer[1]) self._closed = True self._connection.close() + # sender might be blocked on this, so set it + self._socket_created.set() + await self._socket.close() async with self._wake_timer: self._wake_timer.notify_all() try: @@ -185,8 +202,8 @@ class AsyncioQuicConnection(AsyncQuicConnection): class AsyncioQuicManager(AsyncQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, AsyncioQuicConnection) + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) def connect(self, address, port=853, source=None, source_port=0): (connection, start) = self._connect(address, port, source, source_port) @@ -198,7 +215,7 @@ class AsyncioQuicManager(AsyncQuicManager): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - # Copy the itertor into a list as exiting things will mutate the connections + # Copy the iterator into a list as exiting things will mutate the connections # table. connections = list(self._connections.values()) for connection in connections: diff --git a/lib/dns/quic/_common.py b/lib/dns/quic/_common.py index d8f6f7fd..38ec103f 100644 --- a/lib/dns/quic/_common.py +++ b/lib/dns/quic/_common.py @@ -3,13 +3,12 @@ import socket import struct import time - -from typing import Any +from typing import Any, Optional import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore -import dns.inet +import dns.inet QUIC_MAX_DATAGRAM = 2048 @@ -135,12 +134,12 @@ class BaseQuicConnection: class AsyncQuicConnection(BaseQuicConnection): - async def make_stream(self) -> Any: + async def make_stream(self, timeout: Optional[float] = None) -> Any: pass class BaseQuicManager: - def __init__(self, conf, verify_mode, connection_factory): + def __init__(self, conf, verify_mode, connection_factory, server_name=None): self._connections = {} self._connection_factory = connection_factory if conf is None: @@ -151,6 +150,7 @@ class BaseQuicManager: conf = aioquic.quic.configuration.QuicConfiguration( alpn_protocols=["doq", "doq-i03"], verify_mode=verify_mode, + server_name=server_name, ) if verify_path is not None: conf.load_verify_locations(verify_path) diff --git a/lib/dns/quic/_sync.py b/lib/dns/quic/_sync.py index be005ba9..e944784d 100644 --- a/lib/dns/quic/_sync.py +++ b/lib/dns/quic/_sync.py @@ -1,8 +1,8 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import selectors import socket import ssl -import selectors import struct import threading import time @@ -10,13 +10,15 @@ import time import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore -import dns.inet +import dns.exception +import dns.inet from dns.quic._common import ( - BaseQuicStream, + QUIC_MAX_DATAGRAM, BaseQuicConnection, BaseQuicManager, - QUIC_MAX_DATAGRAM, + BaseQuicStream, + UnexpectedEOF, ) # Avoid circularity with dns.query @@ -33,14 +35,15 @@ class SyncQuicStream(BaseQuicStream): self._lock = threading.Lock() def wait_for(self, amount, expiration): - timeout = self._timeout_from_expiration(expiration) while True: + timeout = self._timeout_from_expiration(expiration) with self._lock: if self._buffer.have(amount): return self._expecting = amount with self._wake_up: - self._wake_up.wait(timeout) + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout self._expecting = 0 def receive(self, timeout=None): @@ -114,24 +117,30 @@ class SyncQuicConnection(BaseQuicConnection): return def _worker(self): - sel = _selector_class() - sel.register(self._socket, selectors.EVENT_READ, self._read) - sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) - while not self._done: - (expiration, interval) = self._get_timer_values(False) - items = sel.select(interval) - for (key, _) in items: - key.data() + try: + sel = _selector_class() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for key, _ in items: + key.data() + with self._lock: + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + self._handle_events() + finally: with self._lock: - self._handle_timer(expiration) - datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, _) in datagrams: - try: - self._socket.send(datagram) - except BlockingIOError: - # we let QUIC handle any lossage - pass - self._handle_events() + self._done = True + # Ensure anyone waiting for this gets woken up. + self._handshake_complete.set() def _handle_events(self): while True: @@ -163,9 +172,12 @@ class SyncQuicConnection(BaseQuicConnection): self._worker_thread = threading.Thread(target=self._worker) self._worker_thread.start() - def make_stream(self): - self._handshake_complete.wait() + def make_stream(self, timeout=None): + if not self._handshake_complete.wait(timeout): + raise dns.exception.Timeout with self._lock: + if self._done: + raise UnexpectedEOF stream_id = self._connection.get_next_available_stream_id(False) stream = SyncQuicStream(self, stream_id) self._streams[stream_id] = stream @@ -187,8 +199,8 @@ class SyncQuicConnection(BaseQuicConnection): class SyncQuicManager(BaseQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, SyncQuicConnection) + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name) self._lock = threading.Lock() def connect(self, address, port=853, source=None, source_port=0): @@ -206,7 +218,7 @@ class SyncQuicManager(BaseQuicManager): return self def __exit__(self, exc_type, exc_val, exc_tb): - # Copy the itertor into a list as exiting things will mutate the connections + # Copy the iterator into a list as exiting things will mutate the connections # table. connections = list(self._connections.values()) for connection in connections: diff --git a/lib/dns/quic/_trio.py b/lib/dns/quic/_trio.py index 1e47a5a6..ee07e4f6 100644 --- a/lib/dns/quic/_trio.py +++ b/lib/dns/quic/_trio.py @@ -10,13 +10,15 @@ import aioquic.quic.connection # type: ignore import aioquic.quic.events # type: ignore import trio +import dns.exception import dns.inet from dns._asyncbackend import NullContext from dns.quic._common import ( - BaseQuicStream, + QUIC_MAX_DATAGRAM, AsyncQuicConnection, AsyncQuicManager, - QUIC_MAX_DATAGRAM, + BaseQuicStream, + UnexpectedEOF, ) @@ -44,6 +46,7 @@ class TrioQuicStream(BaseQuicStream): (size,) = struct.unpack("!H", self._buffer.get(2)) await self.wait_for(size) return self._buffer.get(size) + raise dns.exception.Timeout async def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -80,20 +83,26 @@ class TrioQuicConnection(AsyncQuicConnection): self._worker_scope = None async def _worker(self): - await self._socket.connect(self._peer) - while not self._done: - (expiration, interval) = self._get_timer_values(False) - with trio.CancelScope( - deadline=trio.current_time() + interval - ) as self._worker_scope: - datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) - self._connection.receive_datagram(datagram, self._peer[0], time.time()) - self._worker_scope = None - self._handle_timer(expiration) - datagrams = self._connection.datagrams_to_send(time.time()) - for (datagram, _) in datagrams: - await self._socket.send(datagram) - await self._handle_events() + try: + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram( + datagram, self._peer[0], time.time() + ) + self._worker_scope = None + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + await self._socket.send(datagram) + await self._handle_events() + finally: + self._done = True + self._handshake_complete.set() async def _handle_events(self): count = 0 @@ -130,12 +139,20 @@ class TrioQuicConnection(AsyncQuicConnection): nursery.start_soon(self._worker) self._run_done.set() - async def make_stream(self): - await self._handshake_complete.wait() - stream_id = self._connection.get_next_available_stream_id(False) - stream = TrioQuicStream(self, stream_id) - self._streams[stream_id] = stream - return stream + async def make_stream(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + raise dns.exception.Timeout async def close(self): if not self._closed: @@ -148,8 +165,10 @@ class TrioQuicConnection(AsyncQuicConnection): class TrioQuicManager(AsyncQuicManager): - def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, TrioQuicConnection) + def __init__( + self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None + ): + super().__init__(conf, verify_mode, TrioQuicConnection, server_name) self._nursery = nursery def connect(self, address, port=853, source=None, source_port=0): @@ -162,7 +181,7 @@ class TrioQuicManager(AsyncQuicManager): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - # Copy the itertor into a list as exiting things will mutate the connections + # Copy the iterator into a list as exiting things will mutate the connections # table. connections = list(self._connections.values()) for connection in connections: diff --git a/lib/dns/rdata.py b/lib/dns/rdata.py index 1dd6ed90..0d262e8d 100644 --- a/lib/dns/rdata.py +++ b/lib/dns/rdata.py @@ -17,17 +17,15 @@ """DNS rdata.""" -from typing import Any, Dict, Optional, Tuple, Union - -from importlib import import_module import base64 import binascii -import io import inspect +import io import itertools import random +from importlib import import_module +from typing import Any, Dict, Optional, Tuple, Union -import dns.wire import dns.exception import dns.immutable import dns.ipv4 @@ -37,6 +35,7 @@ import dns.rdataclass import dns.rdatatype import dns.tokenizer import dns.ttl +import dns.wire _chunksize = 32 @@ -358,7 +357,6 @@ class Rdata: or self.rdclass != other.rdclass or self.rdtype != other.rdtype ): - return NotImplemented return self._cmp(other) < 0 @@ -881,16 +879,11 @@ def register_type( it applies to all classes. """ - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - existing_cls = get_rdata_class(rdclass, the_rdtype) - if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): - raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) - try: - if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text: - raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) - except ValueError: - pass - _rdata_classes[(rdclass, the_rdtype)] = getattr( + rdtype = dns.rdatatype.RdataType.make(rdtype) + existing_cls = get_rdata_class(rdclass, rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + _rdata_classes[(rdclass, rdtype)] = getattr( implementation, rdtype_text.replace("-", "_") ) - dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) + dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/lib/dns/rdataset.py b/lib/dns/rdataset.py index c0ede425..31124afc 100644 --- a/lib/dns/rdataset.py +++ b/lib/dns/rdataset.py @@ -17,18 +17,17 @@ """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" -from typing import Any, cast, Collection, Dict, List, Optional, Union - import io import random import struct +from typing import Any, Collection, Dict, List, Optional, Union, cast import dns.exception import dns.immutable import dns.name -import dns.rdatatype -import dns.rdataclass import dns.rdata +import dns.rdataclass +import dns.rdatatype import dns.set import dns.ttl @@ -471,9 +470,9 @@ def from_text_list( Returns a ``dns.rdataset.Rdataset`` object. """ - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - r = Rdataset(the_rdclass, the_rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = Rdataset(rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text( diff --git a/lib/dns/rdtypes/ANY/AFSDB.py b/lib/dns/rdtypes/ANY/AFSDB.py index d7838e7e..3d287f6e 100644 --- a/lib/dns/rdtypes/ANY/AFSDB.py +++ b/lib/dns/rdtypes/ANY/AFSDB.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/AVC.py b/lib/dns/rdtypes/ANY/AVC.py index 11e026d0..766d5e2d 100644 --- a/lib/dns/rdtypes/ANY/AVC.py +++ b/lib/dns/rdtypes/ANY/AVC.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/CDNSKEY.py b/lib/dns/rdtypes/ANY/CDNSKEY.py index 869523fb..38b8a8da 100644 --- a/lib/dns/rdtypes/ANY/CDNSKEY.py +++ b/lib/dns/rdtypes/ANY/CDNSKEY.py @@ -15,15 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import ( - SEP, +from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] REVOKE, + SEP, ZONE, -) # noqa: F401 lgtm[py/unused-import] +) # pylint: enable=unused-import diff --git a/lib/dns/rdtypes/ANY/CDS.py b/lib/dns/rdtypes/ANY/CDS.py index 094de12b..2ff42d9a 100644 --- a/lib/dns/rdtypes/ANY/CDS.py +++ b/lib/dns/rdtypes/ANY/CDS.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dsbase import dns.immutable +import dns.rdtypes.dsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/CERT.py b/lib/dns/rdtypes/ANY/CERT.py index 1b0cbeca..30fe863f 100644 --- a/lib/dns/rdtypes/ANY/CERT.py +++ b/lib/dns/rdtypes/ANY/CERT.py @@ -15,12 +15,12 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import base64 +import struct +import dns.dnssectypes import dns.exception import dns.immutable -import dns.dnssectypes import dns.rdata import dns.tokenizer diff --git a/lib/dns/rdtypes/ANY/CNAME.py b/lib/dns/rdtypes/ANY/CNAME.py index a4fcfa88..759adb90 100644 --- a/lib/dns/rdtypes/ANY/CNAME.py +++ b/lib/dns/rdtypes/ANY/CNAME.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/CSYNC.py b/lib/dns/rdtypes/ANY/CSYNC.py index f819c08c..315da9ff 100644 --- a/lib/dns/rdtypes/ANY/CSYNC.py +++ b/lib/dns/rdtypes/ANY/CSYNC.py @@ -19,9 +19,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdatatype -import dns.name import dns.rdtypes.util diff --git a/lib/dns/rdtypes/ANY/DLV.py b/lib/dns/rdtypes/ANY/DLV.py index 947dc42e..632e90f8 100644 --- a/lib/dns/rdtypes/ANY/DLV.py +++ b/lib/dns/rdtypes/ANY/DLV.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dsbase import dns.immutable +import dns.rdtypes.dsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/DNAME.py b/lib/dns/rdtypes/ANY/DNAME.py index f4984b55..556bff59 100644 --- a/lib/dns/rdtypes/ANY/DNAME.py +++ b/lib/dns/rdtypes/ANY/DNAME.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/DNSKEY.py b/lib/dns/rdtypes/ANY/DNSKEY.py index 50fa05b7..f1a63062 100644 --- a/lib/dns/rdtypes/ANY/DNSKEY.py +++ b/lib/dns/rdtypes/ANY/DNSKEY.py @@ -15,15 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] import dns.immutable +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] # pylint: disable=unused-import -from dns.rdtypes.dnskeybase import ( - SEP, +from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] REVOKE, + SEP, ZONE, -) # noqa: F401 lgtm[py/unused-import] +) # pylint: enable=unused-import diff --git a/lib/dns/rdtypes/ANY/DS.py b/lib/dns/rdtypes/ANY/DS.py index 3f6c3ee8..097ecfa0 100644 --- a/lib/dns/rdtypes/ANY/DS.py +++ b/lib/dns/rdtypes/ANY/DS.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.dsbase import dns.immutable +import dns.rdtypes.dsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/EUI48.py b/lib/dns/rdtypes/ANY/EUI48.py index 0ab88ad0..7e4e1ff3 100644 --- a/lib/dns/rdtypes/ANY/EUI48.py +++ b/lib/dns/rdtypes/ANY/EUI48.py @@ -16,8 +16,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.euibase import dns.immutable +import dns.rdtypes.euibase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/EUI64.py b/lib/dns/rdtypes/ANY/EUI64.py index c42957ef..68b5820f 100644 --- a/lib/dns/rdtypes/ANY/EUI64.py +++ b/lib/dns/rdtypes/ANY/EUI64.py @@ -16,8 +16,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.euibase import dns.immutable +import dns.rdtypes.euibase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/HIP.py b/lib/dns/rdtypes/ANY/HIP.py index 01fec822..a20aa1e5 100644 --- a/lib/dns/rdtypes/ANY/HIP.py +++ b/lib/dns/rdtypes/ANY/HIP.py @@ -15,9 +15,9 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import base64 import binascii +import struct import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/ANY/LOC.py b/lib/dns/rdtypes/ANY/LOC.py index 52c97532..783d54af 100644 --- a/lib/dns/rdtypes/ANY/LOC.py +++ b/lib/dns/rdtypes/ANY/LOC.py @@ -21,7 +21,6 @@ import dns.exception import dns.immutable import dns.rdata - _pows = tuple(10**i for i in range(0, 11)) # default values are in centimeters @@ -40,7 +39,7 @@ def _exponent_of(what, desc): if what == 0: return 0 exp = None - for (i, pow) in enumerate(_pows): + for i, pow in enumerate(_pows): if what < pow: exp = i - 1 break diff --git a/lib/dns/rdtypes/ANY/MX.py b/lib/dns/rdtypes/ANY/MX.py index a697ea45..1f9df21f 100644 --- a/lib/dns/rdtypes/ANY/MX.py +++ b/lib/dns/rdtypes/ANY/MX.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/NINFO.py b/lib/dns/rdtypes/ANY/NINFO.py index d53e9676..55bc5614 100644 --- a/lib/dns/rdtypes/ANY/NINFO.py +++ b/lib/dns/rdtypes/ANY/NINFO.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/NS.py b/lib/dns/rdtypes/ANY/NS.py index a0cc232a..fe453f0d 100644 --- a/lib/dns/rdtypes/ANY/NS.py +++ b/lib/dns/rdtypes/ANY/NS.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/NSEC.py b/lib/dns/rdtypes/ANY/NSEC.py index 7af7b77f..a2d98fa7 100644 --- a/lib/dns/rdtypes/ANY/NSEC.py +++ b/lib/dns/rdtypes/ANY/NSEC.py @@ -17,9 +17,9 @@ import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdatatype -import dns.name import dns.rdtypes.util diff --git a/lib/dns/rdtypes/ANY/NSEC3.py b/lib/dns/rdtypes/ANY/NSEC3.py index 6eae16e0..d32fe169 100644 --- a/lib/dns/rdtypes/ANY/NSEC3.py +++ b/lib/dns/rdtypes/ANY/NSEC3.py @@ -25,7 +25,6 @@ import dns.rdata import dns.rdatatype import dns.rdtypes.util - b32_hex_to_normal = bytes.maketrans( b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" ) @@ -67,6 +66,7 @@ class NSEC3(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() + next = next.rstrip("=") if self.salt == b"": salt = "-" else: @@ -94,6 +94,10 @@ class NSEC3(dns.rdata.Rdata): else: salt = binascii.unhexlify(salt.encode("ascii")) next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) + if next.endswith(b"="): + raise binascii.Error("Incorrect padding") + if len(next) % 8 != 0: + next += b"=" * (8 - len(next) % 8) next = base64.b32decode(next) bitmap = Bitmap.from_text(tok) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) diff --git a/lib/dns/rdtypes/ANY/NSEC3PARAM.py b/lib/dns/rdtypes/ANY/NSEC3PARAM.py index 1b7269a0..1a0c0e08 100644 --- a/lib/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/lib/dns/rdtypes/ANY/NSEC3PARAM.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/ANY/OPT.py b/lib/dns/rdtypes/ANY/OPT.py index 36d4c7c6..d70e5373 100644 --- a/lib/dns/rdtypes/ANY/OPT.py +++ b/lib/dns/rdtypes/ANY/OPT.py @@ -18,11 +18,10 @@ import struct import dns.edns -import dns.immutable import dns.exception +import dns.immutable import dns.rdata - # We don't implement from_text, and that's ok. # pylint: disable=abstract-method diff --git a/lib/dns/rdtypes/ANY/PTR.py b/lib/dns/rdtypes/ANY/PTR.py index 265bed03..7fd5547d 100644 --- a/lib/dns/rdtypes/ANY/PTR.py +++ b/lib/dns/rdtypes/ANY/PTR.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/RP.py b/lib/dns/rdtypes/ANY/RP.py index c0c316b5..9c64c6e2 100644 --- a/lib/dns/rdtypes/ANY/RP.py +++ b/lib/dns/rdtypes/ANY/RP.py @@ -17,8 +17,8 @@ import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/RRSIG.py b/lib/dns/rdtypes/ANY/RRSIG.py index 3d5ad0f3..11605026 100644 --- a/lib/dns/rdtypes/ANY/RRSIG.py +++ b/lib/dns/rdtypes/ANY/RRSIG.py @@ -21,8 +21,8 @@ import struct import time import dns.dnssectypes -import dns.immutable import dns.exception +import dns.immutable import dns.rdata import dns.rdatatype diff --git a/lib/dns/rdtypes/ANY/RT.py b/lib/dns/rdtypes/ANY/RT.py index 8d9c6bd0..950f2a06 100644 --- a/lib/dns/rdtypes/ANY/RT.py +++ b/lib/dns/rdtypes/ANY/RT.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/SOA.py b/lib/dns/rdtypes/ANY/SOA.py index 6f6fe58b..bde55e15 100644 --- a/lib/dns/rdtypes/ANY/SOA.py +++ b/lib/dns/rdtypes/ANY/SOA.py @@ -19,8 +19,8 @@ import struct import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/SPF.py b/lib/dns/rdtypes/ANY/SPF.py index 1190e0de..c403589a 100644 --- a/lib/dns/rdtypes/ANY/SPF.py +++ b/lib/dns/rdtypes/ANY/SPF.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/SSHFP.py b/lib/dns/rdtypes/ANY/SSHFP.py index 58ffcbbc..67805452 100644 --- a/lib/dns/rdtypes/ANY/SSHFP.py +++ b/lib/dns/rdtypes/ANY/SSHFP.py @@ -15,11 +15,11 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct -import dns.rdata import dns.immutable +import dns.rdata import dns.rdatatype diff --git a/lib/dns/rdtypes/ANY/TKEY.py b/lib/dns/rdtypes/ANY/TKEY.py index 070f03af..d5f5fc45 100644 --- a/lib/dns/rdtypes/ANY/TKEY.py +++ b/lib/dns/rdtypes/ANY/TKEY.py @@ -18,8 +18,8 @@ import base64 import struct -import dns.immutable import dns.exception +import dns.immutable import dns.rdata diff --git a/lib/dns/rdtypes/ANY/TXT.py b/lib/dns/rdtypes/ANY/TXT.py index cc4b6611..f4e61930 100644 --- a/lib/dns/rdtypes/ANY/TXT.py +++ b/lib/dns/rdtypes/ANY/TXT.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.txtbase import dns.immutable +import dns.rdtypes.txtbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/URI.py b/lib/dns/rdtypes/ANY/URI.py index b4c95a3b..7463e277 100644 --- a/lib/dns/rdtypes/ANY/URI.py +++ b/lib/dns/rdtypes/ANY/URI.py @@ -20,9 +20,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdtypes.util -import dns.name @dns.immutable.immutable diff --git a/lib/dns/rdtypes/ANY/ZONEMD.py b/lib/dns/rdtypes/ANY/ZONEMD.py index 1f86ba49..3062843b 100644 --- a/lib/dns/rdtypes/ANY/ZONEMD.py +++ b/lib/dns/rdtypes/ANY/ZONEMD.py @@ -1,7 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import struct import binascii +import struct import dns.immutable import dns.rdata diff --git a/lib/dns/rdtypes/CH/A.py b/lib/dns/rdtypes/CH/A.py index 9905c7c9..e457f38a 100644 --- a/lib/dns/rdtypes/CH/A.py +++ b/lib/dns/rdtypes/CH/A.py @@ -17,8 +17,8 @@ import struct -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/APL.py b/lib/dns/rdtypes/IN/APL.py index 05e1689f..f1bb01db 100644 --- a/lib/dns/rdtypes/IN/APL.py +++ b/lib/dns/rdtypes/IN/APL.py @@ -124,7 +124,6 @@ class APL(dns.rdata.Rdata): @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): - items = [] while parser.remaining() > 0: header = parser.get_struct("!HBB") diff --git a/lib/dns/rdtypes/IN/HTTPS.py b/lib/dns/rdtypes/IN/HTTPS.py index 7797fbaf..15464cbd 100644 --- a/lib/dns/rdtypes/IN/HTTPS.py +++ b/lib/dns/rdtypes/IN/HTTPS.py @@ -1,7 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import dns.rdtypes.svcbbase import dns.immutable +import dns.rdtypes.svcbbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/IPSECKEY.py b/lib/dns/rdtypes/IN/IPSECKEY.py index 1255739f..8bb2bcb6 100644 --- a/lib/dns/rdtypes/IN/IPSECKEY.py +++ b/lib/dns/rdtypes/IN/IPSECKEY.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import base64 +import struct import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/IN/KX.py b/lib/dns/rdtypes/IN/KX.py index c27e9215..a03d1d51 100644 --- a/lib/dns/rdtypes/IN/KX.py +++ b/lib/dns/rdtypes/IN/KX.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.mxbase import dns.immutable +import dns.rdtypes.mxbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/NSAP_PTR.py b/lib/dns/rdtypes/IN/NSAP_PTR.py index 57dadd47..0a18fdce 100644 --- a/lib/dns/rdtypes/IN/NSAP_PTR.py +++ b/lib/dns/rdtypes/IN/NSAP_PTR.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import dns.rdtypes.nsbase import dns.immutable +import dns.rdtypes.nsbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/PX.py b/lib/dns/rdtypes/IN/PX.py index b2216d6b..5c0aa81e 100644 --- a/lib/dns/rdtypes/IN/PX.py +++ b/lib/dns/rdtypes/IN/PX.py @@ -19,9 +19,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdtypes.util -import dns.name @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/SRV.py b/lib/dns/rdtypes/IN/SRV.py index 8b0b6bf7..84c54007 100644 --- a/lib/dns/rdtypes/IN/SRV.py +++ b/lib/dns/rdtypes/IN/SRV.py @@ -19,9 +19,9 @@ import struct import dns.exception import dns.immutable +import dns.name import dns.rdata import dns.rdtypes.util -import dns.name @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/SVCB.py b/lib/dns/rdtypes/IN/SVCB.py index 9a1ad101..ff3e9327 100644 --- a/lib/dns/rdtypes/IN/SVCB.py +++ b/lib/dns/rdtypes/IN/SVCB.py @@ -1,7 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import dns.rdtypes.svcbbase import dns.immutable +import dns.rdtypes.svcbbase @dns.immutable.immutable diff --git a/lib/dns/rdtypes/IN/WKS.py b/lib/dns/rdtypes/IN/WKS.py index a671e203..26d287a3 100644 --- a/lib/dns/rdtypes/IN/WKS.py +++ b/lib/dns/rdtypes/IN/WKS.py @@ -18,8 +18,8 @@ import socket import struct -import dns.ipv4 import dns.immutable +import dns.ipv4 import dns.rdata try: diff --git a/lib/dns/rdtypes/dnskeybase.py b/lib/dns/rdtypes/dnskeybase.py index 1d17f70f..3bfcf860 100644 --- a/lib/dns/rdtypes/dnskeybase.py +++ b/lib/dns/rdtypes/dnskeybase.py @@ -19,9 +19,9 @@ import base64 import enum import struct +import dns.dnssectypes import dns.exception import dns.immutable -import dns.dnssectypes import dns.rdata # wildcard import @@ -43,7 +43,7 @@ class DNSKEYBase(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): super().__init__(rdclass, rdtype) - self.flags = self._as_uint16(flags) + self.flags = Flag(self._as_uint16(flags)) self.protocol = self._as_uint8(protocol) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.key = self._as_bytes(key) diff --git a/lib/dns/rdtypes/dsbase.py b/lib/dns/rdtypes/dsbase.py index b6032b0f..1ad0b7a5 100644 --- a/lib/dns/rdtypes/dsbase.py +++ b/lib/dns/rdtypes/dsbase.py @@ -15,8 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct import dns.dnssectypes import dns.immutable @@ -44,7 +44,7 @@ class DSBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.key_tag = self._as_uint16(key_tag) self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) - self.digest_type = self._as_uint8(digest_type) + self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type)) self.digest = self._as_bytes(digest) try: if len(self.digest) != self._digest_length_by_type[self.digest_type]: diff --git a/lib/dns/rdtypes/euibase.py b/lib/dns/rdtypes/euibase.py index e524aea9..4c4068b2 100644 --- a/lib/dns/rdtypes/euibase.py +++ b/lib/dns/rdtypes/euibase.py @@ -16,8 +16,8 @@ import binascii -import dns.rdata import dns.immutable +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/mxbase.py b/lib/dns/rdtypes/mxbase.py index b4b9b088..a6bae078 100644 --- a/lib/dns/rdtypes/mxbase.py +++ b/lib/dns/rdtypes/mxbase.py @@ -21,8 +21,8 @@ import struct import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata import dns.rdtypes.util diff --git a/lib/dns/rdtypes/nsbase.py b/lib/dns/rdtypes/nsbase.py index ba7a2ab7..56d94235 100644 --- a/lib/dns/rdtypes/nsbase.py +++ b/lib/dns/rdtypes/nsbase.py @@ -19,8 +19,8 @@ import dns.exception import dns.immutable -import dns.rdata import dns.name +import dns.rdata @dns.immutable.immutable diff --git a/lib/dns/rdtypes/svcbbase.py b/lib/dns/rdtypes/svcbbase.py index 8d6fb1c6..ba5b53d2 100644 --- a/lib/dns/rdtypes/svcbbase.py +++ b/lib/dns/rdtypes/svcbbase.py @@ -34,6 +34,7 @@ class ParamKey(dns.enum.IntEnum): IPV4HINT = 4 ECH = 5 IPV6HINT = 6 + DOHPATH = 7 @classmethod def _maximum(cls): diff --git a/lib/dns/rdtypes/tlsabase.py b/lib/dns/rdtypes/tlsabase.py index a3fdc354..4cdb7ab3 100644 --- a/lib/dns/rdtypes/tlsabase.py +++ b/lib/dns/rdtypes/tlsabase.py @@ -15,11 +15,11 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import struct -import dns.rdata import dns.immutable +import dns.rdata import dns.rdatatype diff --git a/lib/dns/rdtypes/txtbase.py b/lib/dns/rdtypes/txtbase.py index d4cb9bb2..fdbfb646 100644 --- a/lib/dns/rdtypes/txtbase.py +++ b/lib/dns/rdtypes/txtbase.py @@ -17,9 +17,8 @@ """TXT-like base class.""" -from typing import Any, Dict, Iterable, Optional, Tuple, Union - import struct +from typing import Any, Dict, Iterable, Optional, Tuple, Union import dns.exception import dns.immutable diff --git a/lib/dns/rdtypes/util.py b/lib/dns/rdtypes/util.py index 74596f05..54908fdc 100644 --- a/lib/dns/rdtypes/util.py +++ b/lib/dns/rdtypes/util.py @@ -18,6 +18,7 @@ import collections import random import struct +from typing import Any, List import dns.exception import dns.ipv4 @@ -119,7 +120,7 @@ class Bitmap: def __init__(self, windows=None): last_window = -1 self.windows = windows - for (window, bitmap) in self.windows: + for window, bitmap in self.windows: if not isinstance(window, int): raise ValueError(f"bad {self.type_name} window type") if window <= last_window: @@ -132,11 +133,11 @@ class Bitmap: if len(bitmap) == 0 or len(bitmap) > 32: raise ValueError(f"bad {self.type_name} octets") - def to_text(self): + def to_text(self) -> str: text = "" - for (window, bitmap) in self.windows: + for window, bitmap in self.windows: bits = [] - for (i, byte) in enumerate(bitmap): + for i, byte in enumerate(bitmap): for j in range(0, 8): if byte & (0x80 >> j): rdtype = window * 256 + i * 8 + j @@ -145,14 +146,18 @@ class Bitmap: return text @classmethod - def from_text(cls, tok): + def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap": rdtypes = [] for token in tok.get_remaining(): rdtype = dns.rdatatype.from_text(token.unescape().value) if rdtype == 0: raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0") rdtypes.append(rdtype) - rdtypes.sort() + return cls.from_rdtypes(rdtypes) + + @classmethod + def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap": + rdtypes = sorted(rdtypes) window = 0 octets = 0 prior_rdtype = 0 @@ -177,13 +182,13 @@ class Bitmap: windows.append((window, bytes(bitmap[0:octets]))) return cls(windows) - def to_wire(self, file): - for (window, bitmap) in self.windows: + def to_wire(self, file: Any) -> None: + for window, bitmap in self.windows: file.write(struct.pack("!BB", window, len(bitmap))) file.write(bitmap) @classmethod - def from_wire_parser(cls, parser): + def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap": windows = [] while parser.remaining() > 0: window = parser.get_uint8() @@ -226,7 +231,7 @@ def weighted_processing_order(iterable): total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) while len(rdatas) > 1: r = random.uniform(0, total) - for (n, rdata) in enumerate(rdatas): + for n, rdata in enumerate(rdatas): weight = rdata._processing_weight() or _no_weight if weight > r: break diff --git a/lib/dns/renderer.py b/lib/dns/renderer.py index 3c495f61..53e7c0f6 100644 --- a/lib/dns/renderer.py +++ b/lib/dns/renderer.py @@ -19,14 +19,13 @@ import contextlib import io -import struct import random +import struct import time import dns.exception import dns.tsig - QUESTION = 0 ANSWER = 1 AUTHORITY = 2 diff --git a/lib/dns/resolver.py b/lib/dns/resolver.py index a5b66c1d..f08f824d 100644 --- a/lib/dns/resolver.py +++ b/lib/dns/resolver.py @@ -17,29 +17,31 @@ """DNS stub resolver.""" -from typing import Any, Dict, List, Optional, Tuple, Union - -from urllib.parse import urlparse import contextlib +import random import socket import sys import threading import time -import random import warnings +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlparse -import dns.exception +import dns._ddr import dns.edns +import dns.exception import dns.flags import dns.inet import dns.ipv4 import dns.ipv6 import dns.message import dns.name +import dns.nameserver import dns.query import dns.rcode import dns.rdataclass import dns.rdatatype +import dns.rdtypes.svcbbase import dns.reversename import dns.tsig @@ -72,7 +74,7 @@ class NXDOMAIN(dns.exception.DNSException): kwargs = dict(qnames=qnames, responses=responses) return kwargs - def __str__(self): + def __str__(self) -> str: if "qnames" not in self.kwargs: return super().__str__() qnames = self.kwargs["qnames"] @@ -140,7 +142,11 @@ class YXDOMAIN(dns.exception.DNSException): ErrorTuple = Tuple[ - Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message] + Optional[str], + bool, + int, + Union[Exception, str], + Optional[dns.message.Message], ] @@ -148,11 +154,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: """Turn a resolution errors trace into a list of text.""" texts = [] for err in errors: - texts.append( - "Server {} {} port {} answered {}".format( - err[0], "TCP" if err[1] else "UDP", err[2], err[3] - ) - ) + texts.append("Server {} answered {}".format(err[0], err[3])) return texts @@ -184,7 +186,7 @@ Timeout = LifetimeTimeout class NoAnswer(dns.exception.DNSException): """The DNS response does not contain an answer to the question.""" - fmt = "The DNS response does not contain an answer " + "to the question: {query}" + fmt = "The DNS response does not contain an answer to the question: {query}" supp_kwargs = {"response"} # We do this as otherwise mypy complains about unexpected keyword argument @@ -264,7 +266,7 @@ class Answer: response: dns.message.QueryMessage, nameserver: Optional[str] = None, port: Optional[int] = None, - ): + ) -> None: self.qname = qname self.rdtype = rdtype self.rdclass = rdclass @@ -292,7 +294,7 @@ class Answer: else: raise AttributeError(attr) - def __len__(self): + def __len__(self) -> int: return self.rrset and len(self.rrset) or 0 def __iter__(self): @@ -309,14 +311,67 @@ class Answer: del self.rrset[i] +class Answers(dict): + """A dict of DNS stub resolver answers, indexed by type.""" + + +class HostAnswers(Answers): + """A dict of DNS stub resolver answers to a host name lookup, indexed by + type. + """ + + @classmethod + def make( + cls, + v6: Optional[Answer] = None, + v4: Optional[Answer] = None, + add_empty: bool = True, + ) -> "HostAnswers": + answers = HostAnswers() + if v6 is not None and (add_empty or v6.rrset): + answers[dns.rdatatype.AAAA] = v6 + if v4 is not None and (add_empty or v4.rrset): + answers[dns.rdatatype.A] = v4 + return answers + + # Returns pairs of (address, family) from this result, potentiallys + # filtering by address family. + def addresses_and_families( + self, family: int = socket.AF_UNSPEC + ) -> Iterator[Tuple[str, int]]: + if family == socket.AF_UNSPEC: + yield from self.addresses_and_families(socket.AF_INET6) + yield from self.addresses_and_families(socket.AF_INET) + return + elif family == socket.AF_INET6: + answer = self.get(dns.rdatatype.AAAA) + elif family == socket.AF_INET: + answer = self.get(dns.rdatatype.A) + else: + raise NotImplementedError(f"unknown address family {family}") + if answer: + for rdata in answer: + yield (rdata.address, family) + + # Returns addresses from this result, potentially filtering by + # address family. + def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]: + return (pair[0] for pair in self.addresses_and_families(family)) + + # Returns the canonical name from this result. + def canonical_name(self) -> dns.name.Name: + answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A)) + return answer.canonical_name + + class CacheStatistics: """Cache Statistics""" - def __init__(self, hits=0, misses=0): + def __init__(self, hits: int = 0, misses: int = 0) -> None: self.hits = hits self.misses = misses - def reset(self): + def reset(self) -> None: self.hits = 0 self.misses = 0 @@ -325,7 +380,7 @@ class CacheStatistics: class CacheBase: - def __init__(self): + def __init__(self) -> None: self.lock = threading.Lock() self.statistics = CacheStatistics() @@ -361,7 +416,7 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" - def __init__(self, cleaning_interval: float = 300.0): + def __init__(self, cleaning_interval: float = 300.0) -> None: """*cleaning_interval*, a ``float`` is the number of seconds between periodic cleanings. """ @@ -377,7 +432,7 @@ class Cache(CacheBase): now = time.time() if self.next_cleaning <= now: keys_to_delete = [] - for (k, v) in self.data.items(): + for k, v in self.data.items(): if v.expiration <= now: keys_to_delete.append(k) for k in keys_to_delete: @@ -447,13 +502,13 @@ class LRUCacheNode: self.prev = self self.next = self - def link_after(self, node): + def link_after(self, node: "LRUCacheNode") -> None: self.prev = node self.next = node.next node.next.prev = self node.next = self - def unlink(self): + def unlink(self) -> None: self.next.prev = self.prev self.prev.next = self.next @@ -468,7 +523,7 @@ class LRUCache(CacheBase): for a new one. """ - def __init__(self, max_size: int = 100000): + def __init__(self, max_size: int = 100000) -> None: """*max_size*, an ``int``, is the maximum number of nodes to cache; it must be greater than 0. """ @@ -590,30 +645,29 @@ class _Resolution: tcp: bool, raise_on_no_answer: bool, search: Optional[bool], - ): + ) -> None: if isinstance(qname, str): qname = dns.name.from_text(qname, None) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - if dns.rdatatype.is_metatype(the_rdtype): + rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(rdtype): raise NoMetaqueries - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - if dns.rdataclass.is_metaclass(the_rdclass): + rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(rdclass): raise NoMetaqueries self.resolver = resolver self.qnames_to_try = resolver._get_qnames_to_try(qname, search) self.qnames = self.qnames_to_try[:] - self.rdtype = the_rdtype - self.rdclass = the_rdclass + self.rdtype = rdtype + self.rdclass = rdclass self.tcp = tcp self.raise_on_no_answer = raise_on_no_answer self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} # Initialize other things to help analysis tools self.qname = dns.name.empty - self.nameservers: List[str] = [] - self.current_nameservers: List[str] = [] + self.nameservers: List[dns.nameserver.Nameserver] = [] + self.current_nameservers: List[dns.nameserver.Nameserver] = [] self.errors: List[ErrorTuple] = [] - self.nameserver: Optional[str] = None - self.port = 0 + self.nameserver: Optional[dns.nameserver.Nameserver] = None self.tcp_attempt = False self.retry_with_tcp = False self.request: Optional[dns.message.QueryMessage] = None @@ -670,7 +724,11 @@ class _Resolution: if self.resolver.flags is not None: request.flags = self.resolver.flags - self.nameservers = self.resolver.nameservers[:] + self.nameservers = self.resolver._enrich_nameservers( + self.resolver._nameservers, + self.resolver.nameserver_ports, + self.resolver.port, + ) if self.resolver.rotate: random.shuffle(self.nameservers) self.current_nameservers = self.nameservers[:] @@ -690,12 +748,13 @@ class _Resolution: # raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) - def next_nameserver(self) -> Tuple[str, int, bool, float]: + def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]: if self.retry_with_tcp: assert self.nameserver is not None + assert not self.nameserver.is_always_max_size() self.tcp_attempt = True self.retry_with_tcp = False - return (self.nameserver, self.port, True, 0) + return (self.nameserver, True, 0) backoff = 0.0 if not self.current_nameservers: @@ -707,11 +766,8 @@ class _Resolution: self.backoff = min(self.backoff * 2, 2) self.nameserver = self.current_nameservers.pop(0) - self.port = self.resolver.nameserver_ports.get( - self.nameserver, self.resolver.port - ) - self.tcp_attempt = self.tcp - return (self.nameserver, self.port, self.tcp_attempt, backoff) + self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size() + return (self.nameserver, self.tcp_attempt, backoff) def query_result( self, response: Optional[dns.message.Message], ex: Optional[Exception] @@ -724,7 +780,13 @@ class _Resolution: # Exception during I/O or from_wire() assert response is None self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, ex, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + ex, + response, + ) ) if ( isinstance(ex, dns.exception.FormError) @@ -752,12 +814,18 @@ class _Resolution: self.rdtype, self.rdclass, response, - self.nameserver, - self.port, + self.nameserver.answer_nameserver(), + self.nameserver.answer_port(), ) except Exception as e: self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, e, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + e, + response, + ) ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) @@ -776,7 +844,13 @@ class _Resolution: ) except Exception as e: self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, e, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + e, + response, + ) ) # The nameserver is no good, take it out of the mix. self.nameservers.remove(self.nameserver) @@ -792,7 +866,13 @@ class _Resolution: elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() self.errors.append( - (self.nameserver, self.tcp_attempt, self.port, yex, response) + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + yex, + response, + ) ) raise yex else: @@ -804,9 +884,9 @@ class _Resolution: self.nameservers.remove(self.nameserver) self.errors.append( ( - self.nameserver, + str(self.nameserver), self.tcp_attempt, - self.port, + self.nameserver.answer_port(), dns.rcode.to_text(rcode), response, ) @@ -840,8 +920,11 @@ class BaseResolver: retry_servfail: bool rotate: bool ndots: Optional[int] + _nameservers: Sequence[Union[str, dns.nameserver.Nameserver]] - def __init__(self, filename: str = "/etc/resolv.conf", configure: bool = True): + def __init__( + self, filename: str = "/etc/resolv.conf", configure: bool = True + ) -> None: """*filename*, a ``str`` or file object, specifying a file in standard /etc/resolv.conf format. This parameter is meaningful only when *configure* is true and the platform is POSIX. @@ -860,13 +943,13 @@ class BaseResolver: elif filename: self.read_resolv_conf(filename) - def reset(self): + def reset(self) -> None: """Reset all resolver configuration to the defaults.""" self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) if len(self.domain) == 0: self.domain = dns.name.root - self.nameservers = [] + self._nameservers = [] self.nameserver_ports = {} self.port = 53 self.search = [] @@ -903,6 +986,7 @@ class BaseResolver: """ + nameservers = [] if isinstance(f, str): try: cm: contextlib.AbstractContextManager = open(f) @@ -922,7 +1006,7 @@ class BaseResolver: continue if tokens[0] == "nameserver": - self.nameservers.append(tokens[1]) + nameservers.append(tokens[1]) elif tokens[0] == "domain": self.domain = dns.name.from_text(tokens[1]) # domain and search are exclusive @@ -950,8 +1034,11 @@ class BaseResolver: self.ndots = int(opt.split(":")[1]) except (ValueError, IndexError): pass - if len(self.nameservers) == 0: + if len(nameservers) == 0: raise NoResolverConfiguration("no nameservers") + # Assigning directly instead of appending means we invoke the + # setter logic, with additonal checking and enrichment. + self.nameservers = nameservers def read_registry(self) -> None: """Extract resolver configuration from the Windows registry.""" @@ -1086,34 +1173,64 @@ class BaseResolver: self.flags = flags - @property - def nameservers(self) -> List[str]: - return self._nameservers - - @nameservers.setter - def nameservers(self, nameservers: List[str]) -> None: - """ - *nameservers*, a ``list`` of nameservers. - - Raises ``ValueError`` if *nameservers* is anything other than a - ``list``. - """ + @classmethod + def _enrich_nameservers( + cls, + nameservers: Sequence[Union[str, dns.nameserver.Nameserver]], + nameserver_ports: Dict[str, int], + default_port: int, + ) -> List[dns.nameserver.Nameserver]: + enriched_nameservers = [] if isinstance(nameservers, list): for nameserver in nameservers: - if not dns.inet.is_address(nameserver): + enriched_nameserver: dns.nameserver.Nameserver + if isinstance(nameserver, dns.nameserver.Nameserver): + enriched_nameserver = nameserver + elif dns.inet.is_address(nameserver): + port = nameserver_ports.get(nameserver, default_port) + enriched_nameserver = dns.nameserver.Do53Nameserver( + nameserver, port + ) + else: try: if urlparse(nameserver).scheme != "https": raise NotImplementedError except Exception: raise ValueError( - f"nameserver {nameserver} is not an " - "IP address or valid https URL" + f"nameserver {nameserver} is not a " + "dns.nameserver.Nameserver instance or text form, " + "IP address, nor a valid https URL" ) - self._nameservers = nameservers + enriched_nameserver = dns.nameserver.DoHNameserver(nameserver) + enriched_nameservers.append(enriched_nameserver) else: raise ValueError( - "nameservers must be a list (not a {})".format(type(nameservers)) + "nameservers must be a list or tuple (not a {})".format( + type(nameservers) + ) ) + return enriched_nameservers + + @property + def nameservers( + self, + ) -> Sequence[Union[str, dns.nameserver.Nameserver]]: + return self._nameservers + + @nameservers.setter + def nameservers( + self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]] + ) -> None: + """ + *nameservers*, a ``list`` of nameservers, where a nameserver is either + a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver`` + instance. + + Raises ``ValueError`` if *nameservers* is not a list of nameservers. + """ + # We just call _enrich_nameservers() for checking + self._enrich_nameservers(nameservers, self.nameserver_ports, self.port) + self._nameservers = nameservers class Resolver(BaseResolver): @@ -1198,33 +1315,18 @@ class Resolver(BaseResolver): assert request is not None # needed for type checking done = False while not done: - (nameserver, port, tcp, backoff) = resolution.next_nameserver() + (nameserver, tcp, backoff) = resolution.next_nameserver() if backoff: time.sleep(backoff) timeout = self._compute_timeout(start, lifetime, resolution.errors) try: - if dns.inet.is_address(nameserver): - if tcp: - response = dns.query.tcp( - request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port, - ) - else: - response = dns.query.udp( - request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port, - raise_on_truncation=True, - ) - else: - response = dns.query.https(request, nameserver, timeout=timeout) + response = nameserver.query( + request, + timeout=timeout, + source=source, + source_port=source_port, + max_size=tcp, + ) except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -1293,7 +1395,72 @@ class Resolver(BaseResolver): modified_kwargs["rdclass"] = dns.rdataclass.IN return self.resolve( dns.reversename.from_address(ipaddr), *args, **modified_kwargs - ) # type: ignore[arg-type] + ) + + def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any, + ) -> HostAnswers: + """Use a resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) + lifetime = modified_kwargs.pop("lifetime", None) + start = time.time() + v6 = self.resolve( + name, + dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = self.resolve( + name, + dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer) + if not answers: + raise NoAnswer(response=v6.response) + return answers # pylint: disable=redefined-outer-name @@ -1320,6 +1487,37 @@ class Resolver(BaseResolver): # pylint: enable=redefined-outer-name + def try_ddr(self, lifetime: float = 5.0) -> None: + """Try to update the resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + *lifetime*, a float, is the maximum time to spend attempting DDR. The default + is 5 seconds. + + If the SVCB query is successful and results in a non-empty list of nameservers, + then the resolver's nameservers are set to the returned servers in priority + order. + + The current implementation does not use any address hints from the SVCB record, + nor does it resolve addresses for the SCVB target name, rather it assumes that + the bootstrap nameserver will always be one of the addresses and uses it. + A future revision to the code may offer fuller support. The code verifies that + the bootstrap nameserver is in the Subject Alternative Name field of the + TLS certficate. + """ + try: + expiration = time.time() + lifetime + answer = self.resolve( + dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime + ) + timeout = dns.query._remaining(expiration) + nameservers = dns._ddr._get_nameservers_sync(answer, timeout) + if len(nameservers) > 0: + self.nameservers = nameservers + except Exception: + pass + #: The default resolver. default_resolver: Optional[Resolver] = None @@ -1333,7 +1531,7 @@ def get_default_resolver() -> Resolver: return default_resolver -def reset_default_resolver(): +def reset_default_resolver() -> None: """Re-initialize default resolver. Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX @@ -1355,7 +1553,6 @@ def resolve( lifetime: Optional[float] = None, search: Optional[bool] = None, ) -> Answer: # pragma: no cover - """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver @@ -1421,6 +1618,18 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer: return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +def resolve_name( + name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any +) -> HostAnswers: + """Use a resolver to query for address records. + + See ``dns.resolver.Resolver.resolve_name`` for more information on the + parameters. + """ + + return get_default_resolver().resolve_name(name, family, **kwargs) + + def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. @@ -1431,6 +1640,16 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return get_default_resolver().canonical_name(name) +def try_ddr(lifetime: float = 5.0) -> None: + """Try to update the default resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + See :py:func:`dns.resolver.Resolver.try_ddr` for more information. + """ + return get_default_resolver().try_ddr(lifetime) + + def zone_for_name( name: Union[dns.name.Name, str], rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, @@ -1478,7 +1697,7 @@ def zone_for_name( while 1: try: rlifetime: Optional[float] - if expiration: + if expiration is not None: rlifetime = expiration - time.time() if rlifetime <= 0: rlifetime = 0 @@ -1516,6 +1735,83 @@ def zone_for_name( raise NoRootSOA +def make_resolver_at( + where: Union[dns.name.Name, str], + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Resolver: + """Make a stub resolver using the specified destination as the full resolver. + + *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the + full resolver. + + *port*, an ``int``, the port to use. If not specified, the default is 53. + + *family*, an ``int``, the address family to use. This parameter is used if + *where* is not an address. The default is ``socket.AF_UNSPEC`` in which case + the first address returned by ``resolve_name()`` will be used, otherwise the + first address of the specified family will be used. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames. If not specified, the default resolver will be used. + + Returns a ``dns.resolver.Resolver`` or raises an exception. + """ + if resolver is None: + resolver = get_default_resolver() + nameservers: List[Union[str, dns.nameserver.Nameserver]] = [] + if isinstance(where, str) and dns.inet.is_address(where): + nameservers.append(dns.nameserver.Do53Nameserver(where, port)) + else: + for address in resolver.resolve_name(where, family).addresses(): + nameservers.append(dns.nameserver.Do53Nameserver(address, port)) + res = dns.resolver.Resolver(configure=False) + res.nameservers = nameservers + return res + + +def resolve_at( + where: Union[dns.name.Name, str], + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Answer: + """Query nameservers to find the answer to the question. + + This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to + make a resolver, and then uses it to resolve the query. + + See ``dns.resolver.Resolver.resolve`` for more information on the resolution + parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver + parameters *where*, *port*, *family*, and *resolver*. + + If making more than one query, it is more efficient to call + ``dns.resolver.make_resolver_at()`` and then use that resolver for the queries + instead of calling ``resolve_at()`` multiple times. + """ + return make_resolver_at(where, port, family, resolver).resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + ) + + # # Support for overriding the system resolver for all python code in the # running process. @@ -1559,8 +1855,7 @@ def _getaddrinfo( ) if host is None and service is None: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") - v6addrs = [] - v4addrs = [] + addrs = [] canonical_name = None # pylint: disable=redefined-outer-name # Is host None or an address literal? If so, use the system's # getaddrinfo(). @@ -1576,24 +1871,9 @@ def _getaddrinfo( pass # Something needs resolution! try: - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False) - # Note that setting host ensures we query the same name - # for A as we did for AAAA. (This is just in case search lists - # are active by default in the resolver configuration and - # we might be talking to a server that says NXDOMAIN when it - # wants to say NOERROR no data. - host = v6.qname - canonical_name = v6.canonical_name.to_text(True) - if v6.rrset is not None: - for rdata in v6.rrset: - v6addrs.append(rdata.address) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False) - canonical_name = v4.canonical_name.to_text(True) - if v4.rrset is not None: - for rdata in v4.rrset: - v4addrs.append(rdata.address) + answers = _resolver.resolve_name(host, family) + addrs = answers.addresses_and_families() + canonical_name = answers.canonical_name().to_text(True) except dns.resolver.NXDOMAIN: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") except Exception: @@ -1625,20 +1905,11 @@ def _getaddrinfo( cname = canonical_name else: cname = "" - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - for addr in v6addrs: - for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: - tuples.append( - (socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0)) - ) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - for addr in v4addrs: - for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: - tuples.append( - (socket.AF_INET, socktype, proto, cname, (addr, port)) - ) + for addr, af in addrs: + for socktype in socktypes: + for proto in _protocols_for_socktype[socktype]: + addr_tuple = dns.inet.low_level_address_tuple((addr, port), af) + tuples.append((af, socktype, proto, cname, addr_tuple)) if len(tuples) == 0: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") return tuples diff --git a/lib/dns/reversename.py b/lib/dns/reversename.py index eb6a3b6b..8236c711 100644 --- a/lib/dns/reversename.py +++ b/lib/dns/reversename.py @@ -19,9 +19,9 @@ import binascii -import dns.name -import dns.ipv6 import dns.ipv4 +import dns.ipv6 +import dns.name ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.") ipv6_reverse_domain = dns.name.from_text("ip6.arpa.") diff --git a/lib/dns/rrset.py b/lib/dns/rrset.py index 3f22a90c..350de13e 100644 --- a/lib/dns/rrset.py +++ b/lib/dns/rrset.py @@ -17,11 +17,11 @@ """DNS RRsets (an RRset is a named rdataset)""" -from typing import Any, cast, Collection, Dict, Optional, Union +from typing import Any, Collection, Dict, Optional, Union, cast import dns.name -import dns.rdataset import dns.rdataclass +import dns.rdataset import dns.renderer @@ -214,9 +214,9 @@ def from_text_list( if isinstance(name, str): name = dns.name.from_text(name, None, idna_codec=idna_codec) - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - r = RRset(name, the_rdclass, the_rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = RRset(name, rdclass, rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text( diff --git a/lib/dns/tokenizer.py b/lib/dns/tokenizer.py index 0551578a..454cac4a 100644 --- a/lib/dns/tokenizer.py +++ b/lib/dns/tokenizer.py @@ -17,10 +17,9 @@ """Tokenize DNS zone file format""" -from typing import Any, Optional, List, Tuple - import io import sys +from typing import Any, List, Optional, Tuple import dns.exception import dns.name diff --git a/lib/dns/transaction.py b/lib/dns/transaction.py index c4a9e1f6..21dea775 100644 --- a/lib/dns/transaction.py +++ b/lib/dns/transaction.py @@ -1,8 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Any, Callable, List, Optional, Tuple, Union - import collections +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import dns.exception import dns.name @@ -357,6 +356,27 @@ class Transaction: """ self._check_delete_name.append(check) + def iterate_rdatasets( + self, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: + """Iterate all the rdatasets in the transaction, returning + (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples. + + Note that as is usual with python iterators, adding or removing items + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_rdatasets() + + def iterate_names(self) -> Iterator[dns.name.Name]: + """Iterate all the names in the transaction. + + Note that as is usual with python iterators, adding or removing names + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_names() + # # Helper methods # @@ -416,7 +436,7 @@ class Transaction: rdataset = rrset.to_rdataset() else: raise TypeError( - f"{method} requires a name or RRset " + "as the first argument" + f"{method} requires a name or RRset as the first argument" ) if rdataset.rdclass != self.manager.get_class(): raise ValueError(f"{method} has objects of wrong RdataClass") @@ -475,7 +495,7 @@ class Transaction: name = rdataset.name else: raise TypeError( - f"{method} requires a name or RRset " + "as the first argument" + f"{method} requires a name or RRset as the first argument" ) self._raise_if_not_empty(method, args) if rdataset: @@ -610,6 +630,10 @@ class Transaction: """Return an iterator that yields (name, rdataset) tuples.""" raise NotImplementedError # pragma: no cover + def _iterate_names(self): + """Return an iterator that yields a name.""" + raise NotImplementedError # pragma: no cover + def _get_node(self, name): """Return the node at *name*, if any. diff --git a/lib/dns/tsig.py b/lib/dns/tsig.py index 2476fdfb..58760f5f 100644 --- a/lib/dns/tsig.py +++ b/lib/dns/tsig.py @@ -23,9 +23,9 @@ import hmac import struct import dns.exception -import dns.rdataclass import dns.name import dns.rcode +import dns.rdataclass class BadTime(dns.exception.DNSException): @@ -187,9 +187,7 @@ class HMACTSig: try: hashinfo = self._hashes[algorithm] except KeyError: - raise NotImplementedError( - f"TSIG algorithm {algorithm} " + "is not supported" - ) + raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported") # create the HMAC context if isinstance(hashinfo, tuple): diff --git a/lib/dns/tsigkeyring.py b/lib/dns/tsigkeyring.py index 6adba284..1010a79f 100644 --- a/lib/dns/tsigkeyring.py +++ b/lib/dns/tsigkeyring.py @@ -17,9 +17,8 @@ """A place to store TSIG keys.""" -from typing import Any, Dict - import base64 +from typing import Any, Dict import dns.name import dns.tsig @@ -33,7 +32,7 @@ def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]: @rtype: dict""" keyring = {} - for (name, value) in textring.items(): + for name, value in textring.items(): kname = dns.name.from_text(name) if isinstance(value, str): keyring[kname] = dns.tsig.Key(kname, value).secret @@ -55,7 +54,7 @@ def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]: def b64encode(secret): return base64.encodebytes(secret).decode().rstrip() - for (name, key) in keyring.items(): + for name, key in keyring.items(): tname = name.to_text() if isinstance(key, bytes): textring[tname] = b64encode(key) diff --git a/lib/dns/update.py b/lib/dns/update.py index 647e5b19..bf1157ac 100644 --- a/lib/dns/update.py +++ b/lib/dns/update.py @@ -24,8 +24,8 @@ import dns.name import dns.opcode import dns.rdata import dns.rdataclass -import dns.rdatatype import dns.rdataset +import dns.rdatatype import dns.tsig @@ -43,7 +43,6 @@ class UpdateSection(dns.enum.IntEnum): class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] - # ignore the mypy error here as we mean to use a different enum _section_enum = UpdateSection # type: ignore @@ -336,12 +335,12 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] True, ) else: - the_rdtype = dns.rdatatype.RdataType.make(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) self.find_rrset( self.prerequisite, name, dns.rdataclass.NONE, - the_rdtype, + rdtype, dns.rdatatype.NONE, None, True, diff --git a/lib/dns/version.py b/lib/dns/version.py index 89d4cf1a..1f1fbf2d 100644 --- a/lib/dns/version.py +++ b/lib/dns/version.py @@ -20,13 +20,13 @@ #: MAJOR MAJOR = 2 #: MINOR -MINOR = 3 +MINOR = 4 #: MICRO -MICRO = 0 +MICRO = 2 #: RELEASELEVEL RELEASELEVEL = 0x0F #: SERIAL -SERIAL = 1 +SERIAL = 0 if RELEASELEVEL == 0x0F: # pragma: no cover lgtm[py/unreachable-statement] #: version diff --git a/lib/dns/versioned.py b/lib/dns/versioned.py index 02e24122..fd78e674 100644 --- a/lib/dns/versioned.py +++ b/lib/dns/versioned.py @@ -2,10 +2,9 @@ """DNS Versioned Zones.""" -from typing import Callable, Deque, Optional, Set, Union - import collections import threading +from typing import Callable, Deque, Optional, Set, Union import dns.exception import dns.immutable @@ -32,7 +31,6 @@ Transaction = dns.zone.Transaction class Zone(dns.zone.Zone): # lgtm[py/missing-equals] - __slots__ = [ "_versions", "_versions_lock", @@ -152,7 +150,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] # # This is not a problem with Threading module threads as # they cannot be canceled, but could be an issue with trio - # or curio tasks when we do the async version of writer(). + # tasks when we do the async version of writer(). # I.e. we'd need to do something like: # # try: diff --git a/lib/dns/win32util.py b/lib/dns/win32util.py index ac314750..b2ca61da 100644 --- a/lib/dns/win32util.py +++ b/lib/dns/win32util.py @@ -1,7 +1,6 @@ import sys if sys.platform == "win32": - from typing import Any import dns.name @@ -18,6 +17,7 @@ if sys.platform == "win32": try: import threading + import pythoncom # pylint: disable=import-error import wmi # pylint: disable=import-error @@ -206,7 +206,7 @@ if sys.platform == "win32": lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: tcp_params = winreg.OpenKey( - lm, r"SYSTEM\CurrentControlSet" r"\Services\Tcpip\Parameters" + lm, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" ) try: self._config_fromkey(tcp_params, True) @@ -214,9 +214,7 @@ if sys.platform == "win32": tcp_params.Close() interfaces = winreg.OpenKey( lm, - r"SYSTEM\CurrentControlSet" - r"\Services\Tcpip\Parameters" - r"\Interfaces", + r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces", ) try: i = 0 diff --git a/lib/dns/wire.py b/lib/dns/wire.py index cadf1686..9f9b1573 100644 --- a/lib/dns/wire.py +++ b/lib/dns/wire.py @@ -1,9 +1,8 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -from typing import Iterator, Optional, Tuple - import contextlib import struct +from typing import Iterator, Optional, Tuple import dns.exception import dns.name diff --git a/lib/dns/xfr.py b/lib/dns/xfr.py index bb165888..dd247d33 100644 --- a/lib/dns/xfr.py +++ b/lib/dns/xfr.py @@ -21,9 +21,9 @@ import dns.exception import dns.message import dns.name import dns.rcode -import dns.serial import dns.rdataset import dns.rdatatype +import dns.serial import dns.transaction import dns.tsig import dns.zone diff --git a/lib/dns/zone.py b/lib/dns/zone.py index cc8268da..9e763f5f 100644 --- a/lib/dns/zone.py +++ b/lib/dns/zone.py @@ -17,30 +17,29 @@ """DNS Zones.""" -from typing import Any, Dict, Iterator, Iterable, List, Optional, Set, Tuple, Union - import contextlib import io import os import struct +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union import dns.exception +import dns.grange import dns.immutable import dns.name import dns.node -import dns.rdataclass -import dns.rdatatype import dns.rdata +import dns.rdataclass import dns.rdataset +import dns.rdatatype import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.ZONEMD import dns.rrset import dns.tokenizer import dns.transaction import dns.ttl -import dns.grange import dns.zonefile -from dns.zonetypes import DigestScheme, DigestHashAlgorithm, _digest_hashers +from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers class BadZone(dns.exception.DNSException): @@ -321,11 +320,11 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rdataset.Rdataset``. """ - the_name = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - node = self.find_node(the_name, create) - return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create) + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.find_node(name, create) + return node.find_rdataset(self.rdclass, rdtype, covers, create) def get_rdataset( self, @@ -404,14 +403,14 @@ class Zone(dns.transaction.TransactionManager): types were aggregated into a single RRSIG rdataset. """ - the_name = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - node = self.get_node(the_name) + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.get_node(name) if node is not None: - node.delete_rdataset(self.rdclass, the_rdtype, the_covers) + node.delete_rdataset(self.rdclass, rdtype, covers) if len(node) == 0: - self.delete_node(the_name) + self.delete_node(name) def replace_rdataset( self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset @@ -484,10 +483,10 @@ class Zone(dns.transaction.TransactionManager): """ vname = self._validate_name(name) - the_rdtype = dns.rdatatype.RdataType.make(rdtype) - the_covers = dns.rdatatype.RdataType.make(covers) - rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers) - rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers) + rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers) rrset.update(rdataset) return rrset @@ -565,7 +564,7 @@ class Zone(dns.transaction.TransactionManager): rdtype = dns.rdatatype.RdataType.make(rdtype) covers = dns.rdatatype.RdataType.make(covers) - for (name, node) in self.items(): + for name, node in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or ( rds.rdtype == rdtype and rds.covers == covers @@ -597,7 +596,7 @@ class Zone(dns.transaction.TransactionManager): rdtype = dns.rdatatype.RdataType.make(rdtype) covers = dns.rdatatype.RdataType.make(covers) - for (name, node) in self.items(): + for name, node in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or ( rds.rdtype == rdtype and rds.covers == covers @@ -795,7 +794,7 @@ class Zone(dns.transaction.TransactionManager): assert self.origin is not None origin_name = self.origin hasher = hashinfo() - for (name, node) in sorted(self.items()): + for name, node in sorted(self.items()): rrnamebuf = name.to_digestable(self.origin) for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)): if name == origin_name and dns.rdatatype.ZONEMD in ( @@ -997,6 +996,9 @@ class Version: return None return node.get_rdataset(self.zone.rdclass, rdtype, covers) + def keys(self): + return self.nodes.keys() + def items(self): return self.nodes.items() @@ -1143,10 +1145,13 @@ class Transaction(dns.transaction.Transaction): self.version.origin = origin def _iterate_rdatasets(self): - for (name, node) in self.version.items(): + for name, node in self.version.items(): for rdataset in node: yield (name, rdataset) + def _iterate_names(self): + return self.version.keys() + def _get_node(self, name): return self.version.get_node(name) diff --git a/lib/dns/zonefile.py b/lib/dns/zonefile.py index 1a53f5bc..27f04924 100644 --- a/lib/dns/zonefile.py +++ b/lib/dns/zonefile.py @@ -17,23 +17,22 @@ """DNS Zones.""" -from typing import Any, Iterable, List, Optional, Set, Tuple, Union - import re import sys +from typing import Any, Iterable, List, Optional, Set, Tuple, Union import dns.exception +import dns.grange import dns.name import dns.node +import dns.rdata import dns.rdataclass import dns.rdatatype -import dns.rdata import dns.rdtypes.ANY.SOA import dns.rrset import dns.tokenizer import dns.transaction import dns.ttl -import dns.grange class UnknownOrigin(dns.exception.DNSException): @@ -191,10 +190,6 @@ class Reader: self.last_ttl_known = True token = None except dns.ttl.BadTTL: - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl self.tok.unget(token) # Class @@ -212,6 +207,22 @@ class Reader: if rdclass != self.zone_rdclass: raise dns.exception.SyntaxError("RR class is not zone's class") + if ttl is None: + # support for syntax + token = self._get_identifier() + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = None + except dns.ttl.BadTTL: + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + self.tok.unget(token) + # Type if self.force_rdtype is not None: rdtype = self.force_rdtype @@ -581,7 +592,7 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): pass def _name_exists(self, name): - for (n, _, _) in self.rdatasets: + for n, _, _ in self.rdatasets: if n == name: return True return False @@ -606,6 +617,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction): def _iterate_rdatasets(self): raise NotImplementedError # pragma: no cover + def _iterate_names(self): + raise NotImplementedError # pragma: no cover + class RRSetsReaderManager(dns.transaction.TransactionManager): def __init__( @@ -707,26 +721,26 @@ def read_rrsets( if isinstance(default_ttl, str): default_ttl = dns.ttl.from_text(default_ttl) if rdclass is not None: - the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdclass = dns.rdataclass.RdataClass.make(rdclass) else: - the_rdclass = None - the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + rdclass = None + default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) if rdtype is not None: - the_rdtype = dns.rdatatype.RdataType.make(rdtype) + rdtype = dns.rdatatype.RdataType.make(rdtype) else: - the_rdtype = None + rdtype = None manager = RRSetsReaderManager(origin, relativize, default_rdclass) with manager.writer(True) as txn: tok = dns.tokenizer.Tokenizer(text, "", idna_codec=idna_codec) reader = Reader( tok, - the_default_rdclass, + default_rdclass, txn, allow_directives=False, force_name=name, force_ttl=ttl, - force_rdclass=the_rdclass, - force_rdtype=the_rdtype, + force_rdclass=rdclass, + force_rdtype=rdtype, default_ttl=default_ttl, ) reader.read() diff --git a/requirements.txt b/requirements.txt index 03b49577..d4e4442f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ cheroot==9.0.0 cherrypy==18.8.0 cloudinary==1.34.0 distro==1.8.0 -dnspython==2.3.0 +dnspython==2.4.2 facebook-sdk==3.1.0 future==0.18.3 ga4mp==2.0.4 From 69d052f758511c5b4523ea5210a433ce14f8c30d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:05:25 -0700 Subject: [PATCH 02/10] Bump zipp from 3.15.0 to 3.16.2 (#2124) * Bump zipp from 3.15.0 to 3.16.2 Bumps [zipp](https://github.com/jaraco/zipp) from 3.15.0 to 3.16.2. - [Release notes](https://github.com/jaraco/zipp/releases) - [Changelog](https://github.com/jaraco/zipp/blob/main/NEWS.rst) - [Commits](https://github.com/jaraco/zipp/compare/v3.15.0...v3.16.2) --- updated-dependencies: - dependency-name: zipp dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update zipp==3.16.2 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/zipp/__init__.py | 30 ++++++++++++------------------ lib/zipp/glob.py | 40 ++++++++++++++++++++++++++++++++++++++++ lib/zipp/py310compat.py | 5 ++--- requirements.txt | 2 +- 4 files changed, 55 insertions(+), 22 deletions(-) create mode 100644 lib/zipp/glob.py diff --git a/lib/zipp/__init__.py b/lib/zipp/__init__.py index ddfa0a64..3354c2bb 100644 --- a/lib/zipp/__init__.py +++ b/lib/zipp/__init__.py @@ -5,9 +5,9 @@ import itertools import contextlib import pathlib import re -import fnmatch from .py310compat import text_encoding +from .glob import translate __all__ = ['Path'] @@ -298,21 +298,24 @@ class Path: encoding, args, kwargs = _extract_text_encoding(*args, **kwargs) return io.TextIOWrapper(stream, encoding, *args, **kwargs) + def _base(self): + return pathlib.PurePosixPath(self.at or self.root.filename) + @property def name(self): - return pathlib.Path(self.at).name or self.filename.name + return self._base().name @property def suffix(self): - return pathlib.Path(self.at).suffix or self.filename.suffix + return self._base().suffix @property def suffixes(self): - return pathlib.Path(self.at).suffixes or self.filename.suffixes + return self._base().suffixes @property def stem(self): - return pathlib.Path(self.at).stem or self.filename.stem + return self._base().stem @property def filename(self): @@ -349,7 +352,7 @@ class Path: return filter(self._is_child, subs) def match(self, path_pattern): - return pathlib.Path(self.at).match(path_pattern) + return pathlib.PurePosixPath(self.at).match(path_pattern) def is_symlink(self): """ @@ -357,22 +360,13 @@ class Path: """ return False - def _descendants(self): - for child in self.iterdir(): - yield child - if child.is_dir(): - yield from child._descendants() - def glob(self, pattern): if not pattern: raise ValueError(f"Unacceptable pattern: {pattern!r}") - matches = re.compile(fnmatch.translate(pattern)).fullmatch - return ( - child - for child in self._descendants() - if matches(str(child.relative_to(self))) - ) + prefix = re.escape(self.at) + matches = re.compile(prefix + translate(pattern)).fullmatch + return map(self._next, filter(matches, self.root.namelist())) def rglob(self, pattern): return self.glob(f'**/{pattern}') diff --git a/lib/zipp/glob.py b/lib/zipp/glob.py new file mode 100644 index 00000000..4a2e665e --- /dev/null +++ b/lib/zipp/glob.py @@ -0,0 +1,40 @@ +import re + + +def translate(pattern): + r""" + Given a glob pattern, produce a regex that matches it. + + >>> translate('*.txt') + '[^/]*\\.txt' + >>> translate('a?txt') + 'a.txt' + >>> translate('**/*') + '.*/[^/]*' + """ + return ''.join(map(replace, separate(pattern))) + + +def separate(pattern): + """ + Separate out character sets to avoid translating their contents. + + >>> [m.group(0) for m in separate('*.txt')] + ['*.txt'] + >>> [m.group(0) for m in separate('a[?]txt')] + ['a', '[?]', 'txt'] + """ + return re.finditer(r'([^\[]+)|(?P[\[].*?[\]])|([\[][^\]]*$)', pattern) + + +def replace(match): + """ + Perform the replacements for a match from :func:`separate`. + """ + + return match.group('set') or ( + re.escape(match.group(0)) + .replace('\\*\\*', r'.*') + .replace('\\*', r'[^/]*') + .replace('\\?', r'.') + ) diff --git a/lib/zipp/py310compat.py b/lib/zipp/py310compat.py index 8244124c..d5ca53e0 100644 --- a/lib/zipp/py310compat.py +++ b/lib/zipp/py310compat.py @@ -2,9 +2,8 @@ import sys import io -te_impl = 'lambda encoding, stacklevel=2, /: encoding' -te_impl_37 = te_impl.replace(', /', '') -_text_encoding = eval(te_impl) if sys.version_info > (3, 8) else eval(te_impl_37) +def _text_encoding(encoding, stacklevel=2, /): # pragma: no cover + return encoding text_encoding = ( diff --git a/requirements.txt b/requirements.txt index d4e4442f..7157e7b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,7 +49,7 @@ urllib3<2 webencodings==0.5.1 websocket-client==1.6.2 xmltodict==0.13.0 -zipp==3.15.0 +zipp==3.16.2 # configobj==5.1.0 # sgmllib3k==1.0.0 From 9383d5120c4232388f5a12e013a531bdca587587 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:05:41 -0700 Subject: [PATCH 03/10] Bump portend from 3.1.0 to 3.2.0 (#2125) Bumps [portend](https://github.com/jaraco/portend) from 3.1.0 to 3.2.0. - [Release notes](https://github.com/jaraco/portend/releases) - [Changelog](https://github.com/jaraco/portend/blob/main/NEWS.rst) - [Commits](https://github.com/jaraco/portend/compare/v3.1.0...v3.2.0) --- updated-dependencies: - dependency-name: portend dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> [skip ci] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7157e7b6..82b2e93d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,7 +29,7 @@ musicbrainzngs==0.7.1 packaging==23.1 paho-mqtt==1.6.1 plexapi==4.13.4 -portend==3.1.0 +portend==3.2.0 profilehooks==1.12.0 PyJWT==2.8.0 pyparsing==3.0.9 From 72f1ce786543d832a49b6fd1cff5cbd1597ef468 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:05:53 -0700 Subject: [PATCH 04/10] Bump importlib-resources from 5.12.0 to 6.0.1 (#2126) * Bump importlib-resources from 5.12.0 to 6.0.1 Bumps [importlib-resources](https://github.com/python/importlib_resources) from 5.12.0 to 6.0.1. - [Release notes](https://github.com/python/importlib_resources/releases) - [Changelog](https://github.com/python/importlib_resources/blob/main/NEWS.rst) - [Commits](https://github.com/python/importlib_resources/compare/v5.12.0...v6.0.1) --- updated-dependencies: - dependency-name: importlib-resources dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * Update importlib-resources==6.0.1 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/importlib_resources/__init__.py | 19 ----- lib/importlib_resources/_legacy.py | 120 ---------------------------- package/requirements-package.txt | 2 +- requirements.txt | 2 +- 4 files changed, 2 insertions(+), 141 deletions(-) delete mode 100644 lib/importlib_resources/_legacy.py diff --git a/lib/importlib_resources/__init__.py b/lib/importlib_resources/__init__.py index 34e3a995..e6b60c18 100644 --- a/lib/importlib_resources/__init__.py +++ b/lib/importlib_resources/__init__.py @@ -6,31 +6,12 @@ from ._common import ( Package, ) -from ._legacy import ( - contents, - open_binary, - read_binary, - open_text, - read_text, - is_resource, - path, - Resource, -) - from .abc import ResourceReader __all__ = [ 'Package', - 'Resource', 'ResourceReader', 'as_file', - 'contents', 'files', - 'is_resource', - 'open_binary', - 'open_text', - 'path', - 'read_binary', - 'read_text', ] diff --git a/lib/importlib_resources/_legacy.py b/lib/importlib_resources/_legacy.py deleted file mode 100644 index b1ea8105..00000000 --- a/lib/importlib_resources/_legacy.py +++ /dev/null @@ -1,120 +0,0 @@ -import functools -import os -import pathlib -import types -import warnings - -from typing import Union, Iterable, ContextManager, BinaryIO, TextIO, Any - -from . import _common - -Package = Union[types.ModuleType, str] -Resource = str - - -def deprecated(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - warnings.warn( - f"{func.__name__} is deprecated. Use files() instead. " - "Refer to https://importlib-resources.readthedocs.io" - "/en/latest/using.html#migrating-from-legacy for migration advice.", - DeprecationWarning, - stacklevel=2, - ) - return func(*args, **kwargs) - - return wrapper - - -def normalize_path(path: Any) -> str: - """Normalize a path by ensuring it is a string. - - If the resulting string contains path separators, an exception is raised. - """ - str_path = str(path) - parent, file_name = os.path.split(str_path) - if parent: - raise ValueError(f'{path!r} must be only a file name') - return file_name - - -@deprecated -def open_binary(package: Package, resource: Resource) -> BinaryIO: - """Return a file-like object opened for binary reading of the resource.""" - return (_common.files(package) / normalize_path(resource)).open('rb') - - -@deprecated -def read_binary(package: Package, resource: Resource) -> bytes: - """Return the binary contents of the resource.""" - return (_common.files(package) / normalize_path(resource)).read_bytes() - - -@deprecated -def open_text( - package: Package, - resource: Resource, - encoding: str = 'utf-8', - errors: str = 'strict', -) -> TextIO: - """Return a file-like object opened for text reading of the resource.""" - return (_common.files(package) / normalize_path(resource)).open( - 'r', encoding=encoding, errors=errors - ) - - -@deprecated -def read_text( - package: Package, - resource: Resource, - encoding: str = 'utf-8', - errors: str = 'strict', -) -> str: - """Return the decoded string of the resource. - - The decoding-related arguments have the same semantics as those of - bytes.decode(). - """ - with open_text(package, resource, encoding, errors) as fp: - return fp.read() - - -@deprecated -def contents(package: Package) -> Iterable[str]: - """Return an iterable of entries in `package`. - - Note that not all entries are resources. Specifically, directories are - not considered resources. Use `is_resource()` on each entry returned here - to check if it is a resource or not. - """ - return [path.name for path in _common.files(package).iterdir()] - - -@deprecated -def is_resource(package: Package, name: str) -> bool: - """True if `name` is a resource inside `package`. - - Directories are *not* resources. - """ - resource = normalize_path(name) - return any( - traversable.name == resource and traversable.is_file() - for traversable in _common.files(package).iterdir() - ) - - -@deprecated -def path( - package: Package, - resource: Resource, -) -> ContextManager[pathlib.Path]: - """A context manager providing a file path object to the resource. - - If the resource does not already exist on its own on the file system, - a temporary file will be created. If the file was created, the file - will be deleted upon exiting the context manager (no exception is - raised if the file was deleted prior to the context manager - exiting). - """ - return _common.as_file(_common.files(package) / normalize_path(resource)) diff --git a/package/requirements-package.txt b/package/requirements-package.txt index 5847d597..e0bd159b 100644 --- a/package/requirements-package.txt +++ b/package/requirements-package.txt @@ -1,6 +1,6 @@ apscheduler==3.10.1 importlib-metadata==6.8.0 -importlib-resources==5.12.0 +importlib-resources==6.0.1 pyinstaller==5.13.0 pyopenssl==23.2.0 pycryptodomex==3.18.0 diff --git a/requirements.txt b/requirements.txt index 82b2e93d..5f3de400 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ html5lib==1.1 httpagentparser==1.9.5 idna==3.4 importlib-metadata==6.8.0 -importlib-resources==5.12.0 +importlib-resources==6.0.1 git+https://github.com/Tautulli/ipwhois.git@master#egg=ipwhois IPy==1.01 Mako==1.2.4 From 9423f65a9024c68c5ab6b3bcc2cce52f824dac6f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:06:08 -0700 Subject: [PATCH 05/10] Bump backports-functools-lru-cache from 1.6.4 to 1.6.6 (#2127) * Bump backports-functools-lru-cache from 1.6.4 to 1.6.6 Bumps [backports-functools-lru-cache](https://github.com/jaraco/backports.functools_lru_cache) from 1.6.4 to 1.6.6. - [Release notes](https://github.com/jaraco/backports.functools_lru_cache/releases) - [Changelog](https://github.com/jaraco/backports.functools_lru_cache/blob/main/NEWS.rst) - [Commits](https://github.com/jaraco/backports.functools_lru_cache/compare/v1.6.4...v1.6.6) --- updated-dependencies: - dependency-name: backports-functools-lru-cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Update backports-functools-lru-cache==1.6.6 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/backports/__init__.py | 6 +----- lib/backports/functools_lru_cache.py | 1 - requirements.txt | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/lib/backports/__init__.py b/lib/backports/__init__.py index b09491b2..0d1f7edf 100644 --- a/lib/backports/__init__.py +++ b/lib/backports/__init__.py @@ -1,5 +1 @@ -# A Python "namespace package" http://www.python.org/dev/peps/pep-0382/ -# This always goes inside of a namespace package's __init__.py -from pkgutil import extend_path - -__path__ = extend_path(__path__, __name__) # type: ignore +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/lib/backports/functools_lru_cache.py b/lib/backports/functools_lru_cache.py index 8be4515f..1b83fe99 100644 --- a/lib/backports/functools_lru_cache.py +++ b/lib/backports/functools_lru_cache.py @@ -89,7 +89,6 @@ def lru_cache(maxsize=100, typed=False): # noqa: C901 # to allow the implementation to change (including a possible C version). def decorating_function(user_function): - cache = dict() stats = [0, 0] # make statistics updateable non-locally HITS, MISSES = 0, 1 # names for the stats fields diff --git a/requirements.txt b/requirements.txt index 5f3de400..e8dbb20d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ appdirs==1.4.4 apscheduler==3.10.1 arrow==1.2.3 backports.csv==1.0.7 -backports.functools-lru-cache==1.6.4 +backports.functools-lru-cache==1.6.6 backports.zoneinfo==0.2.1;python_version<"3.9" beautifulsoup4==4.12.2 bleach==6.0.0 From 4033114175244ddb80c6bcf565cf4a12b1731ce2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:09:15 -0700 Subject: [PATCH 06/10] Bump cheroot from 9.0.0 to 10.0.0 (#2128) * Bump cheroot from 9.0.0 to 10.0.0 Bumps [cheroot](https://github.com/cherrypy/cheroot) from 9.0.0 to 10.0.0. - [Release notes](https://github.com/cherrypy/cheroot/releases) - [Changelog](https://github.com/cherrypy/cheroot/blob/main/CHANGES.rst) - [Commits](https://github.com/cherrypy/cheroot/compare/v9.0.0...v10.0.0) --- updated-dependencies: - dependency-name: cheroot dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] * Update cheroot==10.0.0 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/cheroot/_compat.py | 1 + lib/cheroot/_compat.pyi | 1 + lib/cheroot/connections.py | 7 +- lib/cheroot/errors.py | 7 +- lib/cheroot/server.py | 46 +++++++++++- lib/cheroot/server.pyi | 7 +- lib/cheroot/ssl/__init__.pyi | 4 +- lib/cheroot/test/_pytest_plugin.py | 17 ----- lib/cheroot/test/test_conn.py | 3 +- lib/cheroot/test/test_core.py | 6 +- lib/cheroot/test/test_errors.py | 3 +- lib/cheroot/test/test_server.py | 117 +++++++++++++++++++++++++++++ lib/cheroot/test/test_ssl.py | 12 --- lib/cheroot/test/webtest.py | 5 +- lib/cheroot/testing.py | 4 +- lib/cheroot/workers/threadpool.py | 54 +++++++------ lib/cheroot/wsgi.py | 2 + lib/cheroot/wsgi.pyi | 2 +- requirements.txt | 2 +- 19 files changed, 215 insertions(+), 85 deletions(-) diff --git a/lib/cheroot/_compat.py b/lib/cheroot/_compat.py index 20c993de..dbe5c6d2 100644 --- a/lib/cheroot/_compat.py +++ b/lib/cheroot/_compat.py @@ -24,6 +24,7 @@ SYS_PLATFORM = platform.system() IS_WINDOWS = SYS_PLATFORM == 'Windows' IS_LINUX = SYS_PLATFORM == 'Linux' IS_MACOS = SYS_PLATFORM == 'Darwin' +IS_SOLARIS = SYS_PLATFORM == 'SunOS' PLATFORM_ARCH = platform.machine() IS_PPC = PLATFORM_ARCH.startswith('ppc') diff --git a/lib/cheroot/_compat.pyi b/lib/cheroot/_compat.pyi index 023bad8c..67d93cf6 100644 --- a/lib/cheroot/_compat.pyi +++ b/lib/cheroot/_compat.pyi @@ -10,6 +10,7 @@ SYS_PLATFORM: str IS_WINDOWS: bool IS_LINUX: bool IS_MACOS: bool +IS_SOLARIS: bool PLATFORM_ARCH: str IS_PPC: bool diff --git a/lib/cheroot/connections.py b/lib/cheroot/connections.py index 9b6366e5..9346bc6a 100644 --- a/lib/cheroot/connections.py +++ b/lib/cheroot/connections.py @@ -274,8 +274,7 @@ class ConnectionManager: # One of the reason on why a socket could cause an error # is that the socket is already closed, ignore the # socket error if we try to close it at this point. - # This is equivalent to OSError in Py3 - with suppress(socket.error): + with suppress(OSError): conn.close() def _from_server_socket(self, server_socket): # noqa: C901 # FIXME @@ -308,7 +307,7 @@ class ConnectionManager: wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE) try: wfile.write(''.join(buf).encode('ISO-8859-1')) - except socket.error as ex: + except OSError as ex: if ex.args[0] not in errors.socket_errors_to_ignore: raise return @@ -343,7 +342,7 @@ class ConnectionManager: # notice keyboard interrupts on Win32, which don't interrupt # accept() by default return - except socket.error as ex: + except OSError as ex: if self.server.stats['Enabled']: self.server.stats['Socket Errors'] += 1 if ex.args[0] in errors.socket_error_eintr: diff --git a/lib/cheroot/errors.py b/lib/cheroot/errors.py index 046263ad..f6b588c2 100644 --- a/lib/cheroot/errors.py +++ b/lib/cheroot/errors.py @@ -77,9 +77,4 @@ Refs: * https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown """ -try: # py3 - acceptable_sock_shutdown_exceptions = ( - BrokenPipeError, ConnectionResetError, - ) -except NameError: # py2 - acceptable_sock_shutdown_exceptions = () +acceptable_sock_shutdown_exceptions = (BrokenPipeError, ConnectionResetError) diff --git a/lib/cheroot/server.py b/lib/cheroot/server.py index 6b8e37a9..bceeb2c9 100644 --- a/lib/cheroot/server.py +++ b/lib/cheroot/server.py @@ -1572,6 +1572,9 @@ class HTTPServer: ``PEERCREDS``-provided IDs. """ + reuse_port = False + """If True, set SO_REUSEPORT on the socket.""" + keep_alive_conn_limit = 10 """Maximum number of waiting keep-alive connections that will be kept open. @@ -1581,6 +1584,7 @@ class HTTPServer: self, bind_addr, gateway, minthreads=10, maxthreads=-1, server_name=None, peercreds_enabled=False, peercreds_resolve_enabled=False, + reuse_port=False, ): """Initialize HTTPServer instance. @@ -1591,6 +1595,8 @@ class HTTPServer: maxthreads (int): maximum number of threads for HTTP thread pool server_name (str): web server name to be advertised via Server HTTP header + reuse_port (bool): if True SO_REUSEPORT option would be set to + socket """ self.bind_addr = bind_addr self.gateway = gateway @@ -1606,6 +1612,7 @@ class HTTPServer: self.peercreds_resolve_enabled = ( peercreds_resolve_enabled and peercreds_enabled ) + self.reuse_port = reuse_port self.clear_stats() def clear_stats(self): @@ -1880,6 +1887,7 @@ class HTTPServer: self.bind_addr, family, type, proto, self.nodelay, self.ssl_adapter, + self.reuse_port, ) sock = self.socket = self.bind_socket(sock, self.bind_addr) self.bind_addr = self.resolve_real_bind_addr(sock) @@ -1911,9 +1919,6 @@ class HTTPServer: 'remove() argument 1 must be encoded ' 'string without null bytes, not unicode' not in err_msg - and 'embedded NUL character' not in err_msg # py34 - and 'argument must be a ' - 'string without NUL characters' not in err_msg # pypy2 ): raise except ValueError as val_err: @@ -1931,6 +1936,7 @@ class HTTPServer: bind_addr=bind_addr, family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, nodelay=self.nodelay, ssl_adapter=self.ssl_adapter, + reuse_port=self.reuse_port, ) try: @@ -1971,7 +1977,36 @@ class HTTPServer: return sock @staticmethod - def prepare_socket(bind_addr, family, type, proto, nodelay, ssl_adapter): + def _make_socket_reusable(socket_, bind_addr): + host, port = bind_addr[:2] + IS_EPHEMERAL_PORT = port == 0 + + if socket_.family not in (socket.AF_INET, socket.AF_INET6): + raise ValueError('Cannot reuse a non-IP socket') + + if IS_EPHEMERAL_PORT: + raise ValueError('Cannot reuse an ephemeral port (0)') + + # Most BSD kernels implement SO_REUSEPORT the way that only the + # latest listener can read from socket. Some of BSD kernels also + # have SO_REUSEPORT_LB that works similarly to SO_REUSEPORT + # in Linux. + if hasattr(socket, 'SO_REUSEPORT_LB'): + socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT_LB, 1) + elif hasattr(socket, 'SO_REUSEPORT'): + socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + elif IS_WINDOWS: + socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + else: + raise NotImplementedError( + 'Current platform does not support port reuse', + ) + + @classmethod + def prepare_socket( + cls, bind_addr, family, type, proto, nodelay, ssl_adapter, + reuse_port=False, + ): """Create and prepare the socket object.""" sock = socket.socket(family, type, proto) connections.prevent_socket_inheritance(sock) @@ -1979,6 +2014,9 @@ class HTTPServer: host, port = bind_addr[:2] IS_EPHEMERAL_PORT = port == 0 + if reuse_port: + cls._make_socket_reusable(socket_=sock, bind_addr=bind_addr) + if not (IS_WINDOWS or IS_EPHEMERAL_PORT): """Enable SO_REUSEADDR for the current socket. diff --git a/lib/cheroot/server.pyi b/lib/cheroot/server.pyi index 864adff4..ecbe2f27 100644 --- a/lib/cheroot/server.pyi +++ b/lib/cheroot/server.pyi @@ -130,9 +130,10 @@ class HTTPServer: ssl_adapter: Any peercreds_enabled: bool peercreds_resolve_enabled: bool + reuse_port: bool keep_alive_conn_limit: int requests: Any - def __init__(self, bind_addr, gateway, minthreads: int = ..., maxthreads: int = ..., server_name: Any | None = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ...) -> None: ... + def __init__(self, bind_addr, gateway, minthreads: int = ..., maxthreads: int = ..., server_name: Any | None = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ..., reuse_port: bool = ...) -> None: ... stats: Any def clear_stats(self): ... def runtime(self): ... @@ -152,7 +153,9 @@ class HTTPServer: def bind(self, family, type, proto: int = ...): ... def bind_unix_socket(self, bind_addr): ... @staticmethod - def prepare_socket(bind_addr, family, type, proto, nodelay, ssl_adapter): ... + def _make_socket_reusable(socket_, bind_addr) -> None: ... + @classmethod + def prepare_socket(cls, bind_addr, family, type, proto, nodelay, ssl_adapter, reuse_port: bool = ...): ... @staticmethod def bind_socket(socket_, bind_addr): ... @staticmethod diff --git a/lib/cheroot/ssl/__init__.pyi b/lib/cheroot/ssl/__init__.pyi index a9807660..4801fbdd 100644 --- a/lib/cheroot/ssl/__init__.pyi +++ b/lib/cheroot/ssl/__init__.pyi @@ -1,7 +1,7 @@ -from abc import abstractmethod +from abc import abstractmethod, ABCMeta from typing import Any -class Adapter(): +class Adapter(metaclass=ABCMeta): certificate: Any private_key: Any certificate_chain: Any diff --git a/lib/cheroot/test/_pytest_plugin.py b/lib/cheroot/test/_pytest_plugin.py index 8ff3b02c..61f2efe1 100644 --- a/lib/cheroot/test/_pytest_plugin.py +++ b/lib/cheroot/test/_pytest_plugin.py @@ -4,11 +4,7 @@ Contains hooks, which are tightly bound to the Cheroot framework itself, useless for end-users' app testing. """ -from __future__ import absolute_import, division, print_function -__metaclass__ = type - import pytest -import six pytest_version = tuple(map(int, pytest.__version__.split('.'))) @@ -45,16 +41,3 @@ def pytest_load_initial_conftests(early_config, parser, args): 'type=SocketKind.SOCK_STREAM, proto=.:' 'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception', )) - - if six.PY2: - return - - # NOTE: `ResourceWarning` does not exist under Python 2 and so using - # NOTE: it in warning filters results in an `_OptionError` exception - # NOTE: being raised. - early_config._inicache['filterwarnings'].extend(( - # FIXME: Try to figure out what causes this and ensure that the socket - # FIXME: gets closed. - 'ignore:unclosed = resource_limit for fn in native_process_conn.filenos) +@pytest.mark.skipif( + not hasattr(socket, 'SO_REUSEPORT'), + reason='socket.SO_REUSEPORT is not supported on this platform', +) +@pytest.mark.parametrize( + 'ip_addr', + ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + ), +) +def test_reuse_port(http_server, ip_addr, mocker): + """Check that port initialized externally can be reused.""" + family = socket.getaddrinfo(ip_addr, EPHEMERAL_PORT)[0][0] + s = socket.socket(family) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + s.bind((ip_addr, EPHEMERAL_PORT)) + server = HTTPServer( + bind_addr=s.getsockname()[:2], gateway=Gateway, reuse_port=True, + ) + spy = mocker.spy(server, 'prepare') + server.prepare() + server.stop() + s.close() + assert spy.spy_exception is None + + ISSUE511 = IS_MACOS @@ -439,3 +469,90 @@ def many_open_sockets(request, resource_limit): # Close our open resources for test_socket in test_sockets: test_socket.close() + + +@pytest.mark.parametrize( + ('minthreads', 'maxthreads', 'inited_maxthreads'), + ( + ( + # NOTE: The docstring only mentions -1 to mean "no max", but other + # NOTE: negative numbers should also work. + 1, + -2, + float('inf'), + ), + (1, -1, float('inf')), + (1, 1, 1), + (1, 2, 2), + (1, float('inf'), float('inf')), + (2, -2, float('inf')), + (2, -1, float('inf')), + (2, 2, 2), + (2, float('inf'), float('inf')), + ), +) +def test_threadpool_threadrange_set(minthreads, maxthreads, inited_maxthreads): + """Test setting the number of threads in a ThreadPool. + + The ThreadPool should properly set the min+max number of the threads to use + in the pool if those limits are valid. + """ + tp = ThreadPool( + server=None, + min=minthreads, + max=maxthreads, + ) + assert tp.min == minthreads + assert tp.max == inited_maxthreads + + +@pytest.mark.parametrize( + ('minthreads', 'maxthreads', 'error'), + ( + (-1, -1, 'min=-1 must be > 0'), + (-1, 0, 'min=-1 must be > 0'), + (-1, 1, 'min=-1 must be > 0'), + (-1, 2, 'min=-1 must be > 0'), + (0, -1, 'min=0 must be > 0'), + (0, 0, 'min=0 must be > 0'), + (0, 1, 'min=0 must be > 0'), + (0, 2, 'min=0 must be > 0'), + (1, 0, 'Expected an integer or the infinity value for the `max` argument but got 0.'), + (1, 0.5, 'Expected an integer or the infinity value for the `max` argument but got 0.5.'), + (2, 0, 'Expected an integer or the infinity value for the `max` argument but got 0.'), + (2, '1', "Expected an integer or the infinity value for the `max` argument but got '1'."), + (2, 1, 'max=1 must be > min=2'), + ), +) +def test_threadpool_invalid_threadrange(minthreads, maxthreads, error): + """Test that a ThreadPool rejects invalid min/max values. + + The ThreadPool should raise an error with the proper message when + initialized with an invalid min+max number of threads. + """ + with pytest.raises((ValueError, TypeError), match=error): + ThreadPool( + server=None, + min=minthreads, + max=maxthreads, + ) + + +def test_threadpool_multistart_validation(monkeypatch): + """Test for ThreadPool multi-start behavior. + + Tests that when calling start() on a ThreadPool multiple times raises a + :exc:`RuntimeError` + """ + # replace _spawn_worker with a function that returns a placeholder to avoid + # actually starting any threads + monkeypatch.setattr( + ThreadPool, + '_spawn_worker', + lambda _: types.SimpleNamespace(ready=True), + ) + + tp = ThreadPool(server=None) + tp.start() + with pytest.raises(RuntimeError, match='Threadpools can only be started once.'): + tp.start() diff --git a/lib/cheroot/test/test_ssl.py b/lib/cheroot/test/test_ssl.py index c55e156f..1900e20d 100644 --- a/lib/cheroot/test/test_ssl.py +++ b/lib/cheroot/test/test_ssl.py @@ -55,17 +55,6 @@ _stdlib_to_openssl_verify = { } -fails_under_py3 = pytest.mark.xfail( - reason='Fails under Python 3+', -) - - -fails_under_py3_in_pypy = pytest.mark.xfail( - IS_PYPY, - reason='Fails under PyPy3', -) - - missing_ipv6 = pytest.mark.skipif( not _probe_ipv6_sock('::1'), reason='' @@ -556,7 +545,6 @@ def test_ssl_env( # noqa: C901 # FIXME # builtin ssl environment generation may use a loopback socket # ensure no ResourceWarning was raised during the test - # NOTE: python 2.7 does not emit ResourceWarning for ssl sockets if IS_PYPY: # NOTE: PyPy doesn't have ResourceWarning # Ref: https://doc.pypy.org/en/latest/cpython_differences.html diff --git a/lib/cheroot/test/webtest.py b/lib/cheroot/test/webtest.py index 1630c8ef..eafa2dd6 100644 --- a/lib/cheroot/test/webtest.py +++ b/lib/cheroot/test/webtest.py @@ -463,16 +463,13 @@ def shb(response): return resp_status_line, response.getheaders(), response.read() -# def openURL(*args, raise_subcls=(), **kwargs): -# py27 compatible signature: -def openURL(*args, **kwargs): +def openURL(*args, raise_subcls=(), **kwargs): """ Open a URL, retrying when it fails. Specify ``raise_subcls`` (class or tuple of classes) to exclude those socket.error subclasses from being suppressed and retried. """ - raise_subcls = kwargs.pop('raise_subcls', ()) opener = functools.partial(_open_url_once, *args, **kwargs) def on_exception(): diff --git a/lib/cheroot/testing.py b/lib/cheroot/testing.py index 169142bf..3e404e59 100644 --- a/lib/cheroot/testing.py +++ b/lib/cheroot/testing.py @@ -119,9 +119,7 @@ def _probe_ipv6_sock(interface): try: with closing(socket.socket(family=socket.AF_INET6)) as sock: sock.bind((interface, 0)) - except (OSError, socket.error) as sock_err: - # In Python 3 socket.error is an alias for OSError - # In Python 2 socket.error is a subclass of IOError + except OSError as sock_err: if sock_err.errno != errno.EADDRNOTAVAIL: raise else: diff --git a/lib/cheroot/workers/threadpool.py b/lib/cheroot/workers/threadpool.py index 2a9878dc..3437d9bd 100644 --- a/lib/cheroot/workers/threadpool.py +++ b/lib/cheroot/workers/threadpool.py @@ -151,12 +151,33 @@ class ThreadPool: server (cheroot.server.HTTPServer): web server object receiving this request min (int): minimum number of worker threads - max (int): maximum number of worker threads + max (int): maximum number of worker threads (-1/inf for no max) accepted_queue_size (int): maximum number of active requests in queue accepted_queue_timeout (int): timeout for putting request into queue + + :raises ValueError: if the min/max values are invalid + :raises TypeError: if the max is not an integer or inf """ + if min < 1: + raise ValueError(f'min={min!s} must be > 0') + + if max == float('inf'): + pass + elif not isinstance(max, int) or max == 0: + raise TypeError( + 'Expected an integer or the infinity value for the `max` ' + f'argument but got {max!r}.', + ) + elif max < 0: + max = float('inf') + + if max < min: + raise ValueError( + f'max={max!s} must be > min={min!s} (or infinity for no max)', + ) + self.server = server self.min = min self.max = max @@ -167,18 +188,13 @@ class ThreadPool: self._pending_shutdowns = collections.deque() def start(self): - """Start the pool of threads.""" - for _ in range(self.min): - self._threads.append(WorkerThread(self.server)) - for worker in self._threads: - worker.name = ( - 'CP Server {worker_name!s}'. - format(worker_name=worker.name) - ) - worker.start() - for worker in self._threads: - while not worker.ready: - time.sleep(.1) + """Start the pool of threads. + + :raises RuntimeError: if the pool is already started + """ + if self._threads: + raise RuntimeError('Threadpools can only be started once.') + self.grow(self.min) @property def idle(self): # noqa: D401; irrelevant for properties @@ -206,17 +222,13 @@ class ThreadPool: def grow(self, amount): """Spawn new worker threads (not above self.max).""" - if self.max > 0: - budget = max(self.max - len(self._threads), 0) - else: - # self.max <= 0 indicates no maximum - budget = float('inf') - + budget = max(self.max - len(self._threads), 0) n_new = min(amount, budget) workers = [self._spawn_worker() for i in range(n_new)] - while not all(worker.ready for worker in workers): - time.sleep(.1) + for worker in workers: + while not worker.ready: + time.sleep(.1) self._threads.extend(workers) def _spawn_worker(self): diff --git a/lib/cheroot/wsgi.py b/lib/cheroot/wsgi.py index 82faca3e..1dbe10ee 100644 --- a/lib/cheroot/wsgi.py +++ b/lib/cheroot/wsgi.py @@ -43,6 +43,7 @@ class Server(server.HTTPServer): max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, accepted_queue_size=-1, accepted_queue_timeout=10, peercreds_enabled=False, peercreds_resolve_enabled=False, + reuse_port=False, ): """Initialize WSGI Server instance. @@ -69,6 +70,7 @@ class Server(server.HTTPServer): server_name=server_name, peercreds_enabled=peercreds_enabled, peercreds_resolve_enabled=peercreds_resolve_enabled, + reuse_port=reuse_port, ) self.wsgi_app = wsgi_app self.request_queue_size = request_queue_size diff --git a/lib/cheroot/wsgi.pyi b/lib/cheroot/wsgi.pyi index 96075633..f96a18f9 100644 --- a/lib/cheroot/wsgi.pyi +++ b/lib/cheroot/wsgi.pyi @@ -8,7 +8,7 @@ class Server(server.HTTPServer): timeout: Any shutdown_timeout: Any requests: Any - def __init__(self, bind_addr, wsgi_app, numthreads: int = ..., server_name: Any | None = ..., max: int = ..., request_queue_size: int = ..., timeout: int = ..., shutdown_timeout: int = ..., accepted_queue_size: int = ..., accepted_queue_timeout: int = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ...) -> None: ... + def __init__(self, bind_addr, wsgi_app, numthreads: int = ..., server_name: Any | None = ..., max: int = ..., request_queue_size: int = ..., timeout: int = ..., shutdown_timeout: int = ..., accepted_queue_size: int = ..., accepted_queue_timeout: int = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ..., reuse_port: bool = ...) -> None: ... @property def numthreads(self): ... @numthreads.setter diff --git a/requirements.txt b/requirements.txt index e8dbb20d..5abbaaa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ backports.zoneinfo==0.2.1;python_version<"3.9" beautifulsoup4==4.12.2 bleach==6.0.0 certifi==2023.7.22 -cheroot==9.0.0 +cheroot==10.0.0 cherrypy==18.8.0 cloudinary==1.34.0 distro==1.8.0 From 371d35433c38c92177ca09beb0ab7dd9d02842cb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:09:28 -0700 Subject: [PATCH 07/10] Bump tokenize-rt from 5.0.0 to 5.2.0 (#2129) * Bump tokenize-rt from 5.0.0 to 5.2.0 Bumps [tokenize-rt](https://github.com/asottile/tokenize-rt) from 5.0.0 to 5.2.0. - [Commits](https://github.com/asottile/tokenize-rt/compare/v5.0.0...v5.2.0) --- updated-dependencies: - dependency-name: tokenize-rt dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update tokenize-rt==5.2.0 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/tokenize_rt.py | 5 ++++- requirements.txt | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/tokenize_rt.py b/lib/tokenize_rt.py index f0922c16..ae91cc40 100644 --- a/lib/tokenize_rt.py +++ b/lib/tokenize_rt.py @@ -18,7 +18,7 @@ if ( # pragma: no branch callable(getattr(tokenize, '_compile', None)) ): # pragma: <3.10 cover from functools import lru_cache - tokenize._compile = lru_cache()(tokenize._compile) + tokenize._compile = lru_cache(tokenize._compile) ESCAPED_NL = 'ESCAPED_NL' UNIMPORTANT_WS = 'UNIMPORTANT_WS' @@ -40,6 +40,9 @@ class Token(NamedTuple): def offset(self) -> Offset: return Offset(self.line, self.utf8_byte_offset) + def matches(self, *, name: str, src: str) -> bool: + return self.name == name and self.src == src + _string_re = re.compile('^([^\'"]*)(.*)$', re.DOTALL) _escaped_nl_re = re.compile(r'\\(\n|\r\n|\r)') diff --git a/requirements.txt b/requirements.txt index 5abbaaa8..8fd57e4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,7 +42,7 @@ rumps==0.4.0; platform_system == "Darwin" simplejson==3.19.1 six==1.16.0 tempora==5.5.0 -tokenize-rt==5.0.0 +tokenize-rt==5.2.0 tzdata==2023.3 tzlocal==4.2 urllib3<2 From d0c7f25a3f9361efcef3ecaeb70b340d897ddcac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:09:39 -0700 Subject: [PATCH 08/10] Bump markupsafe from 2.1.2 to 2.1.3 (#2130) * Bump markupsafe from 2.1.2 to 2.1.3 Bumps [markupsafe](https://github.com/pallets/markupsafe) from 2.1.2 to 2.1.3. - [Release notes](https://github.com/pallets/markupsafe/releases) - [Changelog](https://github.com/pallets/markupsafe/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/markupsafe/compare/2.1.2...2.1.3) --- updated-dependencies: - dependency-name: markupsafe dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Update markupsave==2.1.3 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/markupsafe/__init__.py | 109 ++++++++++++++++++++----------------- requirements.txt | 2 +- 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/lib/markupsafe/__init__.py b/lib/markupsafe/__init__.py index 7166b192..21d31960 100644 --- a/lib/markupsafe/__init__.py +++ b/lib/markupsafe/__init__.py @@ -1,6 +1,7 @@ import functools import re import string +import sys import typing as t if t.TYPE_CHECKING: @@ -10,23 +11,23 @@ if t.TYPE_CHECKING: def __html__(self) -> str: pass + _P = te.ParamSpec("_P") -__version__ = "2.1.2" + +__version__ = "2.1.3" _strip_comments_re = re.compile(r"", re.DOTALL) _strip_tags_re = re.compile(r"<.*?>", re.DOTALL) -def _simple_escaping_wrapper(name: str) -> t.Callable[..., "Markup"]: - orig = getattr(str, name) - - @functools.wraps(orig) - def wrapped(self: "Markup", *args: t.Any, **kwargs: t.Any) -> "Markup": - args = _escape_argspec(list(args), enumerate(args), self.escape) # type: ignore +def _simple_escaping_wrapper(func: "t.Callable[_P, str]") -> "t.Callable[_P, Markup]": + @functools.wraps(func) + def wrapped(self: "Markup", *args: "_P.args", **kwargs: "_P.kwargs") -> "Markup": + arg_list = _escape_argspec(list(args), enumerate(args), self.escape) _escape_argspec(kwargs, kwargs.items(), self.escape) - return self.__class__(orig(self, *args, **kwargs)) + return self.__class__(func(self, *arg_list, **kwargs)) # type: ignore[arg-type] - return wrapped + return wrapped # type: ignore[return-value] class Markup(str): @@ -69,7 +70,7 @@ class Markup(str): def __new__( cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict" - ) -> "Markup": + ) -> "te.Self": if hasattr(base, "__html__"): base = base.__html__() @@ -78,22 +79,22 @@ class Markup(str): return super().__new__(cls, base, encoding, errors) - def __html__(self) -> "Markup": + def __html__(self) -> "te.Self": return self - def __add__(self, other: t.Union[str, "HasHTML"]) -> "Markup": + def __add__(self, other: t.Union[str, "HasHTML"]) -> "te.Self": if isinstance(other, str) or hasattr(other, "__html__"): return self.__class__(super().__add__(self.escape(other))) return NotImplemented - def __radd__(self, other: t.Union[str, "HasHTML"]) -> "Markup": + def __radd__(self, other: t.Union[str, "HasHTML"]) -> "te.Self": if isinstance(other, str) or hasattr(other, "__html__"): return self.escape(other).__add__(self) return NotImplemented - def __mul__(self, num: "te.SupportsIndex") -> "Markup": + def __mul__(self, num: "te.SupportsIndex") -> "te.Self": if isinstance(num, int): return self.__class__(super().__mul__(num)) @@ -101,7 +102,7 @@ class Markup(str): __rmul__ = __mul__ - def __mod__(self, arg: t.Any) -> "Markup": + def __mod__(self, arg: t.Any) -> "te.Self": if isinstance(arg, tuple): # a tuple of arguments, each wrapped arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg) @@ -117,26 +118,28 @@ class Markup(str): def __repr__(self) -> str: return f"{self.__class__.__name__}({super().__repr__()})" - def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "Markup": + def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "te.Self": return self.__class__(super().join(map(self.escape, seq))) join.__doc__ = str.join.__doc__ - def split( # type: ignore + def split( # type: ignore[override] self, sep: t.Optional[str] = None, maxsplit: int = -1 - ) -> t.List["Markup"]: + ) -> t.List["te.Self"]: return [self.__class__(v) for v in super().split(sep, maxsplit)] split.__doc__ = str.split.__doc__ - def rsplit( # type: ignore + def rsplit( # type: ignore[override] self, sep: t.Optional[str] = None, maxsplit: int = -1 - ) -> t.List["Markup"]: + ) -> t.List["te.Self"]: return [self.__class__(v) for v in super().rsplit(sep, maxsplit)] rsplit.__doc__ = str.rsplit.__doc__ - def splitlines(self, keepends: bool = False) -> t.List["Markup"]: # type: ignore + def splitlines( # type: ignore[override] + self, keepends: bool = False + ) -> t.List["te.Self"]: return [self.__class__(v) for v in super().splitlines(keepends)] splitlines.__doc__ = str.splitlines.__doc__ @@ -163,10 +166,10 @@ class Markup(str): value = _strip_comments_re.sub("", self) value = _strip_tags_re.sub("", value) value = " ".join(value.split()) - return Markup(value).unescape() + return self.__class__(value).unescape() @classmethod - def escape(cls, s: t.Any) -> "Markup": + def escape(cls, s: t.Any) -> "te.Self": """Escape a string. Calls :func:`escape` and ensures that for subclasses the correct type is returned. """ @@ -175,45 +178,51 @@ class Markup(str): if rv.__class__ is not cls: return cls(rv) - return rv + return rv # type: ignore[return-value] - for method in ( - "__getitem__", - "capitalize", - "title", - "lower", - "upper", - "replace", - "ljust", - "rjust", - "lstrip", - "rstrip", - "center", - "strip", - "translate", - "expandtabs", - "swapcase", - "zfill", - ): - locals()[method] = _simple_escaping_wrapper(method) + __getitem__ = _simple_escaping_wrapper(str.__getitem__) + capitalize = _simple_escaping_wrapper(str.capitalize) + title = _simple_escaping_wrapper(str.title) + lower = _simple_escaping_wrapper(str.lower) + upper = _simple_escaping_wrapper(str.upper) + replace = _simple_escaping_wrapper(str.replace) + ljust = _simple_escaping_wrapper(str.ljust) + rjust = _simple_escaping_wrapper(str.rjust) + lstrip = _simple_escaping_wrapper(str.lstrip) + rstrip = _simple_escaping_wrapper(str.rstrip) + center = _simple_escaping_wrapper(str.center) + strip = _simple_escaping_wrapper(str.strip) + translate = _simple_escaping_wrapper(str.translate) + expandtabs = _simple_escaping_wrapper(str.expandtabs) + swapcase = _simple_escaping_wrapper(str.swapcase) + zfill = _simple_escaping_wrapper(str.zfill) + casefold = _simple_escaping_wrapper(str.casefold) - del method + if sys.version_info >= (3, 9): + removeprefix = _simple_escaping_wrapper(str.removeprefix) + removesuffix = _simple_escaping_wrapper(str.removesuffix) - def partition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]: + def partition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]: l, s, r = super().partition(self.escape(sep)) cls = self.__class__ return cls(l), cls(s), cls(r) - def rpartition(self, sep: str) -> t.Tuple["Markup", "Markup", "Markup"]: + def rpartition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]: l, s, r = super().rpartition(self.escape(sep)) cls = self.__class__ return cls(l), cls(s), cls(r) - def format(self, *args: t.Any, **kwargs: t.Any) -> "Markup": + def format(self, *args: t.Any, **kwargs: t.Any) -> "te.Self": formatter = EscapeFormatter(self.escape) return self.__class__(formatter.vformat(self, args, kwargs)) - def __html_format__(self, format_spec: str) -> "Markup": + def format_map( # type: ignore[override] + self, map: t.Mapping[str, t.Any] + ) -> "te.Self": + formatter = EscapeFormatter(self.escape) + return self.__class__(formatter.vformat(self, (), map)) + + def __html_format__(self, format_spec: str) -> "te.Self": if format_spec: raise ValueError("Unsupported format specification for Markup.") @@ -268,8 +277,8 @@ class _MarkupEscapeHelper: self.obj = obj self.escape = escape - def __getitem__(self, item: t.Any) -> "_MarkupEscapeHelper": - return _MarkupEscapeHelper(self.obj[item], self.escape) + def __getitem__(self, item: t.Any) -> "te.Self": + return self.__class__(self.obj[item], self.escape) def __str__(self) -> str: return str(self.escape(self.obj)) diff --git a/requirements.txt b/requirements.txt index 8fd57e4d..dc879875 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ importlib-resources==6.0.1 git+https://github.com/Tautulli/ipwhois.git@master#egg=ipwhois IPy==1.01 Mako==1.2.4 -MarkupSafe==2.1.2 +MarkupSafe==2.1.3 musicbrainzngs==0.7.1 packaging==23.1 paho-mqtt==1.6.1 From 3debeada2a040b3a1be5749c2df39d603f97e931 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:09:51 -0700 Subject: [PATCH 09/10] Bump pyparsing from 3.0.9 to 3.1.1 (#2131) * Bump pyparsing from 3.0.9 to 3.1.1 Bumps [pyparsing](https://github.com/pyparsing/pyparsing) from 3.0.9 to 3.1.1. - [Release notes](https://github.com/pyparsing/pyparsing/releases) - [Changelog](https://github.com/pyparsing/pyparsing/blob/master/CHANGES) - [Commits](https://github.com/pyparsing/pyparsing/compare/pyparsing_3.0.9...3.1.1) --- updated-dependencies: - dependency-name: pyparsing dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update pyparsing==3.1.1 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/pyparsing/__init__.py | 78 +- lib/pyparsing/actions.py | 34 +- lib/pyparsing/common.py | 58 +- lib/pyparsing/core.py | 1295 ++++++++++++++++++----------- lib/pyparsing/diagram/__init__.py | 32 +- lib/pyparsing/exceptions.py | 64 +- lib/pyparsing/helpers.py | 196 +++-- lib/pyparsing/results.py | 128 ++- lib/pyparsing/testing.py | 24 +- lib/pyparsing/unicode.py | 103 +-- lib/pyparsing/util.py | 89 +- requirements.txt | 2 +- 12 files changed, 1306 insertions(+), 797 deletions(-) diff --git a/lib/pyparsing/__init__.py b/lib/pyparsing/__init__.py index 7802ff15..3dbc3cf8 100644 --- a/lib/pyparsing/__init__.py +++ b/lib/pyparsing/__init__.py @@ -56,7 +56,7 @@ self-explanatory class names, and the use of :class:`'+'`, :class:`'|'`, :class:`'^'` and :class:`'&'` operators. The :class:`ParseResults` object returned from -:class:`ParserElement.parseString` can be +:class:`ParserElement.parse_string` can be accessed as a nested list, a dictionary, or an object with named attributes. @@ -85,11 +85,11 @@ classes inherit from. Use the docstrings for examples of how to: and :class:`'&'` operators to combine simple expressions into more complex ones - associate names with your parsed results using - :class:`ParserElement.setResultsName` + :class:`ParserElement.set_results_name` - access the parsed data, which is returned as a :class:`ParseResults` object - - find some helpful expression short-cuts like :class:`delimitedList` - and :class:`oneOf` + - find some helpful expression short-cuts like :class:`DelimitedList` + and :class:`one_of` - find more useful common expressions in the :class:`pyparsing_common` namespace class """ @@ -106,30 +106,22 @@ class version_info(NamedTuple): @property def __version__(self): return ( - "{}.{}.{}".format(self.major, self.minor, self.micro) + f"{self.major}.{self.minor}.{self.micro}" + ( - "{}{}{}".format( - "r" if self.releaselevel[0] == "c" else "", - self.releaselevel[0], - self.serial, - ), + f"{'r' if self.releaselevel[0] == 'c' else ''}{self.releaselevel[0]}{self.serial}", "", )[self.releaselevel == "final"] ) def __str__(self): - return "{} {} / {}".format(__name__, self.__version__, __version_time__) + return f"{__name__} {self.__version__} / {__version_time__}" def __repr__(self): - return "{}.{}({})".format( - __name__, - type(self).__name__, - ", ".join("{}={!r}".format(*nv) for nv in zip(self._fields, self)), - ) + return f"{__name__}.{type(self).__name__}({', '.join('{}={!r}'.format(*nv) for nv in zip(self._fields, self))})" -__version_info__ = version_info(3, 0, 9, "final", 0) -__version_time__ = "05 May 2022 07:02 UTC" +__version_info__ = version_info(3, 1, 1, "final", 1) +__version_time__ = "29 Jul 2023 22:27 UTC" __version__ = __version_info__.__version__ __versionTime__ = __version_time__ __author__ = "Paul McGuire " @@ -139,9 +131,9 @@ from .exceptions import * from .actions import * from .core import __diag__, __compat__ from .results import * -from .core import * +from .core import * # type: ignore[misc, assignment] from .core import _builtin_exprs as core_builtin_exprs -from .helpers import * +from .helpers import * # type: ignore[misc, assignment] from .helpers import _builtin_exprs as helper_builtin_exprs from .unicode import unicode_set, UnicodeRangeList, pyparsing_unicode as unicode @@ -153,11 +145,11 @@ from .common import ( # define backward compat synonyms if "pyparsing_unicode" not in globals(): - pyparsing_unicode = unicode + pyparsing_unicode = unicode # type: ignore[misc] if "pyparsing_common" not in globals(): - pyparsing_common = common + pyparsing_common = common # type: ignore[misc] if "pyparsing_test" not in globals(): - pyparsing_test = testing + pyparsing_test = testing # type: ignore[misc] core_builtin_exprs += common_builtin_exprs + helper_builtin_exprs @@ -174,7 +166,9 @@ __all__ = [ "CaselessKeyword", "CaselessLiteral", "CharsNotIn", + "CloseMatch", "Combine", + "DelimitedList", "Dict", "Each", "Empty", @@ -227,9 +221,11 @@ __all__ = [ "alphas8bit", "any_close_tag", "any_open_tag", + "autoname_elements", "c_style_comment", "col", "common_html_entity", + "condition_as_parse_action", "counted_array", "cpp_style_comment", "dbl_quoted_string", @@ -241,6 +237,7 @@ __all__ = [ "html_comment", "identchars", "identbodychars", + "infix_notation", "java_style_comment", "line", "line_end", @@ -255,8 +252,12 @@ __all__ = [ "null_debug_action", "nums", "one_of", + "original_text_for", "printables", "punc8bit", + "pyparsing_common", + "pyparsing_test", + "pyparsing_unicode", "python_style_comment", "quoted_string", "remove_quotes", @@ -267,28 +268,20 @@ __all__ = [ "srange", "string_end", "string_start", + "token_map", "trace_parse_action", + "ungroup", + "unicode_set", "unicode_string", "with_attribute", - "indentedBlock", - "original_text_for", - "ungroup", - "infix_notation", - "locatedExpr", "with_class", - "CloseMatch", - "token_map", - "pyparsing_common", - "pyparsing_unicode", - "unicode_set", - "condition_as_parse_action", - "pyparsing_test", # pre-PEP8 compatibility names "__versionTime__", "anyCloseTag", "anyOpenTag", "cStyleComment", "commonHTMLEntity", + "conditionAsParseAction", "countedArray", "cppStyleComment", "dblQuotedString", @@ -296,9 +289,12 @@ __all__ = [ "delimitedList", "dictOf", "htmlComment", + "indentedBlock", + "infixNotation", "javaStyleComment", "lineEnd", "lineStart", + "locatedExpr", "makeHTMLTags", "makeXMLTags", "matchOnlyAtCol", @@ -308,6 +304,7 @@ __all__ = [ "nullDebugAction", "oneOf", "opAssoc", + "originalTextFor", "pythonStyleComment", "quotedString", "removeQuotes", @@ -317,15 +314,12 @@ __all__ = [ "sglQuotedString", "stringEnd", "stringStart", + "tokenMap", "traceParseAction", "unicodeString", "withAttribute", - "indentedBlock", - "originalTextFor", - "infixNotation", - "locatedExpr", "withClass", - "tokenMap", - "conditionAsParseAction", - "autoname_elements", + "common", + "unicode", + "testing", ] diff --git a/lib/pyparsing/actions.py b/lib/pyparsing/actions.py index f72c66e7..ca6e4c6a 100644 --- a/lib/pyparsing/actions.py +++ b/lib/pyparsing/actions.py @@ -1,7 +1,7 @@ # actions.py from .exceptions import ParseException -from .util import col +from .util import col, replaced_by_pep8 class OnlyOnce: @@ -38,7 +38,7 @@ def match_only_at_col(n): def verify_col(strg, locn, toks): if col(locn, strg) != n: - raise ParseException(strg, locn, "matched token not at column {}".format(n)) + raise ParseException(strg, locn, f"matched token not at column {n}") return verify_col @@ -148,15 +148,13 @@ def with_attribute(*args, **attr_dict): raise ParseException( s, l, - "attribute {!r} has value {!r}, must be {!r}".format( - attrName, tokens[attrName], attrValue - ), + f"attribute {attrName!r} has value {tokens[attrName]!r}, must be {attrValue!r}", ) return pa -with_attribute.ANY_VALUE = object() +with_attribute.ANY_VALUE = object() # type: ignore [attr-defined] def with_class(classname, namespace=""): @@ -195,13 +193,25 @@ def with_class(classname, namespace=""): 1 4 0 1 0 1,3 2,3 1,1 """ - classattr = "{}:class".format(namespace) if namespace else "class" + classattr = f"{namespace}:class" if namespace else "class" return with_attribute(**{classattr: classname}) # pre-PEP8 compatibility symbols -replaceWith = replace_with -removeQuotes = remove_quotes -withAttribute = with_attribute -withClass = with_class -matchOnlyAtCol = match_only_at_col +# fmt: off +@replaced_by_pep8(replace_with) +def replaceWith(): ... + +@replaced_by_pep8(remove_quotes) +def removeQuotes(): ... + +@replaced_by_pep8(with_attribute) +def withAttribute(): ... + +@replaced_by_pep8(with_class) +def withClass(): ... + +@replaced_by_pep8(match_only_at_col) +def matchOnlyAtCol(): ... + +# fmt: on diff --git a/lib/pyparsing/common.py b/lib/pyparsing/common.py index 1859fb79..7a666b27 100644 --- a/lib/pyparsing/common.py +++ b/lib/pyparsing/common.py @@ -1,6 +1,6 @@ # common.py from .core import * -from .helpers import delimited_list, any_open_tag, any_close_tag +from .helpers import DelimitedList, any_open_tag, any_close_tag from datetime import datetime @@ -22,17 +22,17 @@ class pyparsing_common: Parse actions: - - :class:`convertToInteger` - - :class:`convertToFloat` - - :class:`convertToDate` - - :class:`convertToDatetime` - - :class:`stripHTMLTags` - - :class:`upcaseTokens` - - :class:`downcaseTokens` + - :class:`convert_to_integer` + - :class:`convert_to_float` + - :class:`convert_to_date` + - :class:`convert_to_datetime` + - :class:`strip_html_tags` + - :class:`upcase_tokens` + - :class:`downcase_tokens` Example:: - pyparsing_common.number.runTests(''' + pyparsing_common.number.run_tests(''' # any int or real number, returned as the appropriate type 100 -100 @@ -42,7 +42,7 @@ class pyparsing_common: 1e-12 ''') - pyparsing_common.fnumber.runTests(''' + pyparsing_common.fnumber.run_tests(''' # any int or real number, returned as float 100 -100 @@ -52,19 +52,19 @@ class pyparsing_common: 1e-12 ''') - pyparsing_common.hex_integer.runTests(''' + pyparsing_common.hex_integer.run_tests(''' # hex numbers 100 FF ''') - pyparsing_common.fraction.runTests(''' + pyparsing_common.fraction.run_tests(''' # fractions 1/2 -3/4 ''') - pyparsing_common.mixed_integer.runTests(''' + pyparsing_common.mixed_integer.run_tests(''' # mixed fractions 1 1/2 @@ -73,8 +73,8 @@ class pyparsing_common: ''') import uuid - pyparsing_common.uuid.setParseAction(tokenMap(uuid.UUID)) - pyparsing_common.uuid.runTests(''' + pyparsing_common.uuid.set_parse_action(token_map(uuid.UUID)) + pyparsing_common.uuid.run_tests(''' # uuid 12345678-1234-5678-1234-567812345678 ''') @@ -260,8 +260,8 @@ class pyparsing_common: Example:: date_expr = pyparsing_common.iso8601_date.copy() - date_expr.setParseAction(pyparsing_common.convertToDate()) - print(date_expr.parseString("1999-12-31")) + date_expr.set_parse_action(pyparsing_common.convert_to_date()) + print(date_expr.parse_string("1999-12-31")) prints:: @@ -287,8 +287,8 @@ class pyparsing_common: Example:: dt_expr = pyparsing_common.iso8601_datetime.copy() - dt_expr.setParseAction(pyparsing_common.convertToDatetime()) - print(dt_expr.parseString("1999-12-31T23:59:59.999")) + dt_expr.set_parse_action(pyparsing_common.convert_to_datetime()) + print(dt_expr.parse_string("1999-12-31T23:59:59.999")) prints:: @@ -326,9 +326,9 @@ class pyparsing_common: # strip HTML links from normal text text = 'More info at the pyparsing wiki page' - td, td_end = makeHTMLTags("TD") - table_text = td + SkipTo(td_end).setParseAction(pyparsing_common.stripHTMLTags)("body") + td_end - print(table_text.parseString(text).body) + td, td_end = make_html_tags("TD") + table_text = td + SkipTo(td_end).set_parse_action(pyparsing_common.strip_html_tags)("body") + td_end + print(table_text.parse_string(text).body) Prints:: @@ -348,7 +348,7 @@ class pyparsing_common: .streamline() .set_name("commaItem") ) - comma_separated_list = delimited_list( + comma_separated_list = DelimitedList( Opt(quoted_string.copy() | _commasepitem, default="") ).set_name("comma separated list") """Predefined expression of 1 or more printable words or quoted strings, separated by commas.""" @@ -363,7 +363,7 @@ class pyparsing_common: url = Regex( # https://mathiasbynens.be/demo/url-regex # https://gist.github.com/dperini/729294 - r"^" + + r"(?P" + # protocol identifier (optional) # short syntax // still required r"(?:(?:(?Phttps?|ftp):)?\/\/)" + @@ -405,18 +405,26 @@ class pyparsing_common: r"(\?(?P[^#]*))?" + # fragment (optional) r"(#(?P\S*))?" + - r"$" + r")" ).set_name("url") + """URL (http/https/ftp scheme)""" # fmt: on # pre-PEP8 compatibility names convertToInteger = convert_to_integer + """Deprecated - use :class:`convert_to_integer`""" convertToFloat = convert_to_float + """Deprecated - use :class:`convert_to_float`""" convertToDate = convert_to_date + """Deprecated - use :class:`convert_to_date`""" convertToDatetime = convert_to_datetime + """Deprecated - use :class:`convert_to_datetime`""" stripHTMLTags = strip_html_tags + """Deprecated - use :class:`strip_html_tags`""" upcaseTokens = upcase_tokens + """Deprecated - use :class:`upcase_tokens`""" downcaseTokens = downcase_tokens + """Deprecated - use :class:`downcase_tokens`""" _builtin_exprs = [ diff --git a/lib/pyparsing/core.py b/lib/pyparsing/core.py index 9acba3f3..73514ed0 100644 --- a/lib/pyparsing/core.py +++ b/lib/pyparsing/core.py @@ -1,19 +1,22 @@ # # core.py # + +from collections import deque import os import typing from typing import ( - NamedTuple, - Union, - Callable, Any, + Callable, Generator, - Tuple, List, - TextIO, - Set, + NamedTuple, Sequence, + Set, + TextIO, + Tuple, + Union, + cast, ) from abc import ABC, abstractmethod from enum import Enum @@ -40,6 +43,7 @@ from .util import ( _flatten, LRUMemo as _LRUMemo, UnboundedMemo as _UnboundedMemo, + replaced_by_pep8, ) from .exceptions import * from .actions import * @@ -134,6 +138,7 @@ class __diag__(__config_flags): class Diagnostics(Enum): """ Diagnostic configuration (all default to disabled) + - ``warn_multiple_tokens_in_named_alternation`` - flag to enable warnings when a results name is defined on a :class:`MatchFirst` or :class:`Or` expression with one or more :class:`And` subexpressions - ``warn_ungrouped_named_tokens_in_collection`` - flag to enable warnings when a results @@ -228,6 +233,8 @@ _single_arg_builtins = { } _generatorType = types.GeneratorType +ParseImplReturnType = Tuple[int, Any] +PostParseReturnType = Union[ParseResults, Sequence[ParseResults]] ParseAction = Union[ Callable[[], Any], Callable[[ParseResults], Any], @@ -256,7 +263,7 @@ hexnums = nums + "ABCDEFabcdef" alphanums = alphas + nums printables = "".join([c for c in string.printable if c not in string.whitespace]) -_trim_arity_call_line: traceback.StackSummary = None +_trim_arity_call_line: traceback.StackSummary = None # type: ignore[assignment] def _trim_arity(func, max_limit=3): @@ -269,11 +276,6 @@ def _trim_arity(func, max_limit=3): limit = 0 found_arity = False - def extract_tb(tb, limit=0): - frames = traceback.extract_tb(tb, limit=limit) - frame_summary = frames[-1] - return [frame_summary[:2]] - # synthesize what would be returned by traceback.extract_stack at the call to # user's parse action 'func', so that we don't incur call penalty at parse time @@ -297,8 +299,10 @@ def _trim_arity(func, max_limit=3): raise else: tb = te.__traceback__ + frames = traceback.extract_tb(tb, limit=2) + frame_summary = frames[-1] trim_arity_type_error = ( - extract_tb(tb, limit=2)[-1][:2] == pa_call_line_synth + [frame_summary[:2]][-1][:2] == pa_call_line_synth ) del tb @@ -320,7 +324,7 @@ def _trim_arity(func, max_limit=3): def condition_as_parse_action( - fn: ParseCondition, message: str = None, fatal: bool = False + fn: ParseCondition, message: typing.Optional[str] = None, fatal: bool = False ) -> ParseAction: """ Function to convert a simple predicate function that returns ``True`` or ``False`` @@ -353,15 +357,9 @@ def _default_start_debug_action( cache_hit_str = "*" if cache_hit else "" print( ( - "{}Match {} at loc {}({},{})\n {}\n {}^".format( - cache_hit_str, - expr, - loc, - lineno(loc, instring), - col(loc, instring), - line(loc, instring), - " " * (col(loc, instring) - 1), - ) + f"{cache_hit_str}Match {expr} at loc {loc}({lineno(loc, instring)},{col(loc, instring)})\n" + f" {line(loc, instring)}\n" + f" {' ' * (col(loc, instring) - 1)}^" ) ) @@ -375,7 +373,7 @@ def _default_success_debug_action( cache_hit: bool = False, ): cache_hit_str = "*" if cache_hit else "" - print("{}Matched {} -> {}".format(cache_hit_str, expr, toks.as_list())) + print(f"{cache_hit_str}Matched {expr} -> {toks.as_list()}") def _default_exception_debug_action( @@ -386,11 +384,7 @@ def _default_exception_debug_action( cache_hit: bool = False, ): cache_hit_str = "*" if cache_hit else "" - print( - "{}Match {} failed, {} raised: {}".format( - cache_hit_str, expr, type(exc).__name__, exc - ) - ) + print(f"{cache_hit_str}Match {expr} failed, {type(exc).__name__} raised: {exc}") def null_debug_action(*args): @@ -402,7 +396,7 @@ class ParserElement(ABC): DEFAULT_WHITE_CHARS: str = " \n\t\r" verbose_stacktrace: bool = False - _literalStringClass: typing.Optional[type] = None + _literalStringClass: type = None # type: ignore[assignment] @staticmethod def set_default_whitespace_chars(chars: str) -> None: @@ -447,6 +441,18 @@ class ParserElement(ABC): """ ParserElement._literalStringClass = cls + @classmethod + def using_each(cls, seq, **class_kwargs): + """ + Yields a sequence of class(obj, **class_kwargs) for obj in seq. + + Example:: + + LPAR, RPAR, LBRACE, RBRACE, SEMI = Suppress.using_each("(){};") + + """ + yield from (cls(obj, **class_kwargs) for obj in seq) + class DebugActions(NamedTuple): debug_try: typing.Optional[DebugStartAction] debug_match: typing.Optional[DebugSuccessAction] @@ -455,9 +461,9 @@ class ParserElement(ABC): def __init__(self, savelist: bool = False): self.parseAction: List[ParseAction] = list() self.failAction: typing.Optional[ParseFailAction] = None - self.customName = None - self._defaultName = None - self.resultsName = None + self.customName: str = None # type: ignore[assignment] + self._defaultName: typing.Optional[str] = None + self.resultsName: str = None # type: ignore[assignment] self.saveAsList = savelist self.skipWhitespace = True self.whiteChars = set(ParserElement.DEFAULT_WHITE_CHARS) @@ -490,12 +496,29 @@ class ParserElement(ABC): base.suppress_warning(Diagnostics.warn_on_parse_using_empty_Forward) # statement would normally raise a warning, but is now suppressed - print(base.parseString("x")) + print(base.parse_string("x")) """ self.suppress_warnings_.append(warning_type) return self + def visit_all(self): + """General-purpose method to yield all expressions and sub-expressions + in a grammar. Typically just for internal use. + """ + to_visit = deque([self]) + seen = set() + while to_visit: + cur = to_visit.popleft() + + # guard against looping forever through recursive grammars + if cur in seen: + continue + seen.add(cur) + + to_visit.extend(cur.recurse()) + yield cur + def copy(self) -> "ParserElement": """ Make a copy of this :class:`ParserElement`. Useful for defining @@ -585,11 +608,11 @@ class ParserElement(ABC): pdb.set_trace() return _parseMethod(instring, loc, doActions, callPreParse) - breaker._originalParseMethod = _parseMethod - self._parse = breaker + breaker._originalParseMethod = _parseMethod # type: ignore [attr-defined] + self._parse = breaker # type: ignore [assignment] else: if hasattr(self._parse, "_originalParseMethod"): - self._parse = self._parse._originalParseMethod + self._parse = self._parse._originalParseMethod # type: ignore [attr-defined, assignment] return self def set_parse_action(self, *fns: ParseAction, **kwargs) -> "ParserElement": @@ -601,9 +624,9 @@ class ParserElement(ABC): Each parse action ``fn`` is a callable method with 0-3 arguments, called as ``fn(s, loc, toks)`` , ``fn(loc, toks)`` , ``fn(toks)`` , or just ``fn()`` , where: - - s = the original string being parsed (see note below) - - loc = the location of the matching substring - - toks = a list of the matched tokens, packaged as a :class:`ParseResults` object + - ``s`` = the original string being parsed (see note below) + - ``loc`` = the location of the matching substring + - ``toks`` = a list of the matched tokens, packaged as a :class:`ParseResults` object The parsed tokens are passed to the parse action as ParseResults. They can be modified in place using list-style append, extend, and pop operations to update @@ -621,7 +644,7 @@ class ParserElement(ABC): Optional keyword arguments: - - call_during_try = (default= ``False``) indicate if parse action should be run during + - ``call_during_try`` = (default= ``False``) indicate if parse action should be run during lookaheads and alternate testing. For parse actions that have side effects, it is important to only call the parse action once it is determined that it is being called as part of a successful parse. For parse actions that perform additional @@ -697,10 +720,10 @@ class ParserElement(ABC): Optional keyword arguments: - - message = define a custom message to be used in the raised exception - - fatal = if True, will raise ParseFatalException to stop parsing immediately; otherwise will raise + - ``message`` = define a custom message to be used in the raised exception + - ``fatal`` = if True, will raise ParseFatalException to stop parsing immediately; otherwise will raise ParseException - - call_during_try = boolean to indicate if this method should be called during internal tryParse calls, + - ``call_during_try`` = boolean to indicate if this method should be called during internal tryParse calls, default=False Example:: @@ -716,7 +739,9 @@ class ParserElement(ABC): for fn in fns: self.parseAction.append( condition_as_parse_action( - fn, message=kwargs.get("message"), fatal=kwargs.get("fatal", False) + fn, + message=str(kwargs.get("message")), + fatal=bool(kwargs.get("fatal", False)), ) ) @@ -731,30 +756,38 @@ class ParserElement(ABC): Fail acton fn is a callable function that takes the arguments ``fn(s, loc, expr, err)`` where: - - s = string being parsed - - loc = location where expression match was attempted and failed - - expr = the parse expression that failed - - err = the exception thrown + - ``s`` = string being parsed + - ``loc`` = location where expression match was attempted and failed + - ``expr`` = the parse expression that failed + - ``err`` = the exception thrown The function returns no value. It may throw :class:`ParseFatalException` if it is desired to stop parsing immediately.""" self.failAction = fn return self - def _skipIgnorables(self, instring, loc): + def _skipIgnorables(self, instring: str, loc: int) -> int: + if not self.ignoreExprs: + return loc exprsFound = True + ignore_expr_fns = [e._parse for e in self.ignoreExprs] + last_loc = loc while exprsFound: exprsFound = False - for e in self.ignoreExprs: + for ignore_fn in ignore_expr_fns: try: while 1: - loc, dummy = e._parse(instring, loc) + loc, dummy = ignore_fn(instring, loc) exprsFound = True except ParseException: pass + # check if all ignore exprs matched but didn't actually advance the parse location + if loc == last_loc: + break + last_loc = loc return loc - def preParse(self, instring, loc): + def preParse(self, instring: str, loc: int) -> int: if self.ignoreExprs: loc = self._skipIgnorables(instring, loc) @@ -830,7 +863,7 @@ class ParserElement(ABC): try: for fn in self.parseAction: try: - tokens = fn(instring, tokens_start, ret_tokens) + tokens = fn(instring, tokens_start, ret_tokens) # type: ignore [call-arg, arg-type] except IndexError as parse_action_exc: exc = ParseException("exception raised in parse action") raise exc from parse_action_exc @@ -853,7 +886,7 @@ class ParserElement(ABC): else: for fn in self.parseAction: try: - tokens = fn(instring, tokens_start, ret_tokens) + tokens = fn(instring, tokens_start, ret_tokens) # type: ignore [call-arg, arg-type] except IndexError as parse_action_exc: exc = ParseException("exception raised in parse action") raise exc from parse_action_exc @@ -875,17 +908,24 @@ class ParserElement(ABC): return loc, ret_tokens - def try_parse(self, instring: str, loc: int, raise_fatal: bool = False) -> int: + def try_parse( + self, + instring: str, + loc: int, + *, + raise_fatal: bool = False, + do_actions: bool = False, + ) -> int: try: - return self._parse(instring, loc, doActions=False)[0] + return self._parse(instring, loc, doActions=do_actions)[0] except ParseFatalException: if raise_fatal: raise raise ParseException(instring, loc, self.errmsg, self) - def can_parse_next(self, instring: str, loc: int) -> bool: + def can_parse_next(self, instring: str, loc: int, do_actions: bool = False) -> bool: try: - self.try_parse(instring, loc) + self.try_parse(instring, loc, do_actions=do_actions) except (ParseException, IndexError): return False else: @@ -897,10 +937,23 @@ class ParserElement(ABC): Tuple[int, "Forward", bool], Tuple[int, Union[ParseResults, Exception]] ] = {} + class _CacheType(dict): + """ + class to help type checking + """ + + not_in_cache: bool + + def get(self, *args): + ... + + def set(self, *args): + ... + # argument cache for optimizing repeated calls when backtracking through recursive expressions packrat_cache = ( - {} - ) # this is set later by enabled_packrat(); this is here so that reset_cache() doesn't fail + _CacheType() + ) # set later by enable_packrat(); this is here so that reset_cache() doesn't fail packrat_cache_lock = RLock() packrat_cache_stats = [0, 0] @@ -930,24 +983,25 @@ class ParserElement(ABC): ParserElement.packrat_cache_stats[HIT] += 1 if self.debug and self.debugActions.debug_try: try: - self.debugActions.debug_try(instring, loc, self, cache_hit=True) + self.debugActions.debug_try(instring, loc, self, cache_hit=True) # type: ignore [call-arg] except TypeError: pass if isinstance(value, Exception): if self.debug and self.debugActions.debug_fail: try: self.debugActions.debug_fail( - instring, loc, self, value, cache_hit=True + instring, loc, self, value, cache_hit=True # type: ignore [call-arg] ) except TypeError: pass raise value + value = cast(Tuple[int, ParseResults, int], value) loc_, result, endloc = value[0], value[1].copy(), value[2] if self.debug and self.debugActions.debug_match: try: self.debugActions.debug_match( - instring, loc_, endloc, self, result, cache_hit=True + instring, loc_, endloc, self, result, cache_hit=True # type: ignore [call-arg] ) except TypeError: pass @@ -1009,7 +1063,7 @@ class ParserElement(ABC): Parameters: - - cache_size_limit - (default=``None``) - memoize at most this many + - ``cache_size_limit`` - (default=``None``) - memoize at most this many ``Forward`` elements during matching; if ``None`` (the default), memoize all ``Forward`` elements. @@ -1022,15 +1076,17 @@ class ParserElement(ABC): elif ParserElement._packratEnabled: raise RuntimeError("Packrat and Bounded Recursion are not compatible") if cache_size_limit is None: - ParserElement.recursion_memos = _UnboundedMemo() + ParserElement.recursion_memos = _UnboundedMemo() # type: ignore[assignment] elif cache_size_limit > 0: - ParserElement.recursion_memos = _LRUMemo(capacity=cache_size_limit) + ParserElement.recursion_memos = _LRUMemo(capacity=cache_size_limit) # type: ignore[assignment] else: raise NotImplementedError("Memo size of %s" % cache_size_limit) ParserElement._left_recursion_enabled = True @staticmethod - def enable_packrat(cache_size_limit: int = 128, *, force: bool = False) -> None: + def enable_packrat( + cache_size_limit: Union[int, None] = 128, *, force: bool = False + ) -> None: """ Enables "packrat" parsing, which adds memoizing to the parsing logic. Repeated parse attempts at the same string location (which happens @@ -1040,7 +1096,7 @@ class ParserElement(ABC): Parameters: - - cache_size_limit - (default= ``128``) - if an integer value is provided + - ``cache_size_limit`` - (default= ``128``) - if an integer value is provided will limit the size of the packrat cache; if None is passed, then the cache size will be unbounded; if 0 is passed, the cache will be effectively disabled. @@ -1070,7 +1126,7 @@ class ParserElement(ABC): if cache_size_limit is None: ParserElement.packrat_cache = _UnboundedCache() else: - ParserElement.packrat_cache = _FifoCache(cache_size_limit) + ParserElement.packrat_cache = _FifoCache(cache_size_limit) # type: ignore[assignment] ParserElement._parse = ParserElement._parseCache def parse_string( @@ -1088,7 +1144,7 @@ class ParserElement(ABC): an object with attributes if the given parser includes results names. If the input string is required to match the entire grammar, ``parse_all`` flag must be set to ``True``. This - is also equivalent to ending the grammar with :class:`StringEnd`(). + is also equivalent to ending the grammar with :class:`StringEnd`\\ (). To report proper column numbers, ``parse_string`` operates on a copy of the input string where all tabs are converted to spaces (8 spaces per tab, as per the default in ``string.expandtabs``). If the input string @@ -1198,7 +1254,9 @@ class ParserElement(ABC): try: while loc <= instrlen and matches < maxMatches: try: - preloc = preparseFn(instring, loc) + preloc: int = preparseFn(instring, loc) + nextLoc: int + tokens: ParseResults nextLoc, tokens = parseFn(instring, preloc, callPreParse=False) except ParseException: loc = preloc + 1 @@ -1352,7 +1410,7 @@ class ParserElement(ABC): def __add__(self, other) -> "ParserElement": """ Implementation of ``+`` operator - returns :class:`And`. Adding strings to a :class:`ParserElement` - converts them to :class:`Literal`s by default. + converts them to :class:`Literal`\\ s by default. Example:: @@ -1364,11 +1422,11 @@ class ParserElement(ABC): Hello, World! -> ['Hello', ',', 'World', '!'] - ``...`` may be used as a parse expression as a short form of :class:`SkipTo`. + ``...`` may be used as a parse expression as a short form of :class:`SkipTo`:: Literal('start') + ... + Literal('end') - is equivalent to: + is equivalent to:: Literal('start') + SkipTo('end')("_skipped*") + Literal('end') @@ -1382,11 +1440,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return And([self, other]) def __radd__(self, other) -> "ParserElement": @@ -1399,11 +1453,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return other + self def __sub__(self, other) -> "ParserElement": @@ -1413,11 +1463,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return self + And._ErrorStop() + other def __rsub__(self, other) -> "ParserElement": @@ -1427,11 +1473,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return other - self def __mul__(self, other) -> "ParserElement": @@ -1440,11 +1482,12 @@ class ParserElement(ABC): ``expr + expr + expr``. Expressions may also be multiplied by a 2-integer tuple, similar to ``{min, max}`` multipliers in regular expressions. Tuples may also include ``None`` as in: + - ``expr*(n, None)`` or ``expr*(n, )`` is equivalent - to ``expr*n + ZeroOrMore(expr)`` - (read as "at least n instances of ``expr``") + to ``expr*n + ZeroOrMore(expr)`` + (read as "at least n instances of ``expr``") - ``expr*(None, n)`` is equivalent to ``expr*(0, n)`` - (read as "0 to n instances of ``expr``") + (read as "0 to n instances of ``expr``") - ``expr*(None, None)`` is equivalent to ``ZeroOrMore(expr)`` - ``expr*(1, None)`` is equivalent to ``OneOrMore(expr)`` @@ -1477,17 +1520,9 @@ class ParserElement(ABC): minElements, optElements = other optElements -= minElements else: - raise TypeError( - "cannot multiply ParserElement and ({}) objects".format( - ",".join(type(item).__name__ for item in other) - ) - ) + return NotImplemented else: - raise TypeError( - "cannot multiply ParserElement and {} objects".format( - type(other).__name__ - ) - ) + return NotImplemented if minElements < 0: raise ValueError("cannot multiply ParserElement by negative value") @@ -1531,13 +1566,12 @@ class ParserElement(ABC): return _PendingSkip(self, must_skip=True) if isinstance(other, str_type): + # `expr | ""` is equivalent to `Opt(expr)` + if other == "": + return Opt(self) other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return MatchFirst([self, other]) def __ror__(self, other) -> "ParserElement": @@ -1547,11 +1581,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return other | self def __xor__(self, other) -> "ParserElement": @@ -1561,11 +1591,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return Or([self, other]) def __rxor__(self, other) -> "ParserElement": @@ -1575,11 +1601,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return other ^ self def __and__(self, other) -> "ParserElement": @@ -1589,11 +1611,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return Each([self, other]) def __rand__(self, other) -> "ParserElement": @@ -1603,11 +1621,7 @@ class ParserElement(ABC): if isinstance(other, str_type): other = self._literalStringClass(other) if not isinstance(other, ParserElement): - raise TypeError( - "Cannot combine element of type {} with ParserElement".format( - type(other).__name__ - ) - ) + return NotImplemented return other & self def __invert__(self) -> "ParserElement": @@ -1636,38 +1650,58 @@ class ParserElement(ABC): ``None`` may be used in place of ``...``. - Note that ``expr[..., n]`` and ``expr[m, n]``do not raise an exception - if more than ``n`` ``expr``s exist in the input stream. If this behavior is + Note that ``expr[..., n]`` and ``expr[m, n]`` do not raise an exception + if more than ``n`` ``expr``\\ s exist in the input stream. If this behavior is desired, then write ``expr[..., n] + ~expr``. + + For repetition with a stop_on expression, use slice notation: + + - ``expr[...: end_expr]`` and ``expr[0, ...: end_expr]`` are equivalent to ``ZeroOrMore(expr, stop_on=end_expr)`` + - ``expr[1, ...: end_expr]`` is equivalent to ``OneOrMore(expr, stop_on=end_expr)`` + """ + stop_on_defined = False + stop_on = NoMatch() + if isinstance(key, slice): + key, stop_on = key.start, key.stop + if key is None: + key = ... + stop_on_defined = True + elif isinstance(key, tuple) and isinstance(key[-1], slice): + key, stop_on = (key[0], key[1].start), key[1].stop + stop_on_defined = True + # convert single arg keys to tuples + if isinstance(key, str_type): + key = (key,) try: - if isinstance(key, str_type): - key = (key,) iter(key) except TypeError: key = (key, key) if len(key) > 2: raise TypeError( - "only 1 or 2 index arguments supported ({}{})".format( - key[:5], "... [{}]".format(len(key)) if len(key) > 5 else "" - ) + f"only 1 or 2 index arguments supported ({key[:5]}{f'... [{len(key)}]' if len(key) > 5 else ''})" ) # clip to 2 elements ret = self * tuple(key[:2]) + ret = typing.cast(_MultipleMatch, ret) + + if stop_on_defined: + ret.stopOn(stop_on) + return ret - def __call__(self, name: str = None) -> "ParserElement": + def __call__(self, name: typing.Optional[str] = None) -> "ParserElement": """ Shortcut for :class:`set_results_name`, with ``list_all_matches=False``. If ``name`` is given with a trailing ``'*'`` character, then ``list_all_matches`` will be passed as ``True``. - If ``name` is omitted, same as calling :class:`copy`. + If ``name`` is omitted, same as calling :class:`copy`. Example:: @@ -1775,17 +1809,18 @@ class ParserElement(ABC): should have the signature ``fn(input_string: str, location: int, expression: ParserElement, exception: Exception, cache_hit: bool)`` """ self.debugActions = self.DebugActions( - start_action or _default_start_debug_action, - success_action or _default_success_debug_action, - exception_action or _default_exception_debug_action, + start_action or _default_start_debug_action, # type: ignore[truthy-function] + success_action or _default_success_debug_action, # type: ignore[truthy-function] + exception_action or _default_exception_debug_action, # type: ignore[truthy-function] ) self.debug = True return self - def set_debug(self, flag: bool = True) -> "ParserElement": + def set_debug(self, flag: bool = True, recurse: bool = False) -> "ParserElement": """ Enable display of debugging messages while doing pattern matching. Set ``flag`` to ``True`` to enable, ``False`` to disable. + Set ``recurse`` to ``True`` to set the debug flag on this expression and all sub-expressions. Example:: @@ -1819,6 +1854,11 @@ class ParserElement(ABC): which makes debugging and exception messages easier to understand - for instance, the default name created for the :class:`Word` expression without calling ``set_name`` is ``"W:(A-Za-z)"``. """ + if recurse: + for expr in self.visit_all(): + expr.set_debug(flag, recurse=False) + return self + if flag: self.set_debug_actions( _default_start_debug_action, @@ -1836,7 +1876,7 @@ class ParserElement(ABC): return self._defaultName @abstractmethod - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: """ Child classes must define this method, which defines how the ``default_name`` is set. """ @@ -1844,7 +1884,9 @@ class ParserElement(ABC): def set_name(self, name: str) -> "ParserElement": """ Define name for this expression, makes debugging and exception messages clearer. + Example:: + Word(nums).parse_string("ABC") # -> Exception: Expected W:(0-9) (at char 0), (line:1, col:1) Word(nums).set_name("integer").parse_string("ABC") # -> Exception: Expected integer (at char 0), (line:1, col:1) """ @@ -1870,7 +1912,7 @@ class ParserElement(ABC): self._defaultName = None return self - def recurse(self) -> Sequence["ParserElement"]: + def recurse(self) -> List["ParserElement"]: return [] def _checkRecursion(self, parseElementList): @@ -1882,6 +1924,11 @@ class ParserElement(ABC): """ Check defined expressions for valid structure, check for infinite recursive definitions. """ + warnings.warn( + "ParserElement.validate() is deprecated, and should not be used to check for left recursion", + DeprecationWarning, + stacklevel=2, + ) self._checkRecursion([]) def parse_file( @@ -1899,8 +1946,10 @@ class ParserElement(ABC): """ parseAll = parseAll or parse_all try: + file_or_filename = typing.cast(TextIO, file_or_filename) file_contents = file_or_filename.read() except AttributeError: + file_or_filename = typing.cast(str, file_or_filename) with open(file_or_filename, "r", encoding=encoding) as f: file_contents = f.read() try: @@ -1932,6 +1981,7 @@ class ParserElement(ABC): inline microtests of sub expressions while building up larger parser. Parameters: + - ``test_string`` - to test against this expression for a match - ``parse_all`` - (default= ``True``) - flag to pass to :class:`parse_string` when running tests @@ -1955,7 +2005,7 @@ class ParserElement(ABC): full_dump: bool = True, print_results: bool = True, failure_tests: bool = False, - post_parse: Callable[[str, ParseResults], str] = None, + post_parse: typing.Optional[Callable[[str, ParseResults], str]] = None, file: typing.Optional[TextIO] = None, with_line_numbers: bool = False, *, @@ -1963,7 +2013,7 @@ class ParserElement(ABC): fullDump: bool = True, printResults: bool = True, failureTests: bool = False, - postParse: Callable[[str, ParseResults], str] = None, + postParse: typing.Optional[Callable[[str, ParseResults], str]] = None, ) -> Tuple[bool, List[Tuple[str, Union[ParseResults, Exception]]]]: """ Execute the parse expression on a series of test strings, showing each @@ -1971,6 +2021,7 @@ class ParserElement(ABC): run a parse expression against a list of sample strings. Parameters: + - ``tests`` - a list of separate test strings, or a multiline string of test strings - ``parse_all`` - (default= ``True``) - flag to pass to :class:`parse_string` when running tests - ``comment`` - (default= ``'#'``) - expression for indicating embedded comments in the test @@ -2067,22 +2118,27 @@ class ParserElement(ABC): failureTests = failureTests or failure_tests postParse = postParse or post_parse if isinstance(tests, str_type): + tests = typing.cast(str, tests) line_strip = type(tests).strip tests = [line_strip(test_line) for test_line in tests.rstrip().splitlines()] - if isinstance(comment, str_type): - comment = Literal(comment) + comment_specified = comment is not None + if comment_specified: + if isinstance(comment, str_type): + comment = typing.cast(str, comment) + comment = Literal(comment) + comment = typing.cast(ParserElement, comment) if file is None: file = sys.stdout print_ = file.write result: Union[ParseResults, Exception] - allResults = [] - comments = [] + allResults: List[Tuple[str, Union[ParseResults, Exception]]] = [] + comments: List[str] = [] success = True NL = Literal(r"\n").add_parse_action(replace_with("\n")).ignore(quoted_string) BOM = "\ufeff" for t in tests: - if comment is not None and comment.matches(t, False) or comments and not t: + if comment_specified and comment.matches(t, False) or comments and not t: comments.append( pyparsing_test.with_line_numbers(t) if with_line_numbers else t ) @@ -2107,7 +2163,7 @@ class ParserElement(ABC): success = success and failureTests result = pe except Exception as exc: - out.append("FAIL-EXCEPTION: {}: {}".format(type(exc).__name__, exc)) + out.append(f"FAIL-EXCEPTION: {type(exc).__name__}: {exc}") if ParserElement.verbose_stacktrace: out.extend(traceback.format_tb(exc.__traceback__)) success = success and failureTests @@ -2127,9 +2183,7 @@ class ParserElement(ABC): except Exception as e: out.append(result.dump(full=fullDump)) out.append( - "{} failed: {}: {}".format( - postParse.__name__, type(e).__name__, e - ) + f"{postParse.__name__} failed: {type(e).__name__}: {e}" ) else: out.append(result.dump(full=fullDump)) @@ -2148,19 +2202,28 @@ class ParserElement(ABC): vertical: int = 3, show_results_names: bool = False, show_groups: bool = False, + embed: bool = False, **kwargs, ) -> None: """ Create a railroad diagram for the parser. Parameters: - - output_html (str or file-like object) - output target for generated + + - ``output_html`` (str or file-like object) - output target for generated diagram HTML - - vertical (int) - threshold for formatting multiple alternatives vertically + - ``vertical`` (int) - threshold for formatting multiple alternatives vertically instead of horizontally (default=3) - - show_results_names - bool flag whether diagram should show annotations for + - ``show_results_names`` - bool flag whether diagram should show annotations for defined results names - - show_groups - bool flag whether groups should be highlighted with an unlabeled surrounding box + - ``show_groups`` - bool flag whether groups should be highlighted with an unlabeled surrounding box + - ``embed`` - bool flag whether generated HTML should omit , , and tags to embed + the resulting HTML in an enclosing HTML source + - ``head`` - str containing additional HTML to insert into the section of the generated code; + can be used to insert custom CSS styling + - ``body`` - str containing additional HTML to insert at the beginning of the section of the + generated code + Additional diagram-formatting keyword arguments can also be included; see railroad.Diagram class. """ @@ -2183,38 +2246,93 @@ class ParserElement(ABC): ) if isinstance(output_html, (str, Path)): with open(output_html, "w", encoding="utf-8") as diag_file: - diag_file.write(railroad_to_html(railroad)) + diag_file.write(railroad_to_html(railroad, embed=embed, **kwargs)) else: # we were passed a file-like object, just write to it - output_html.write(railroad_to_html(railroad)) + output_html.write(railroad_to_html(railroad, embed=embed, **kwargs)) + + # Compatibility synonyms + # fmt: off + @staticmethod + @replaced_by_pep8(inline_literals_using) + def inlineLiteralsUsing(): ... + + @staticmethod + @replaced_by_pep8(set_default_whitespace_chars) + def setDefaultWhitespaceChars(): ... + + @replaced_by_pep8(set_results_name) + def setResultsName(self): ... + + @replaced_by_pep8(set_break) + def setBreak(self): ... + + @replaced_by_pep8(set_parse_action) + def setParseAction(self): ... + + @replaced_by_pep8(add_parse_action) + def addParseAction(self): ... + + @replaced_by_pep8(add_condition) + def addCondition(self): ... + + @replaced_by_pep8(set_fail_action) + def setFailAction(self): ... + + @replaced_by_pep8(try_parse) + def tryParse(self): ... + + @staticmethod + @replaced_by_pep8(enable_left_recursion) + def enableLeftRecursion(): ... + + @staticmethod + @replaced_by_pep8(enable_packrat) + def enablePackrat(): ... + + @replaced_by_pep8(parse_string) + def parseString(self): ... + + @replaced_by_pep8(scan_string) + def scanString(self): ... + + @replaced_by_pep8(transform_string) + def transformString(self): ... + + @replaced_by_pep8(search_string) + def searchString(self): ... + + @replaced_by_pep8(ignore_whitespace) + def ignoreWhitespace(self): ... + + @replaced_by_pep8(leave_whitespace) + def leaveWhitespace(self): ... + + @replaced_by_pep8(set_whitespace_chars) + def setWhitespaceChars(self): ... + + @replaced_by_pep8(parse_with_tabs) + def parseWithTabs(self): ... + + @replaced_by_pep8(set_debug_actions) + def setDebugActions(self): ... + + @replaced_by_pep8(set_debug) + def setDebug(self): ... + + @replaced_by_pep8(set_name) + def setName(self): ... + + @replaced_by_pep8(parse_file) + def parseFile(self): ... + + @replaced_by_pep8(run_tests) + def runTests(self): ... - setDefaultWhitespaceChars = set_default_whitespace_chars - inlineLiteralsUsing = inline_literals_using - setResultsName = set_results_name - setBreak = set_break - setParseAction = set_parse_action - addParseAction = add_parse_action - addCondition = add_condition - setFailAction = set_fail_action - tryParse = try_parse canParseNext = can_parse_next resetCache = reset_cache - enableLeftRecursion = enable_left_recursion - enablePackrat = enable_packrat - parseString = parse_string - scanString = scan_string - searchString = search_string - transformString = transform_string - setWhitespaceChars = set_whitespace_chars - parseWithTabs = parse_with_tabs - setDebugActions = set_debug_actions - setDebug = set_debug defaultName = default_name - setName = set_name - parseFile = parse_file - runTests = run_tests - ignoreWhitespace = ignore_whitespace - leaveWhitespace = leave_whitespace + # fmt: on class _PendingSkip(ParserElement): @@ -2225,7 +2343,7 @@ class _PendingSkip(ParserElement): self.anchor = expr self.must_skip = must_skip - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return str(self.anchor + Empty()).replace("Empty", "...") def __add__(self, other) -> "ParserElement": @@ -2266,21 +2384,10 @@ class Token(ParserElement): def __init__(self): super().__init__(savelist=False) - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return type(self).__name__ -class Empty(Token): - """ - An empty token, will always match. - """ - - def __init__(self): - super().__init__() - self.mayReturnEmpty = True - self.mayIndexError = False - - class NoMatch(Token): """ A token that will never match. @@ -2312,25 +2419,33 @@ class Literal(Token): use :class:`Keyword` or :class:`CaselessKeyword`. """ + def __new__(cls, match_string: str = "", *, matchString: str = ""): + # Performance tuning: select a subclass with optimized parseImpl + if cls is Literal: + match_string = matchString or match_string + if not match_string: + return super().__new__(Empty) + if len(match_string) == 1: + return super().__new__(_SingleCharLiteral) + + # Default behavior + return super().__new__(cls) + + # Needed to make copy.copy() work correctly if we customize __new__ + def __getnewargs__(self): + return (self.match,) + def __init__(self, match_string: str = "", *, matchString: str = ""): super().__init__() match_string = matchString or match_string self.match = match_string self.matchLen = len(match_string) - try: - self.firstMatchChar = match_string[0] - except IndexError: - raise ValueError("null string passed to Literal; use Empty() instead") + self.firstMatchChar = match_string[:1] self.errmsg = "Expected " + self.name self.mayReturnEmpty = False self.mayIndexError = False - # Performance tuning: modify __class__ to select - # a parseImpl optimized for single-character check - if self.matchLen == 1 and type(self) is Literal: - self.__class__ = _SingleCharLiteral - - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return repr(self.match) def parseImpl(self, instring, loc, doActions=True): @@ -2341,6 +2456,23 @@ class Literal(Token): raise ParseException(instring, loc, self.errmsg, self) +class Empty(Literal): + """ + An empty token, will always match. + """ + + def __init__(self, match_string="", *, matchString=""): + super().__init__("") + self.mayReturnEmpty = True + self.mayIndexError = False + + def _generateDefaultName(self) -> str: + return "Empty" + + def parseImpl(self, instring, loc, doActions=True): + return loc, [] + + class _SingleCharLiteral(Literal): def parseImpl(self, instring, loc, doActions=True): if instring[loc] == self.firstMatchChar: @@ -2354,8 +2486,8 @@ ParserElement._literalStringClass = Literal class Keyword(Token): """ Token to exactly match a specified string as a keyword, that is, - it must be immediately followed by a non-keyword character. Compare - with :class:`Literal`: + it must be immediately preceded and followed by whitespace or + non-keyword characters. Compare with :class:`Literal`: - ``Literal("if")`` will match the leading ``'if'`` in ``'ifAndOnlyIf'``. @@ -2365,7 +2497,7 @@ class Keyword(Token): Accepts two optional constructor arguments in addition to the keyword string: - - ``identChars`` is a string of characters that would be valid + - ``ident_chars`` is a string of characters that would be valid identifier characters, defaulting to all alphanumerics + "_" and "$" - ``caseless`` allows case-insensitive matching, default is ``False``. @@ -2400,7 +2532,7 @@ class Keyword(Token): self.firstMatchChar = match_string[0] except IndexError: raise ValueError("null string passed to Keyword; use Empty() instead") - self.errmsg = "Expected {} {}".format(type(self).__name__, self.name) + self.errmsg = f"Expected {type(self).__name__} {self.name}" self.mayReturnEmpty = False self.mayIndexError = False self.caseless = caseless @@ -2409,7 +2541,7 @@ class Keyword(Token): identChars = identChars.upper() self.identChars = set(identChars) - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return repr(self.match) def parseImpl(self, instring, loc, doActions=True): @@ -2559,7 +2691,7 @@ class CloseMatch(Token): def __init__( self, match_string: str, - max_mismatches: int = None, + max_mismatches: typing.Optional[int] = None, *, maxMismatches: int = 1, caseless=False, @@ -2568,15 +2700,13 @@ class CloseMatch(Token): super().__init__() self.match_string = match_string self.maxMismatches = maxMismatches - self.errmsg = "Expected {!r} (with up to {} mismatches)".format( - self.match_string, self.maxMismatches - ) + self.errmsg = f"Expected {self.match_string!r} (with up to {self.maxMismatches} mismatches)" self.caseless = caseless self.mayIndexError = False self.mayReturnEmpty = False - def _generateDefaultName(self): - return "{}:{!r}".format(type(self).__name__, self.match_string) + def _generateDefaultName(self) -> str: + return f"{type(self).__name__}:{self.match_string!r}" def parseImpl(self, instring, loc, doActions=True): start = loc @@ -2612,7 +2742,9 @@ class CloseMatch(Token): class Word(Token): """Token for matching words composed of allowed character sets. + Parameters: + - ``init_chars`` - string of all characters that should be used to match as a word; "ABC" will match "AAA", "ABAB", "CBAC", etc.; if ``body_chars`` is also specified, then this is the string of @@ -2697,26 +2829,24 @@ class Word(Token): super().__init__() if not initChars: raise ValueError( - "invalid {}, initChars cannot be empty string".format( - type(self).__name__ - ) + f"invalid {type(self).__name__}, initChars cannot be empty string" ) - initChars = set(initChars) - self.initChars = initChars + initChars_set = set(initChars) if excludeChars: - excludeChars = set(excludeChars) - initChars -= excludeChars + excludeChars_set = set(excludeChars) + initChars_set -= excludeChars_set if bodyChars: - bodyChars = set(bodyChars) - excludeChars - self.initCharsOrig = "".join(sorted(initChars)) + bodyChars = "".join(set(bodyChars) - excludeChars_set) + self.initChars = initChars_set + self.initCharsOrig = "".join(sorted(initChars_set)) if bodyChars: - self.bodyCharsOrig = "".join(sorted(bodyChars)) self.bodyChars = set(bodyChars) + self.bodyCharsOrig = "".join(sorted(bodyChars)) else: - self.bodyCharsOrig = "".join(sorted(initChars)) - self.bodyChars = set(initChars) + self.bodyChars = initChars_set + self.bodyCharsOrig = self.initCharsOrig self.maxSpecified = max > 0 @@ -2725,6 +2855,11 @@ class Word(Token): "cannot specify a minimum length < 1; use Opt(Word()) if zero-length word is permitted" ) + if self.maxSpecified and min > max: + raise ValueError( + f"invalid args, if min and max both specified min must be <= max (min={min}, max={max})" + ) + self.minLen = min if max > 0: @@ -2733,62 +2868,64 @@ class Word(Token): self.maxLen = _MAX_INT if exact > 0: + min = max = exact self.maxLen = exact self.minLen = exact self.errmsg = "Expected " + self.name self.mayIndexError = False self.asKeyword = asKeyword + if self.asKeyword: + self.errmsg += " as a keyword" # see if we can make a regex for this Word - if " " not in self.initChars | self.bodyChars and (min == 1 and exact == 0): + if " " not in (self.initChars | self.bodyChars): + if len(self.initChars) == 1: + re_leading_fragment = re.escape(self.initCharsOrig) + else: + re_leading_fragment = f"[{_collapse_string_to_ranges(self.initChars)}]" + if self.bodyChars == self.initChars: - if max == 0: + if max == 0 and self.minLen == 1: repeat = "+" elif max == 1: repeat = "" else: - repeat = "{{{},{}}}".format( - self.minLen, "" if self.maxLen == _MAX_INT else self.maxLen - ) - self.reString = "[{}]{}".format( - _collapse_string_to_ranges(self.initChars), - repeat, - ) - elif len(self.initChars) == 1: - if max == 0: - repeat = "*" - else: - repeat = "{{0,{}}}".format(max - 1) - self.reString = "{}[{}]{}".format( - re.escape(self.initCharsOrig), - _collapse_string_to_ranges(self.bodyChars), - repeat, - ) + if self.minLen != self.maxLen: + repeat = f"{{{self.minLen},{'' if self.maxLen == _MAX_INT else self.maxLen}}}" + else: + repeat = f"{{{self.minLen}}}" + self.reString = f"{re_leading_fragment}{repeat}" else: - if max == 0: - repeat = "*" - elif max == 2: + if max == 1: + re_body_fragment = "" repeat = "" else: - repeat = "{{0,{}}}".format(max - 1) - self.reString = "[{}][{}]{}".format( - _collapse_string_to_ranges(self.initChars), - _collapse_string_to_ranges(self.bodyChars), - repeat, - ) + re_body_fragment = f"[{_collapse_string_to_ranges(self.bodyChars)}]" + if max == 0 and self.minLen == 1: + repeat = "*" + elif max == 2: + repeat = "?" if min <= 1 else "" + else: + if min != max: + repeat = f"{{{min - 1 if min > 0 else ''},{max - 1 if max > 0 else ''}}}" + else: + repeat = f"{{{min - 1 if min > 0 else ''}}}" + + self.reString = f"{re_leading_fragment}{re_body_fragment}{repeat}" + if self.asKeyword: - self.reString = r"\b" + self.reString + r"\b" + self.reString = rf"\b{self.reString}\b" try: self.re = re.compile(self.reString) except re.error: - self.re = None + self.re = None # type: ignore[assignment] else: self.re_match = self.re.match - self.__class__ = _WordRegex + self.parseImpl = self.parseImpl_regex # type: ignore[assignment] - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: def charsAsStr(s): max_repr_len = 16 s = _collapse_string_to_ranges(s, re_escape=False) @@ -2798,11 +2935,9 @@ class Word(Token): return s if self.initChars != self.bodyChars: - base = "W:({}, {})".format( - charsAsStr(self.initChars), charsAsStr(self.bodyChars) - ) + base = f"W:({charsAsStr(self.initChars)}, {charsAsStr(self.bodyChars)})" else: - base = "W:({})".format(charsAsStr(self.initChars)) + base = f"W:({charsAsStr(self.initChars)})" # add length specification if self.minLen > 1 or self.maxLen != _MAX_INT: @@ -2810,11 +2945,11 @@ class Word(Token): if self.minLen == 1: return base[2:] else: - return base + "{{{}}}".format(self.minLen) + return base + f"{{{self.minLen}}}" elif self.maxLen == _MAX_INT: - return base + "{{{},...}}".format(self.minLen) + return base + f"{{{self.minLen},...}}" else: - return base + "{{{},{}}}".format(self.minLen, self.maxLen) + return base + f"{{{self.minLen},{self.maxLen}}}" return base def parseImpl(self, instring, loc, doActions=True): @@ -2849,9 +2984,7 @@ class Word(Token): return loc, instring[start:loc] - -class _WordRegex(Word): - def parseImpl(self, instring, loc, doActions=True): + def parseImpl_regex(self, instring, loc, doActions=True): result = self.re_match(instring, loc) if not result: raise ParseException(instring, loc, self.errmsg, self) @@ -2860,7 +2993,7 @@ class _WordRegex(Word): return loc, result.group() -class Char(_WordRegex): +class Char(Word): """A short-cut class for defining :class:`Word` ``(characters, exact=1)``, when defining a match of any single character in a string of characters. @@ -2878,13 +3011,8 @@ class Char(_WordRegex): asKeyword = asKeyword or as_keyword excludeChars = excludeChars or exclude_chars super().__init__( - charset, exact=1, asKeyword=asKeyword, excludeChars=excludeChars + charset, exact=1, as_keyword=asKeyword, exclude_chars=excludeChars ) - self.reString = "[{}]".format(_collapse_string_to_ranges(self.initChars)) - if asKeyword: - self.reString = r"\b{}\b".format(self.reString) - self.re = re.compile(self.reString) - self.re_match = self.re.match class Regex(Token): @@ -2954,9 +3082,9 @@ class Regex(Token): self.asGroupList = asGroupList self.asMatch = asMatch if self.asGroupList: - self.parseImpl = self.parseImplAsGroupList + self.parseImpl = self.parseImplAsGroupList # type: ignore [assignment] if self.asMatch: - self.parseImpl = self.parseImplAsMatch + self.parseImpl = self.parseImplAsMatch # type: ignore [assignment] @cached_property def re(self): @@ -2966,9 +3094,7 @@ class Regex(Token): try: return re.compile(self.pattern, self.flags) except re.error: - raise ValueError( - "invalid pattern ({!r}) passed to Regex".format(self.pattern) - ) + raise ValueError(f"invalid pattern ({self.pattern!r}) passed to Regex") @cached_property def re_match(self): @@ -2978,7 +3104,7 @@ class Regex(Token): def mayReturnEmpty(self): return self.re_match("") is not None - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "Re:({})".format(repr(self.pattern).replace("\\\\", "\\")) def parseImpl(self, instring, loc, doActions=True): @@ -3024,10 +3150,12 @@ class Regex(Token): # prints "

main title

" """ if self.asGroupList: - raise TypeError("cannot use sub() with Regex(asGroupList=True)") + raise TypeError("cannot use sub() with Regex(as_group_list=True)") if self.asMatch and callable(repl): - raise TypeError("cannot use sub() with a callable with Regex(asMatch=True)") + raise TypeError( + "cannot use sub() with a callable with Regex(as_match=True)" + ) if self.asMatch: @@ -3081,7 +3209,7 @@ class QuotedString(Token): [['This is the "quote"']] [['This is the quote with "embedded" quotes']] """ - ws_map = ((r"\t", "\t"), (r"\n", "\n"), (r"\f", "\f"), (r"\r", "\r")) + ws_map = dict(((r"\t", "\t"), (r"\n", "\n"), (r"\f", "\f"), (r"\r", "\r"))) def __init__( self, @@ -3101,139 +3229,164 @@ class QuotedString(Token): convertWhitespaceEscapes: bool = True, ): super().__init__() - escChar = escChar or esc_char - escQuote = escQuote or esc_quote - unquoteResults = unquoteResults and unquote_results - endQuoteChar = endQuoteChar or end_quote_char - convertWhitespaceEscapes = ( + esc_char = escChar or esc_char + esc_quote = escQuote or esc_quote + unquote_results = unquoteResults and unquote_results + end_quote_char = endQuoteChar or end_quote_char + convert_whitespace_escapes = ( convertWhitespaceEscapes and convert_whitespace_escapes ) quote_char = quoteChar or quote_char - # remove white space from quote chars - wont work anyway + # remove white space from quote chars quote_char = quote_char.strip() if not quote_char: raise ValueError("quote_char cannot be the empty string") - if endQuoteChar is None: - endQuoteChar = quote_char + if end_quote_char is None: + end_quote_char = quote_char else: - endQuoteChar = endQuoteChar.strip() - if not endQuoteChar: - raise ValueError("endQuoteChar cannot be the empty string") + end_quote_char = end_quote_char.strip() + if not end_quote_char: + raise ValueError("end_quote_char cannot be the empty string") - self.quoteChar = quote_char - self.quoteCharLen = len(quote_char) - self.firstQuoteChar = quote_char[0] - self.endQuoteChar = endQuoteChar - self.endQuoteCharLen = len(endQuoteChar) - self.escChar = escChar - self.escQuote = escQuote - self.unquoteResults = unquoteResults - self.convertWhitespaceEscapes = convertWhitespaceEscapes + self.quote_char: str = quote_char + self.quote_char_len: int = len(quote_char) + self.first_quote_char: str = quote_char[0] + self.end_quote_char: str = end_quote_char + self.end_quote_char_len: int = len(end_quote_char) + self.esc_char: str = esc_char or "" + self.has_esc_char: bool = esc_char is not None + self.esc_quote: str = esc_quote or "" + self.unquote_results: bool = unquote_results + self.convert_whitespace_escapes: bool = convert_whitespace_escapes + self.multiline = multiline + self.re_flags = re.RegexFlag(0) - sep = "" - inner_pattern = "" + # fmt: off + # build up re pattern for the content between the quote delimiters + inner_pattern = [] - if escQuote: - inner_pattern += r"{}(?:{})".format(sep, re.escape(escQuote)) - sep = "|" + if esc_quote: + inner_pattern.append(rf"(?:{re.escape(esc_quote)})") - if escChar: - inner_pattern += r"{}(?:{}.)".format(sep, re.escape(escChar)) - sep = "|" - self.escCharReplacePattern = re.escape(self.escChar) + "(.)" + if esc_char: + inner_pattern.append(rf"(?:{re.escape(esc_char)}.)") - if len(self.endQuoteChar) > 1: - inner_pattern += ( - "{}(?:".format(sep) + if len(self.end_quote_char) > 1: + inner_pattern.append( + "(?:" + "|".join( - "(?:{}(?!{}))".format( - re.escape(self.endQuoteChar[:i]), - re.escape(self.endQuoteChar[i:]), - ) - for i in range(len(self.endQuoteChar) - 1, 0, -1) + f"(?:{re.escape(self.end_quote_char[:i])}(?!{re.escape(self.end_quote_char[i:])}))" + for i in range(len(self.end_quote_char) - 1, 0, -1) ) + ")" ) - sep = "|" - if multiline: - self.flags = re.MULTILINE | re.DOTALL - inner_pattern += r"{}(?:[^{}{}])".format( - sep, - _escape_regex_range_chars(self.endQuoteChar[0]), - (_escape_regex_range_chars(escChar) if escChar is not None else ""), + if self.multiline: + self.re_flags |= re.MULTILINE | re.DOTALL + inner_pattern.append( + rf"(?:[^{_escape_regex_range_chars(self.end_quote_char[0])}" + rf"{(_escape_regex_range_chars(esc_char) if self.has_esc_char else '')}])" ) else: - self.flags = 0 - inner_pattern += r"{}(?:[^{}\n\r{}])".format( - sep, - _escape_regex_range_chars(self.endQuoteChar[0]), - (_escape_regex_range_chars(escChar) if escChar is not None else ""), + inner_pattern.append( + rf"(?:[^{_escape_regex_range_chars(self.end_quote_char[0])}\n\r" + rf"{(_escape_regex_range_chars(esc_char) if self.has_esc_char else '')}])" ) self.pattern = "".join( [ - re.escape(self.quoteChar), + re.escape(self.quote_char), "(?:", - inner_pattern, + '|'.join(inner_pattern), ")*", - re.escape(self.endQuoteChar), + re.escape(self.end_quote_char), ] ) + if self.unquote_results: + if self.convert_whitespace_escapes: + self.unquote_scan_re = re.compile( + rf"({'|'.join(re.escape(k) for k in self.ws_map)})" + rf"|({re.escape(self.esc_char)}.)" + rf"|(\n|.)", + flags=self.re_flags, + ) + else: + self.unquote_scan_re = re.compile( + rf"({re.escape(self.esc_char)}.)" + rf"|(\n|.)", + flags=self.re_flags + ) + # fmt: on + try: - self.re = re.compile(self.pattern, self.flags) + self.re = re.compile(self.pattern, self.re_flags) self.reString = self.pattern self.re_match = self.re.match except re.error: - raise ValueError( - "invalid pattern {!r} passed to Regex".format(self.pattern) - ) + raise ValueError(f"invalid pattern {self.pattern!r} passed to Regex") self.errmsg = "Expected " + self.name self.mayIndexError = False self.mayReturnEmpty = True - def _generateDefaultName(self): - if self.quoteChar == self.endQuoteChar and isinstance(self.quoteChar, str_type): - return "string enclosed in {!r}".format(self.quoteChar) + def _generateDefaultName(self) -> str: + if self.quote_char == self.end_quote_char and isinstance( + self.quote_char, str_type + ): + return f"string enclosed in {self.quote_char!r}" - return "quoted string, starting with {} ending with {}".format( - self.quoteChar, self.endQuoteChar - ) + return f"quoted string, starting with {self.quote_char} ending with {self.end_quote_char}" def parseImpl(self, instring, loc, doActions=True): + # check first character of opening quote to see if that is a match + # before doing the more complicated regex match result = ( - instring[loc] == self.firstQuoteChar + instring[loc] == self.first_quote_char and self.re_match(instring, loc) or None ) if not result: raise ParseException(instring, loc, self.errmsg, self) + # get ending loc and matched string from regex matching result loc = result.end() ret = result.group() - if self.unquoteResults: - + if self.unquote_results: # strip off quotes - ret = ret[self.quoteCharLen : -self.endQuoteCharLen] + ret = ret[self.quote_char_len : -self.end_quote_char_len] if isinstance(ret, str_type): - # replace escaped whitespace - if "\\" in ret and self.convertWhitespaceEscapes: - for wslit, wschar in self.ws_map: - ret = ret.replace(wslit, wschar) - - # replace escaped characters - if self.escChar: - ret = re.sub(self.escCharReplacePattern, r"\g<1>", ret) + # fmt: off + if self.convert_whitespace_escapes: + # as we iterate over matches in the input string, + # collect from whichever match group of the unquote_scan_re + # regex matches (only 1 group will match at any given time) + ret = "".join( + # match group 1 matches \t, \n, etc. + self.ws_map[match.group(1)] if match.group(1) + # match group 2 matches escaped characters + else match.group(2)[-1] if match.group(2) + # match group 3 matches any character + else match.group(3) + for match in self.unquote_scan_re.finditer(ret) + ) + else: + ret = "".join( + # match group 1 matches escaped characters + match.group(1)[-1] if match.group(1) + # match group 2 matches any character + else match.group(2) + for match in self.unquote_scan_re.finditer(ret) + ) + # fmt: on # replace escaped quotes - if self.escQuote: - ret = ret.replace(self.escQuote, self.endQuoteChar) + if self.esc_quote: + ret = ret.replace(self.esc_quote, self.end_quote_char) return loc, ret @@ -3252,7 +3405,7 @@ class CharsNotIn(Token): # define a comma-separated-value as anything that is not a ',' csv_value = CharsNotIn(',') - print(delimited_list(csv_value).parse_string("dkls,lsdkjf,s12 34,@!#,213")) + print(DelimitedList(csv_value).parse_string("dkls,lsdkjf,s12 34,@!#,213")) prints:: @@ -3294,12 +3447,12 @@ class CharsNotIn(Token): self.mayReturnEmpty = self.minLen == 0 self.mayIndexError = False - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: not_chars_str = _collapse_string_to_ranges(self.notChars) if len(not_chars_str) > 16: - return "!W:({}...)".format(self.notChars[: 16 - 3]) + return f"!W:({self.notChars[: 16 - 3]}...)" else: - return "!W:({})".format(self.notChars) + return f"!W:({self.notChars})" def parseImpl(self, instring, loc, doActions=True): notchars = self.notCharsSet @@ -3376,7 +3529,7 @@ class White(Token): self.maxLen = exact self.minLen = exact - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "".join(White.whiteStrs[c] for c in self.matchWhite) def parseImpl(self, instring, loc, doActions=True): @@ -3411,7 +3564,7 @@ class GoToColumn(PositionToken): super().__init__() self.col = colno - def preParse(self, instring, loc): + def preParse(self, instring: str, loc: int) -> int: if col(loc, instring) != self.col: instrlen = len(instring) if self.ignoreExprs: @@ -3446,7 +3599,7 @@ class LineStart(PositionToken): B AAA and definitely not this one ''' - for t in (LineStart() + 'AAA' + restOfLine).search_string(test): + for t in (LineStart() + 'AAA' + rest_of_line).search_string(test): print(t) prints:: @@ -3464,7 +3617,7 @@ class LineStart(PositionToken): self.skipper = Empty().set_whitespace_chars(self.whiteChars) self.errmsg = "Expected start of line" - def preParse(self, instring, loc): + def preParse(self, instring: str, loc: int) -> int: if loc == 0: return loc else: @@ -3624,7 +3777,7 @@ class ParseExpression(ParserElement): self.exprs = [exprs] self.callPreparse = False - def recurse(self) -> Sequence[ParserElement]: + def recurse(self) -> List[ParserElement]: return self.exprs[:] def append(self, other) -> ParserElement: @@ -3669,8 +3822,8 @@ class ParseExpression(ParserElement): e.ignore(self.ignoreExprs[-1]) return self - def _generateDefaultName(self): - return "{}:({})".format(self.__class__.__name__, str(self.exprs)) + def _generateDefaultName(self) -> str: + return f"{self.__class__.__name__}:({str(self.exprs)})" def streamline(self) -> ParserElement: if self.streamlined: @@ -3714,6 +3867,11 @@ class ParseExpression(ParserElement): return self def validate(self, validateTrace=None) -> None: + warnings.warn( + "ParserElement.validate() is deprecated, and should not be used to check for left recursion", + DeprecationWarning, + stacklevel=2, + ) tmp = (validateTrace if validateTrace is not None else [])[:] + [self] for e in self.exprs: e.validate(tmp) @@ -3721,6 +3879,7 @@ class ParseExpression(ParserElement): def copy(self) -> ParserElement: ret = super().copy() + ret = typing.cast(ParseExpression, ret) ret.exprs = [e.copy() for e in self.exprs] return ret @@ -3750,8 +3909,14 @@ class ParseExpression(ParserElement): return super()._setResultsName(name, listAllMatches) - ignoreWhitespace = ignore_whitespace - leaveWhitespace = leave_whitespace + # Compatibility synonyms + # fmt: off + @replaced_by_pep8(leave_whitespace) + def leaveWhitespace(self): ... + + @replaced_by_pep8(ignore_whitespace) + def ignoreWhitespace(self): ... + # fmt: on class And(ParseExpression): @@ -3777,7 +3942,7 @@ class And(ParseExpression): super().__init__(*args, **kwargs) self.leave_whitespace() - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "-" def __init__( @@ -3789,7 +3954,9 @@ class And(ParseExpression): for i, expr in enumerate(exprs): if expr is Ellipsis: if i < len(exprs) - 1: - skipto_arg: ParserElement = (Empty() + exprs[i + 1]).exprs[-1] + skipto_arg: ParserElement = typing.cast( + ParseExpression, (Empty() + exprs[i + 1]) + ).exprs[-1] tmp.append(SkipTo(skipto_arg)("_skipped*")) else: raise Exception( @@ -3822,8 +3989,9 @@ class And(ParseExpression): and isinstance(e.exprs[-1], _PendingSkip) for e in self.exprs[:-1] ): + deleted_expr_marker = NoMatch() for i, e in enumerate(self.exprs[:-1]): - if e is None: + if e is deleted_expr_marker: continue if ( isinstance(e, ParseExpression) @@ -3831,17 +3999,19 @@ class And(ParseExpression): and isinstance(e.exprs[-1], _PendingSkip) ): e.exprs[-1] = e.exprs[-1] + self.exprs[i + 1] - self.exprs[i + 1] = None - self.exprs = [e for e in self.exprs if e is not None] + self.exprs[i + 1] = deleted_expr_marker + self.exprs = [e for e in self.exprs if e is not deleted_expr_marker] super().streamline() # link any IndentedBlocks to the prior expression + prev: ParserElement + cur: ParserElement for prev, cur in zip(self.exprs, self.exprs[1:]): # traverse cur or any first embedded expr of cur looking for an IndentedBlock # (but watch out for recursive grammar) seen = set() - while cur: + while True: if id(cur) in seen: break seen.add(id(cur)) @@ -3853,7 +4023,10 @@ class And(ParseExpression): ) break subs = cur.recurse() - cur = next(iter(subs), None) + next_first = next(iter(subs), None) + if next_first is None: + break + cur = typing.cast(ParserElement, next_first) self.mayReturnEmpty = all(e.mayReturnEmpty for e in self.exprs) return self @@ -3884,13 +4057,14 @@ class And(ParseExpression): ) else: loc, exprtokens = e._parse(instring, loc, doActions) - if exprtokens or exprtokens.haskeys(): - resultlist += exprtokens + resultlist += exprtokens return loc, resultlist def __iadd__(self, other): if isinstance(other, str_type): other = self._literalStringClass(other) + if not isinstance(other, ParserElement): + return NotImplemented return self.append(other) # And([self, other]) def _checkRecursion(self, parseElementList): @@ -3900,7 +4074,7 @@ class And(ParseExpression): if not e.mayReturnEmpty: break - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: inner = " ".join(str(e) for e in self.exprs) # strip off redundant inner {}'s while len(inner) > 1 and inner[0 :: len(inner) - 1] == "{}": @@ -3958,7 +4132,7 @@ class Or(ParseExpression): loc2 = e.try_parse(instring, loc, raise_fatal=True) except ParseFatalException as pfe: pfe.__traceback__ = None - pfe.parserElement = e + pfe.parser_element = e fatals.append(pfe) maxException = None maxExcLoc = -1 @@ -4016,12 +4190,15 @@ class Or(ParseExpression): if len(fatals) > 1: fatals.sort(key=lambda e: -e.loc) if fatals[0].loc == fatals[1].loc: - fatals.sort(key=lambda e: (-e.loc, -len(str(e.parserElement)))) + fatals.sort(key=lambda e: (-e.loc, -len(str(e.parser_element)))) max_fatal = fatals[0] raise max_fatal if maxException is not None: - maxException.msg = self.errmsg + # infer from this check that all alternatives failed at the current position + # so emit this collective error message instead of any single error message + if maxExcLoc == loc: + maxException.msg = self.errmsg raise maxException else: raise ParseException( @@ -4031,9 +4208,11 @@ class Or(ParseExpression): def __ixor__(self, other): if isinstance(other, str_type): other = self._literalStringClass(other) + if not isinstance(other, ParserElement): + return NotImplemented return self.append(other) # Or([self, other]) - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "{" + " ^ ".join(str(e) for e in self.exprs) + "}" def _setResultsName(self, name, listAllMatches=False): @@ -4118,7 +4297,7 @@ class MatchFirst(ParseExpression): ) except ParseFatalException as pfe: pfe.__traceback__ = None - pfe.parserElement = e + pfe.parser_element = e raise except ParseException as err: if err.loc > maxExcLoc: @@ -4132,7 +4311,10 @@ class MatchFirst(ParseExpression): maxExcLoc = len(instring) if maxException is not None: - maxException.msg = self.errmsg + # infer from this check that all alternatives failed at the current position + # so emit this collective error message instead of any individual error message + if maxExcLoc == loc: + maxException.msg = self.errmsg raise maxException else: raise ParseException( @@ -4142,9 +4324,11 @@ class MatchFirst(ParseExpression): def __ior__(self, other): if isinstance(other, str_type): other = self._literalStringClass(other) + if not isinstance(other, ParserElement): + return NotImplemented return self.append(other) # MatchFirst([self, other]) - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "{" + " | ".join(str(e) for e in self.exprs) + "}" def _setResultsName(self, name, listAllMatches=False): @@ -4242,6 +4426,13 @@ class Each(ParseExpression): self.initExprGroups = True self.saveAsList = True + def __iand__(self, other): + if isinstance(other, str_type): + other = self._literalStringClass(other) + if not isinstance(other, ParserElement): + return NotImplemented + return self.append(other) # Each([self, other]) + def streamline(self) -> ParserElement: super().streamline() if self.exprs: @@ -4296,7 +4487,7 @@ class Each(ParseExpression): tmpLoc = e.try_parse(instring, tmpLoc, raise_fatal=True) except ParseFatalException as pfe: pfe.__traceback__ = None - pfe.parserElement = e + pfe.parser_element = e fatals.append(pfe) failed.append(e) except ParseException: @@ -4315,7 +4506,7 @@ class Each(ParseExpression): if len(fatals) > 1: fatals.sort(key=lambda e: -e.loc) if fatals[0].loc == fatals[1].loc: - fatals.sort(key=lambda e: (-e.loc, -len(str(e.parserElement)))) + fatals.sort(key=lambda e: (-e.loc, -len(str(e.parser_element)))) max_fatal = fatals[0] raise max_fatal @@ -4324,7 +4515,7 @@ class Each(ParseExpression): raise ParseException( instring, loc, - "Missing one or more required elements ({})".format(missing), + f"Missing one or more required elements ({missing})", ) # add any unmatched Opts, in case they have default values defined @@ -4337,7 +4528,7 @@ class Each(ParseExpression): return loc, total_results - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "{" + " & ".join(str(e) for e in self.exprs) + "}" @@ -4349,12 +4540,14 @@ class ParseElementEnhance(ParserElement): def __init__(self, expr: Union[ParserElement, str], savelist: bool = False): super().__init__(savelist) if isinstance(expr, str_type): + expr_str = typing.cast(str, expr) if issubclass(self._literalStringClass, Token): - expr = self._literalStringClass(expr) + expr = self._literalStringClass(expr_str) # type: ignore[call-arg] elif issubclass(type(self), self._literalStringClass): - expr = Literal(expr) + expr = Literal(expr_str) else: - expr = self._literalStringClass(Literal(expr)) + expr = self._literalStringClass(Literal(expr_str)) # type: ignore[assignment, call-arg] + expr = typing.cast(ParserElement, expr) self.expr = expr if expr is not None: self.mayIndexError = expr.mayIndexError @@ -4367,12 +4560,17 @@ class ParseElementEnhance(ParserElement): self.callPreparse = expr.callPreparse self.ignoreExprs.extend(expr.ignoreExprs) - def recurse(self) -> Sequence[ParserElement]: + def recurse(self) -> List[ParserElement]: return [self.expr] if self.expr is not None else [] def parseImpl(self, instring, loc, doActions=True): if self.expr is not None: - return self.expr._parse(instring, loc, doActions, callPreParse=False) + try: + return self.expr._parse(instring, loc, doActions, callPreParse=False) + except ParseBaseException as pbe: + if not isinstance(self, Forward) or self.customName is not None: + pbe.msg = self.errmsg + raise else: raise ParseException(instring, loc, "No expression defined", self) @@ -4380,8 +4578,8 @@ class ParseElementEnhance(ParserElement): super().leave_whitespace(recursive) if recursive: - self.expr = self.expr.copy() if self.expr is not None: + self.expr = self.expr.copy() self.expr.leave_whitespace(recursive) return self @@ -4389,8 +4587,8 @@ class ParseElementEnhance(ParserElement): super().ignore_whitespace(recursive) if recursive: - self.expr = self.expr.copy() if self.expr is not None: + self.expr = self.expr.copy() self.expr.ignore_whitespace(recursive) return self @@ -4420,6 +4618,11 @@ class ParseElementEnhance(ParserElement): self.expr._checkRecursion(subRecCheckList) def validate(self, validateTrace=None) -> None: + warnings.warn( + "ParserElement.validate() is deprecated, and should not be used to check for left recursion", + DeprecationWarning, + stacklevel=2, + ) if validateTrace is None: validateTrace = [] tmp = validateTrace[:] + [self] @@ -4427,11 +4630,17 @@ class ParseElementEnhance(ParserElement): self.expr.validate(tmp) self._checkRecursion([]) - def _generateDefaultName(self): - return "{}:({})".format(self.__class__.__name__, str(self.expr)) + def _generateDefaultName(self) -> str: + return f"{self.__class__.__name__}:({str(self.expr)})" - ignoreWhitespace = ignore_whitespace - leaveWhitespace = leave_whitespace + # Compatibility synonyms + # fmt: off + @replaced_by_pep8(leave_whitespace) + def leaveWhitespace(self): ... + + @replaced_by_pep8(ignore_whitespace) + def ignoreWhitespace(self): ... + # fmt: on class IndentedBlock(ParseElementEnhance): @@ -4443,13 +4652,13 @@ class IndentedBlock(ParseElementEnhance): class _Indent(Empty): def __init__(self, ref_col: int): super().__init__() - self.errmsg = "expected indent at column {}".format(ref_col) + self.errmsg = f"expected indent at column {ref_col}" self.add_condition(lambda s, l, t: col(l, s) == ref_col) class _IndentGreater(Empty): def __init__(self, ref_col: int): super().__init__() - self.errmsg = "expected indent at column greater than {}".format(ref_col) + self.errmsg = f"expected indent at column greater than {ref_col}" self.add_condition(lambda s, l, t: col(l, s) > ref_col) def __init__( @@ -4469,7 +4678,7 @@ class IndentedBlock(ParseElementEnhance): # see if self.expr matches at the current location - if not it will raise an exception # and no further work is necessary - self.expr.try_parse(instring, anchor_loc, doActions) + self.expr.try_parse(instring, anchor_loc, do_actions=doActions) indent_col = col(anchor_loc, instring) peer_detect_expr = self._Indent(indent_col) @@ -4532,7 +4741,7 @@ class AtLineStart(ParseElementEnhance): B AAA and definitely not this one ''' - for t in (AtLineStart('AAA') + restOfLine).search_string(test): + for t in (AtLineStart('AAA') + rest_of_line).search_string(test): print(t) prints:: @@ -4598,9 +4807,9 @@ class PrecededBy(ParseElementEnhance): Parameters: - - expr - expression that must match prior to the current parse + - ``expr`` - expression that must match prior to the current parse location - - retreat - (default= ``None``) - (int) maximum number of characters + - ``retreat`` - (default= ``None``) - (int) maximum number of characters to lookbehind prior to the current parse location If the lookbehind expression is a string, :class:`Literal`, @@ -4627,6 +4836,7 @@ class PrecededBy(ParseElementEnhance): self.mayIndexError = False self.exact = False if isinstance(expr, str_type): + expr = typing.cast(str, expr) retreat = len(expr) self.exact = True elif isinstance(expr, (Literal, Keyword)): @@ -4746,18 +4956,18 @@ class NotAny(ParseElementEnhance): self.errmsg = "Found unwanted token, " + str(self.expr) def parseImpl(self, instring, loc, doActions=True): - if self.expr.can_parse_next(instring, loc): + if self.expr.can_parse_next(instring, loc, do_actions=doActions): raise ParseException(instring, loc, self.errmsg, self) return loc, [] - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "~{" + str(self.expr) + "}" class _MultipleMatch(ParseElementEnhance): def __init__( self, - expr: ParserElement, + expr: Union[str, ParserElement], stop_on: typing.Optional[Union[ParserElement, str]] = None, *, stopOn: typing.Optional[Union[ParserElement, str]] = None, @@ -4781,7 +4991,7 @@ class _MultipleMatch(ParseElementEnhance): self_skip_ignorables = self._skipIgnorables check_ender = self.not_ender is not None if check_ender: - try_not_ender = self.not_ender.tryParse + try_not_ender = self.not_ender.try_parse # must be at least one (but first see if we are the stopOn sentinel; # if so, fail) @@ -4798,8 +5008,7 @@ class _MultipleMatch(ParseElementEnhance): else: preloc = loc loc, tmptokens = self_expr_parse(instring, preloc, doActions) - if tmptokens or tmptokens.haskeys(): - tokens += tmptokens + tokens += tmptokens except (ParseException, IndexError): pass @@ -4837,10 +5046,11 @@ class OneOrMore(_MultipleMatch): Repetition of one or more of the given expression. Parameters: - - expr - expression that must match one or more times - - stop_on - (default= ``None``) - expression for a terminating sentinel - (only required if the sentinel would ordinarily match the repetition - expression) + + - ``expr`` - expression that must match one or more times + - ``stop_on`` - (default= ``None``) - expression for a terminating sentinel + (only required if the sentinel would ordinarily match the repetition + expression) Example:: @@ -4859,7 +5069,7 @@ class OneOrMore(_MultipleMatch): (attr_expr * (1,)).parse_string(text).pprint() """ - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "{" + str(self.expr) + "}..." @@ -4868,6 +5078,7 @@ class ZeroOrMore(_MultipleMatch): Optional repetition of zero or more of the given expression. Parameters: + - ``expr`` - expression that must match zero or more times - ``stop_on`` - expression for a terminating sentinel (only required if the sentinel would ordinarily match the repetition @@ -4878,7 +5089,7 @@ class ZeroOrMore(_MultipleMatch): def __init__( self, - expr: ParserElement, + expr: Union[str, ParserElement], stop_on: typing.Optional[Union[ParserElement, str]] = None, *, stopOn: typing.Optional[Union[ParserElement, str]] = None, @@ -4892,10 +5103,75 @@ class ZeroOrMore(_MultipleMatch): except (ParseException, IndexError): return loc, ParseResults([], name=self.resultsName) - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: return "[" + str(self.expr) + "]..." +class DelimitedList(ParseElementEnhance): + def __init__( + self, + expr: Union[str, ParserElement], + delim: Union[str, ParserElement] = ",", + combine: bool = False, + min: typing.Optional[int] = None, + max: typing.Optional[int] = None, + *, + allow_trailing_delim: bool = False, + ): + """Helper to define a delimited list of expressions - the delimiter + defaults to ','. By default, the list elements and delimiters can + have intervening whitespace, and comments, but this can be + overridden by passing ``combine=True`` in the constructor. If + ``combine`` is set to ``True``, the matching tokens are + returned as a single token string, with the delimiters included; + otherwise, the matching tokens are returned as a list of tokens, + with the delimiters suppressed. + + If ``allow_trailing_delim`` is set to True, then the list may end with + a delimiter. + + Example:: + + DelimitedList(Word(alphas)).parse_string("aa,bb,cc") # -> ['aa', 'bb', 'cc'] + DelimitedList(Word(hexnums), delim=':', combine=True).parse_string("AA:BB:CC:DD:EE") # -> ['AA:BB:CC:DD:EE'] + """ + if isinstance(expr, str_type): + expr = ParserElement._literalStringClass(expr) + expr = typing.cast(ParserElement, expr) + + if min is not None: + if min < 1: + raise ValueError("min must be greater than 0") + if max is not None: + if min is not None and max < min: + raise ValueError("max must be greater than, or equal to min") + + self.content = expr + self.raw_delim = str(delim) + self.delim = delim + self.combine = combine + if not combine: + self.delim = Suppress(delim) + self.min = min or 1 + self.max = max + self.allow_trailing_delim = allow_trailing_delim + + delim_list_expr = self.content + (self.delim + self.content) * ( + self.min - 1, + None if self.max is None else self.max - 1, + ) + if self.allow_trailing_delim: + delim_list_expr += Opt(self.delim) + + if self.combine: + delim_list_expr = Combine(delim_list_expr) + + super().__init__(delim_list_expr, savelist=True) + + def _generateDefaultName(self) -> str: + return "{0} [{1} {0}]...".format(self.content.streamline(), self.raw_delim) + + class _NullToken: def __bool__(self): return False @@ -4909,6 +5185,7 @@ class Opt(ParseElementEnhance): Optional matching of the given expression. Parameters: + - ``expr`` - expression that must match zero or more times - ``default`` (optional) - value to be returned if the optional expression is not found. @@ -4969,7 +5246,7 @@ class Opt(ParseElementEnhance): tokens = [] return loc, tokens - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: inner = str(self.expr) # strip off redundant inner {}'s while len(inner) > 1 and inner[0 :: len(inner) - 1] == "{}": @@ -4986,6 +5263,7 @@ class SkipTo(ParseElementEnhance): expression is found. Parameters: + - ``expr`` - target expression marking the end of the data to be skipped - ``include`` - if ``True``, the target expression is also parsed (the skipped text and target expression are returned as a 2-element @@ -5045,10 +5323,10 @@ class SkipTo(ParseElementEnhance): self, other: Union[ParserElement, str], include: bool = False, - ignore: bool = None, + ignore: typing.Optional[Union[ParserElement, str]] = None, fail_on: typing.Optional[Union[ParserElement, str]] = None, *, - failOn: Union[ParserElement, str] = None, + failOn: typing.Optional[Union[ParserElement, str]] = None, ): super().__init__(other) failOn = failOn or fail_on @@ -5062,6 +5340,20 @@ class SkipTo(ParseElementEnhance): else: self.failOn = failOn self.errmsg = "No match found for " + str(self.expr) + self.ignorer = Empty().leave_whitespace() + self._update_ignorer() + + def _update_ignorer(self): + # rebuild internal ignore expr from current ignore exprs and assigned ignoreExpr + self.ignorer.ignoreExprs.clear() + for e in self.expr.ignoreExprs: + self.ignorer.ignore(e) + if self.ignoreExpr: + self.ignorer.ignore(self.ignoreExpr) + + def ignore(self, expr): + super().ignore(expr) + self._update_ignorer() def parseImpl(self, instring, loc, doActions=True): startloc = loc @@ -5070,9 +5362,7 @@ class SkipTo(ParseElementEnhance): self_failOn_canParseNext = ( self.failOn.canParseNext if self.failOn is not None else None ) - self_ignoreExpr_tryParse = ( - self.ignoreExpr.tryParse if self.ignoreExpr is not None else None - ) + ignorer_try_parse = self.ignorer.try_parse if self.ignorer.ignoreExprs else None tmploc = loc while tmploc <= instrlen: @@ -5081,13 +5371,18 @@ class SkipTo(ParseElementEnhance): if self_failOn_canParseNext(instring, tmploc): break - if self_ignoreExpr_tryParse is not None: + if ignorer_try_parse is not None: # advance past ignore expressions + prev_tmploc = tmploc while 1: try: - tmploc = self_ignoreExpr_tryParse(instring, tmploc) + tmploc = ignorer_try_parse(instring, tmploc) except ParseBaseException: break + # see if all ignorers matched, but didn't actually ignore anything + if tmploc == prev_tmploc: + break + prev_tmploc = tmploc try: self_expr_parse(instring, tmploc, doActions=False, callPreParse=False) @@ -5145,15 +5440,20 @@ class Forward(ParseElementEnhance): def __init__(self, other: typing.Optional[Union[ParserElement, str]] = None): self.caller_frame = traceback.extract_stack(limit=2)[0] - super().__init__(other, savelist=False) + super().__init__(other, savelist=False) # type: ignore[arg-type] self.lshift_line = None - def __lshift__(self, other): + def __lshift__(self, other) -> "Forward": if hasattr(self, "caller_frame"): del self.caller_frame if isinstance(other, str_type): other = self._literalStringClass(other) + + if not isinstance(other, ParserElement): + return NotImplemented + self.expr = other + self.streamlined = other.streamlined self.mayIndexError = self.expr.mayIndexError self.mayReturnEmpty = self.expr.mayReturnEmpty self.set_whitespace_chars( @@ -5162,13 +5462,16 @@ class Forward(ParseElementEnhance): self.skipWhitespace = self.expr.skipWhitespace self.saveAsList = self.expr.saveAsList self.ignoreExprs.extend(self.expr.ignoreExprs) - self.lshift_line = traceback.extract_stack(limit=2)[-2] + self.lshift_line = traceback.extract_stack(limit=2)[-2] # type: ignore[assignment] return self - def __ilshift__(self, other): + def __ilshift__(self, other) -> "Forward": + if not isinstance(other, ParserElement): + return NotImplemented + return self << other - def __or__(self, other): + def __or__(self, other) -> "ParserElement": caller_line = traceback.extract_stack(limit=2)[-2] if ( __diag__.warn_on_match_first_with_lshift_operator @@ -5205,12 +5508,12 @@ class Forward(ParseElementEnhance): not in self.suppress_warnings_ ): # walk stack until parse_string, scan_string, search_string, or transform_string is found - parse_fns = [ + parse_fns = ( "parse_string", "scan_string", "search_string", "transform_string", - ] + ) tb = traceback.extract_stack(limit=200) for i, frm in enumerate(reversed(tb), start=1): if frm.name in parse_fns: @@ -5308,6 +5611,11 @@ class Forward(ParseElementEnhance): return self def validate(self, validateTrace=None) -> None: + warnings.warn( + "ParserElement.validate() is deprecated, and should not be used to check for left recursion", + DeprecationWarning, + stacklevel=2, + ) if validateTrace is None: validateTrace = [] @@ -5317,7 +5625,7 @@ class Forward(ParseElementEnhance): self.expr.validate(tmp) self._checkRecursion([]) - def _generateDefaultName(self): + def _generateDefaultName(self) -> str: # Avoid infinite recursion by setting a temporary _defaultName self._defaultName = ": ..." @@ -5356,8 +5664,14 @@ class Forward(ParseElementEnhance): return super()._setResultsName(name, list_all_matches) - ignoreWhitespace = ignore_whitespace - leaveWhitespace = leave_whitespace + # Compatibility synonyms + # fmt: off + @replaced_by_pep8(leave_whitespace) + def leaveWhitespace(self): ... + + @replaced_by_pep8(ignore_whitespace) + def ignoreWhitespace(self): ... + # fmt: on class TokenConverter(ParseElementEnhance): @@ -5439,11 +5753,11 @@ class Group(TokenConverter): ident = Word(alphas) num = Word(nums) term = ident | num - func = ident + Opt(delimited_list(term)) + func = ident + Opt(DelimitedList(term)) print(func.parse_string("fn a, b, 100")) # -> ['fn', 'a', 'b', '100'] - func = ident + Group(Opt(delimited_list(term))) + func = ident + Group(Opt(DelimitedList(term))) print(func.parse_string("fn a, b, 100")) # -> ['fn', ['a', 'b', '100']] """ @@ -5579,7 +5893,7 @@ class Suppress(TokenConverter): ['a', 'b', 'c', 'd'] ['START', 'relevant text ', 'END'] - (See also :class:`delimited_list`.) + (See also :class:`DelimitedList`.) """ def __init__(self, expr: Union[ParserElement, str], savelist: bool = False): @@ -5638,15 +5952,13 @@ def trace_parse_action(f: ParseAction) -> ParseAction: s, l, t = paArgs[-3:] if len(paArgs) > 3: thisFunc = paArgs[0].__class__.__name__ + "." + thisFunc - sys.stderr.write( - ">>entering {}(line: {!r}, {}, {!r})\n".format(thisFunc, line(l, s), l, t) - ) + sys.stderr.write(f">>entering {thisFunc}(line: {line(l, s)!r}, {l}, {t!r})\n") try: ret = f(*paArgs) except Exception as exc: - sys.stderr.write("< str: ) try: return "".join(_expanded(part) for part in _reBracketExpr.parse_string(s).body) - except Exception: + except Exception as e: return "" @@ -5769,7 +6081,11 @@ def autoname_elements() -> None: Utility to simplify mass-naming of parser elements, for generating railroad diagram with named subdiagrams. """ - for name, var in sys._getframe().f_back.f_locals.items(): + calling_frame = sys._getframe().f_back + if calling_frame is None: + return + calling_frame = typing.cast(types.FrameType, calling_frame) + for name, var in calling_frame.f_locals.items(): if isinstance(var, ParserElement) and not var.customName: var.set_name(name) @@ -5783,9 +6099,28 @@ sgl_quoted_string = Combine( ).set_name("string enclosed in single quotes") quoted_string = Combine( - Regex(r'"(?:[^"\n\r\\]|(?:"")|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*') + '"' - | Regex(r"'(?:[^'\n\r\\]|(?:'')|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*") + "'" -).set_name("quotedString using single or double quotes") + (Regex(r'"(?:[^"\n\r\\]|(?:"")|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*') + '"').set_name( + "double quoted string" + ) + | (Regex(r"'(?:[^'\n\r\\]|(?:'')|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*") + "'").set_name( + "single quoted string" + ) +).set_name("quoted string using single or double quotes") + +python_quoted_string = Combine( + (Regex(r'"""(?:[^"\\]|""(?!")|"(?!"")|\\.)*', flags=re.MULTILINE) + '"""').set_name( + "multiline double quoted string" + ) + ^ ( + Regex(r"'''(?:[^'\\]|''(?!')|'(?!'')|\\.)*", flags=re.MULTILINE) + "'''" + ).set_name("multiline single quoted string") + ^ (Regex(r'"(?:[^"\n\r\\]|(?:\\")|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*') + '"').set_name( + "double quoted string" + ) + ^ (Regex(r"'(?:[^'\n\r\\]|(?:\\')|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*") + "'").set_name( + "single quoted string" + ) +).set_name("Python quoted string") unicode_string = Combine("u" + quoted_string.copy()).set_name("unicode string literal") @@ -5800,9 +6135,7 @@ _builtin_exprs: List[ParserElement] = [ ] # backward compatibility names -tokenMap = token_map -conditionAsParseAction = condition_as_parse_action -nullDebugAction = null_debug_action +# fmt: off sglQuotedString = sgl_quoted_string dblQuotedString = dbl_quoted_string quotedString = quoted_string @@ -5811,4 +6144,16 @@ lineStart = line_start lineEnd = line_end stringStart = string_start stringEnd = string_end -traceParseAction = trace_parse_action + +@replaced_by_pep8(null_debug_action) +def nullDebugAction(): ... + +@replaced_by_pep8(trace_parse_action) +def traceParseAction(): ... + +@replaced_by_pep8(condition_as_parse_action) +def conditionAsParseAction(): ... + +@replaced_by_pep8(token_map) +def tokenMap(): ... +# fmt: on diff --git a/lib/pyparsing/diagram/__init__.py b/lib/pyparsing/diagram/__init__.py index 89864475..267f3447 100644 --- a/lib/pyparsing/diagram/__init__.py +++ b/lib/pyparsing/diagram/__init__.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors import railroad import pyparsing import typing @@ -17,11 +18,13 @@ import inspect jinja2_template_source = """\ +{% if not embed %} +{% endif %} {% if not head %} -