From aa060dfa5506ed038d4078a3757ecc2302bd8655 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 6 Oct 2024 15:04:56 -0700 Subject: [PATCH 01/12] start of pyright linting --- dns/dnssec.py | 12 ++++++------ dns/edns.py | 5 ++++- dns/grange.py | 2 +- dns/name.py | 6 ++++++ dns/rdtypes/util.py | 28 ++++++++++++++++++++-------- dns/reversename.py | 1 + dns/update.py | 2 ++ pyproject.toml | 5 +++++ 8 files changed, 45 insertions(+), 16 deletions(-) diff --git a/dns/dnssec.py b/dns/dnssec.py index b69d0a126..76d728a5f 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -135,16 +135,16 @@ class Policy: def __init__(self): pass - def ok_to_sign(self, _: DNSKEY) -> bool: # pragma: no cover + def ok_to_sign(self, key: DNSKEY) -> bool: # pragma: no cover return False - def ok_to_validate(self, _: DNSKEY) -> bool: # pragma: no cover + def ok_to_validate(self, key: DNSKEY) -> bool: # pragma: no cover return False - def ok_to_create_ds(self, _: DSDigest) -> bool: # pragma: no cover + def ok_to_create_ds(self, algorithm: DSDigest) -> bool: # pragma: no cover return False - def ok_to_validate_ds(self, _: DSDigest) -> bool: # pragma: no cover + def ok_to_validate_ds(self, algorithm: DSDigest) -> bool: # pragma: no cover return False @@ -587,7 +587,7 @@ def _sign( signature=b"", ) - data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) + data = _make_rrsig_signature_data(rrset, rrsig_template, origin) # pylint: disable=possibly-used-before-assignment if isinstance(private_key, GenericPrivateKey): @@ -979,7 +979,7 @@ def default_rrset_signer( keys = zsks for private_key, dnskey in keys: - rrsig = dns.dnssec.sign( + rrsig = sign( rrset=rrset, private_key=private_key, dnskey=dnskey, diff --git a/dns/edns.py b/dns/edns.py index c36036864..ad9a07a8e 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -25,6 +25,9 @@ import dns.enum import dns.inet +import dns.ipv4 +import dns.ipv6 +import dns.name import dns.rdata import dns.wire @@ -444,7 +447,7 @@ def from_wire_parser( class CookieOption(Option): def __init__(self, client: bytes, server: bytes): - super().__init__(dns.edns.OptionType.COOKIE) + super().__init__(OptionType.COOKIE) self.client = client self.server = server if len(client) != 8: diff --git a/dns/grange.py b/dns/grange.py index a967ca41c..8d366dc8d 100644 --- a/dns/grange.py +++ b/dns/grange.py @@ -19,7 +19,7 @@ from typing import Tuple -import dns +import dns.exception def from_text(text: str) -> Tuple[int, int, int]: diff --git a/dns/name.py b/dns/name.py index f79f0d0f6..4861e11cf 100644 --- a/dns/name.py +++ b/dns/name.py @@ -30,6 +30,11 @@ import dns.immutable import dns.wire +# Dnspython will never access idna if the import fails, but pyright can't figure +# that out, so... +# +# pyright: reportAttributeAccessIssue = false, reportPossiblyUnboundVariable = false + if dns._features.have("idna"): import idna # type: ignore @@ -37,6 +42,7 @@ else: # pragma: no cover have_idna_2008 = False + CompressType = Dict["Name", int] diff --git a/dns/rdtypes/util.py b/dns/rdtypes/util.py index 653a0bf2e..defb8011b 100644 --- a/dns/rdtypes/util.py +++ b/dns/rdtypes/util.py @@ -18,13 +18,16 @@ import collections import random import struct -from typing import Any, List +from typing import Any, Iterable, List, Optional, Tuple, Union import dns.exception import dns.ipv4 import dns.ipv6 import dns.name import dns.rdata +import dns.rdatatype +import dns.tokenizer +import dns.wire class Gateway: @@ -32,7 +35,7 @@ class Gateway: name = "" - def __init__(self, type, gateway=None): + def __init__(self, type, gateway: Optional[Union[str, dns.name.Name]] = None): self.type = dns.rdata.Rdata._as_uint8(type) self.gateway = gateway self._check() @@ -48,9 +51,11 @@ def _check(self): self.gateway = None elif self.type == 1: # check that it's OK + assert isinstance(self.gateway, str) dns.ipv4.inet_aton(self.gateway) elif self.type == 2: # check that it's OK + assert isinstance(self.gateway, str) dns.ipv6.inet_aton(self.gateway) elif self.type == 3: if not isinstance(self.gateway, dns.name.Name): @@ -64,6 +69,7 @@ def to_text(self, origin=None, relativize=True): elif self.type in (1, 2): return self.gateway elif self.type == 3: + assert isinstance(self.gateway, dns.name.Name) return str(self.gateway.choose_relativity(origin, relativize)) else: raise ValueError(self._invalid_type(self.type)) # pragma: no cover @@ -87,10 +93,13 @@ def to_wire(self, file, compress=None, origin=None, canonicalize=False): if self.type == 0: pass elif self.type == 1: + assert isinstance(self.gateway, str) file.write(dns.ipv4.inet_aton(self.gateway)) elif self.type == 2: + assert isinstance(self.gateway, str) file.write(dns.ipv6.inet_aton(self.gateway)) elif self.type == 3: + assert isinstance(self.gateway, dns.name.Name) self.gateway.to_wire(file, None, origin, False) else: raise ValueError(self._invalid_type(self.type)) # pragma: no cover @@ -117,9 +126,11 @@ class Bitmap: type_name = "" - def __init__(self, windows=None): + def __init__(self, windows: Optional[Iterable[Tuple[int, bytes]]] = None): last_window = -1 - self.windows = windows + if windows is None: + windows = [] + self.windows = tuple(windows) for window, bitmap in self.windows: if not isinstance(window, int): raise ValueError(f"bad {self.type_name} window type") @@ -140,7 +151,7 @@ def to_text(self) -> str: for i, byte in enumerate(bitmap): for j in range(0, 8): if byte & (0x80 >> j): - rdtype = window * 256 + i * 8 + j + rdtype = dns.rdatatype.RdataType.make(window * 256 + i * 8 + j) bits.append(dns.rdatatype.to_text(rdtype)) text += " " + " ".join(bits) return text @@ -236,9 +247,10 @@ def weighted_processing_order(iterable): if weight > r: break r -= weight - total -= weight - ordered.append(rdata) # pylint: disable=undefined-loop-variable - del rdatas[n] # pylint: disable=undefined-loop-variable + total -= weight # pyright: ignore[reportPossiblyUnboundVariable] + # pylint: disable=undefined-loop-variable + ordered.append(rdata) # pyright: ignore[reportPossiblyUnboundVariable] + del rdatas[n] # pyright: ignore[reportPossiblyUnboundVariable] ordered.append(rdatas[0]) return ordered diff --git a/dns/reversename.py b/dns/reversename.py index 8236c711f..dc5f33e35 100644 --- a/dns/reversename.py +++ b/dns/reversename.py @@ -19,6 +19,7 @@ import binascii +import dns.exception import dns.ipv4 import dns.ipv6 import dns.name diff --git a/dns/update.py b/dns/update.py index bf1157acd..a1834b59e 100644 --- a/dns/update.py +++ b/dns/update.py @@ -19,6 +19,8 @@ from typing import Any, List, Optional, Union +import dns.enum +import dns.exception import dns.message import dns.name import dns.opcode diff --git a/pyproject.toml b/pyproject.toml index 2a4d045c6..10e860e37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,3 +116,8 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = "wmi" ignore_missing_imports = true + +[tool.pyright] +reportUnsupportedDunderAll = false +# temporary! +exclude = ["dns/rdtypes/ANY/*.py", "dns/rdtypes/CH/*.py", "dns/rdtypes/IN/*.py"] From c5b9038921e0413305670e117967130edf6216e2 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 7 Oct 2024 04:11:20 -0700 Subject: [PATCH 02/12] add pyright rule --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index 11edb489a..ecb0115e5 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,9 @@ check: test type: python -m mypy --install-types --non-interactive --disallow-incomplete-defs dns +pyright: + pyright dns + lint: pylint dns From ad238dfae6431a4cfc12b75f6d27eba9d94d781f Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 12 Oct 2024 10:19:12 -0700 Subject: [PATCH 03/12] checkpoint more linting --- dns/ipv4.py | 2 +- dns/ipv6.py | 2 +- dns/message.py | 9 ++++++-- dns/renderer.py | 4 ++++ dns/resolver.py | 59 +++++++++++++++++++++++++++---------------------- 5 files changed, 46 insertions(+), 30 deletions(-) diff --git a/dns/ipv4.py b/dns/ipv4.py index 65ee69c0d..21f529614 100644 --- a/dns/ipv4.py +++ b/dns/ipv4.py @@ -74,4 +74,4 @@ def canonicalize(text: Union[str, bytes]) -> str: """ # Note that inet_aton() only accepts canonial form, but we still run through # inet_ntoa() to ensure the output is a str. - return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) + return inet_ntoa(inet_aton(text)) diff --git a/dns/ipv6.py b/dns/ipv6.py index 4dd1d1cad..4f27b415a 100644 --- a/dns/ipv6.py +++ b/dns/ipv6.py @@ -214,4 +214,4 @@ def canonicalize(text: Union[str, bytes]) -> str: Raises ``dns.exception.SyntaxError`` if the text is not valid. """ - return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) + return inet_ntoa(inet_aton(text)) diff --git a/dns/message.py b/dns/message.py index e978a0a2e..568799572 100644 --- a/dns/message.py +++ b/dns/message.py @@ -38,6 +38,7 @@ import dns.rdtypes.ANY.TSIG import dns.renderer import dns.rrset +import dns.tokenizer import dns.tsig import dns.ttl import dns.wire @@ -1091,7 +1092,7 @@ def _message_factory_from_opcode(opcode): return QueryMessage elif opcode == dns.opcode.UPDATE: _maybe_import_update() - return dns.update.UpdateMessage + return dns.update.UpdateMessage # pyright: ignore else: return Message @@ -1421,7 +1422,7 @@ def __init__( relativize=True, relativize_to=None, ): - self.message = None + self.message: Optional[Message] = None self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) self.last_name = None self.one_rr_per_rrset = one_rr_per_rrset @@ -1480,6 +1481,7 @@ def _header_line(self, _): def _question_line(self, section_number): """Process one line from the text format question section.""" + assert self.message is not None section = self.message.sections[section_number] token = self.tok.get(want_leading=True) if not token.is_whitespace(): @@ -1517,6 +1519,7 @@ def _rr_line(self, section_number): additional data sections. """ + assert self.message is not None section = self.message.sections[section_number] # Name token = self.tok.get(want_leading=True) @@ -1910,6 +1913,8 @@ def make_response( pad = 468 response.use_edns(0, 0, our_payload, query.payload, pad=pad) if query.had_tsig: + assert query.mac is not None + assert query.keyalgorithm is not None response.use_tsig( query.keyring, query.keyname, diff --git a/dns/renderer.py b/dns/renderer.py index a77481f67..1b9903985 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -23,7 +23,11 @@ import struct import time +import dns.edns import dns.exception +import dns.message +import dns.rdataclass +import dns.rdatatype import dns.tsig QUESTION = 0 diff --git a/dns/resolver.py b/dns/resolver.py index af90dd8f1..1e23c0c86 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -24,7 +24,7 @@ import threading import time import warnings -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast from urllib.parse import urlparse import dns._ddr @@ -42,6 +42,7 @@ import dns.rdata import dns.rdataclass import dns.rdatatype +import dns.rdtypes.ANY.PTR import dns.rdtypes.svcbbase import dns.reversename import dns.tsig @@ -63,7 +64,7 @@ class NXDOMAIN(dns.exception.DNSException): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _check_kwargs(self, qnames, responses=None): + def _check_kwargs(self, qnames, responses=None): # pyright: ignore if not isinstance(qnames, (list, tuple, set)): raise AttributeError("qnames must be a list, tuple or set") if len(qnames) == 0: @@ -282,24 +283,25 @@ def __init__( self.expiration = time.time() + self.chaining_result.minimum_ttl def __getattr__(self, attr): # pragma: no cover - if attr == "name": - return self.rrset.name - elif attr == "ttl": - return self.rrset.ttl - elif attr == "covers": - return self.rrset.covers - elif attr == "rdclass": - return self.rrset.rdclass - elif attr == "rdtype": - return self.rrset.rdtype + if self.rrset is not None: + if attr == "name": + return self.rrset.name + elif attr == "ttl": + return self.rrset.ttl + elif attr == "covers": + return self.rrset.covers + elif attr == "rdclass": + return self.rrset.rdclass + elif attr == "rdtype": + return self.rrset.rdtype else: raise AttributeError(attr) def __len__(self) -> int: - return self.rrset and len(self.rrset) or 0 + return self.rrset is not None and len(self.rrset) or 0 def __iter__(self) -> Iterator[Any]: - return self.rrset and iter(self.rrset) or iter(tuple()) + return self.rrset is not None and iter(self.rrset) or iter(tuple()) def __getitem__(self, i): if self.rrset is None: @@ -1480,7 +1482,7 @@ def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: try: answer = self.resolve(name, raise_on_no_answer=False) canonical_name = answer.canonical_name - except dns.resolver.NXDOMAIN as e: + except NXDOMAIN as e: canonical_name = e.canonical_name return canonical_name @@ -1655,7 +1657,7 @@ def zone_for_name( tcp: bool = False, resolver: Optional[Resolver] = None, lifetime: Optional[float] = None, -) -> dns.name.Name: +) -> dns.name.Name: # pyright: ignore[reportReturnType] """Find the name of the zone which contains the specified name. *name*, an absolute ``dns.name.Name`` or ``str``, the query name. @@ -1709,8 +1711,8 @@ def zone_for_name( if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e: - if isinstance(e, dns.resolver.NXDOMAIN): + except (NXDOMAIN, NoAnswer) as e: + if isinstance(e, NXDOMAIN): response = e.responses().get(name) else: response = e.response() # pylint: disable=no-value-for-parameter @@ -1765,7 +1767,7 @@ def make_resolver_at( else: for address in resolver.resolve_name(where, family).addresses(): nameservers.append(dns.nameserver.Do53Nameserver(address, port)) - res = dns.resolver.Resolver(configure=False) + res = Resolver(configure=False) res.nameservers = nameservers return res @@ -1816,12 +1818,12 @@ def resolve_at( # running process. # -_protocols_for_socktype = { +_protocols_for_socktype: Dict[Any, List[Any]] = { socket.SOCK_DGRAM: [socket.SOL_UDP], socket.SOCK_STREAM: [socket.SOL_TCP], } -_resolver = None +_resolver: Optional[Resolver] = None _original_getaddrinfo = socket.getaddrinfo _original_getnameinfo = socket.getnameinfo _original_getfqdn = socket.getfqdn @@ -1870,10 +1872,11 @@ def _getaddrinfo( pass # Something needs resolution! try: + assert _resolver is not None answers = _resolver.resolve_name(host, family) addrs = answers.addresses_and_families() canonical_name = answers.canonical_name().to_text(True) - except dns.resolver.NXDOMAIN: + except NXDOMAIN: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") except Exception: # We raise EAI_AGAIN here as the failure may be temporary @@ -1890,7 +1893,7 @@ def _getaddrinfo( except Exception: if flags & socket.AI_NUMERICSERV == 0: try: - port = socket.getservbyname(service) + port = socket.getservbyname(service) # pyright: ignore except Exception: pass if port is None: @@ -1906,7 +1909,8 @@ def _getaddrinfo( cname = "" for addr, af in addrs: for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: + for sockproto in _protocols_for_socktype[socktype]: + proto = int(sockproto) addr_tuple = dns.inet.low_level_address_tuple((addr, port), af) tuples.append((af, socktype, proto, cname, addr_tuple)) if len(tuples) == 0: @@ -1934,9 +1938,12 @@ def _getnameinfo(sockaddr, flags=0): qname = dns.reversename.from_address(addr) if flags & socket.NI_NUMERICHOST == 0: try: + assert _resolver is not None answer = _resolver.resolve(qname, "PTR") - hostname = answer.rrset[0].target.to_text(True) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + assert answer.rrset is not None + rdata = cast(dns.rdtypes.ANY.PTR.PTR, answer.rrset[0]) + hostname = rdata.target.to_text(True) + except (NXDOMAIN, NoAnswer): if flags & socket.NI_NAMEREQD: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") hostname = addr From 4975e04a4c423cc9edbee9975ffb753f783953f2 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 12 Oct 2024 10:22:56 -0700 Subject: [PATCH 04/12] lint transaction.py --- dns/transaction.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/dns/transaction.py b/dns/transaction.py index aa2e11603..bcdda9e0d 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -6,6 +6,7 @@ import dns.exception import dns.name import dns.node +import dns.rdata import dns.rdataclass import dns.rdataset import dns.rdatatype @@ -416,12 +417,12 @@ def _rdataset_from_args(self, method, deleting, args): raise TypeError(f"{method}: expected more arguments") def _add(self, replace, args): + if replace: + method = "replace()" + else: + method = "add()" try: args = collections.deque(args) - if replace: - method = "replace()" - else: - method = "add()" arg = args.popleft() if isinstance(arg, str): arg = dns.name.from_text(arg, None) @@ -438,6 +439,7 @@ def _add(self, replace, args): raise TypeError( f"{method} requires a name or RRset as the first argument" ) + assert rdataset is not None # for type checkers if rdataset.rdclass != self.manager.get_class(): raise ValueError(f"{method} has objects of wrong RdataClass") if rdataset.rdtype == dns.rdatatype.SOA: @@ -460,12 +462,12 @@ def _add(self, replace, args): raise TypeError(f"not enough parameters to {method}") def _delete(self, exact, args): + if exact: + method = "delete_exact()" + else: + method = "delete()" try: args = collections.deque(args) - if exact: - method = "delete_exact()" - else: - method = "delete()" arg = args.popleft() if isinstance(arg, str): arg = dns.name.from_text(arg, None) From 12c31746fca1482b6c3e932c0227c4545cacec1b Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 12 Oct 2024 10:24:07 -0700 Subject: [PATCH 05/12] lint tsig.py --- dns/tsig.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dns/tsig.py b/dns/tsig.py index 780852e8e..8dee78dae 100644 --- a/dns/tsig.py +++ b/dns/tsig.py @@ -26,6 +26,7 @@ import dns.name import dns.rcode import dns.rdataclass +import dns.rdatatype class BadTime(dns.exception.DNSException): @@ -221,6 +222,7 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None) if request_mac: ctx.update(struct.pack("!H", len(request_mac))) ctx.update(request_mac) + assert ctx is not None # for type checkers ctx.update(struct.pack("!H", rdata.original_id)) ctx.update(wire[2:]) if first: From c908a8e81c9c345d2ad2fc75ad11f19b4ece2b00 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 12 Oct 2024 10:34:52 -0700 Subject: [PATCH 06/12] rdata type linting --- dns/rdata.py | 23 +++++++++++------------ dns/rdtypes/txtbase.py | 3 +++ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/dns/rdata.py b/dns/rdata.py index 0189f2409..bcdac094e 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -210,7 +210,7 @@ def to_text( def _to_wire( self, - file: Optional[Any], + file: Any, compress: Optional[dns.name.CompressType] = None, origin: Optional[dns.name.Name] = None, canonicalize: bool = False, @@ -241,16 +241,12 @@ def to_wire( self._to_wire(f, compress, origin, canonicalize) return f.getvalue() - def to_generic( - self, origin: Optional[dns.name.Name] = None - ) -> "dns.rdata.GenericRdata": + def to_generic(self, origin: Optional[dns.name.Name] = None) -> "GenericRdata": """Creates a dns.rdata.GenericRdata equivalent of this rdata. Returns a ``dns.rdata.GenericRdata``. """ - return dns.rdata.GenericRdata( - self.rdclass, self.rdtype, self.to_wire(origin=origin) - ) + return GenericRdata(self.rdclass, self.rdtype, self.to_wire(origin=origin)) def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes: """Convert rdata to a format suitable for digesting in hashes. This @@ -298,6 +294,9 @@ def _cmp(self, other): In the future, all ordering comparisons for rdata with relative names will be disallowed. """ + # the next two lines are for type checkers, so they are bound + our = b"" + their = b"" try: our = self.to_digestable() our_relative = False @@ -620,7 +619,7 @@ def to_text( relativize: bool = True, **kw: Dict[str, Any], ) -> str: - return r"\# %d " % len(self.data) + _hexify(self.data, **kw) + return r"\# %d " % len(self.data) + _hexify(self.data, **kw) # pyright: ignore @classmethod def from_text( @@ -639,9 +638,7 @@ def from_text( def _to_wire(self, file, compress=None, origin=None, canonicalize=False): file.write(self.data) - def to_generic( - self, origin: Optional[dns.name.Name] = None - ) -> "dns.rdata.GenericRdata": + def to_generic(self, origin: Optional[dns.name.Name] = None) -> "GenericRdata": return self @classmethod @@ -659,7 +656,7 @@ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): def get_rdata_class(rdclass, rdtype, use_generic=True): cls = _rdata_classes.get((rdclass, rdtype)) if not cls: - cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype)) + cls = _rdata_classes.get((dns.rdataclass.ANY, rdtype)) if not cls and _dynamic_load_allowed: rdclass_text = dns.rdataclass.to_text(rdclass) rdtype_text = dns.rdatatype.to_text(rdtype) @@ -758,6 +755,7 @@ def from_text( rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) + assert cls is not None # for type checkers with dns.exception.ExceptionWrapper(dns.exception.SyntaxError): rdata = None if cls != GenericRdata: @@ -830,6 +828,7 @@ def from_wire_parser( rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) + assert cls is not None # for type checkers with dns.exception.ExceptionWrapper(dns.exception.FormError): return cls.from_wire_parser(rdclass, rdtype, parser, origin) diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index 73db6d9e2..6ecdd35fc 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -21,7 +21,10 @@ import dns.exception import dns.immutable +import dns.name import dns.rdata +import dns.rdataclass +import dns.rdatatype import dns.renderer import dns.tokenizer From 1728dd494468d93d4a1c3fae0d68ded2c88f5b88 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 12 Oct 2024 14:53:49 -0700 Subject: [PATCH 07/12] checkpoint more linting --- dns/_immutable_ctx.py | 2 +- dns/asyncresolver.py | 5 ++++- dns/e164.py | 2 +- dns/edns.py | 6 +++--- dns/entropy.py | 6 +++--- dns/enum.py | 10 ++++++---- dns/message.py | 32 ++++++++++++++++++++++---------- dns/opcode.py | 4 +++- dns/rcode.py | 4 ++-- dns/tsig.py | 8 +++++++- dns/ttl.py | 2 +- dns/zonefile.py | 24 +++++++++++++++++++----- pyproject.toml | 2 +- 13 files changed, 73 insertions(+), 34 deletions(-) diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py index ae7a33bf3..b3d72deef 100644 --- a/dns/_immutable_ctx.py +++ b/dns/_immutable_ctx.py @@ -41,7 +41,7 @@ def nf(*args, **kwargs): finally: _in__init__.reset(previous) - nf.__signature__ = inspect.signature(f) + nf.__signature__ = inspect.signature(f) # pyright: ignore return nf diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py index 8f5e062a9..1df89e6ca 100644 --- a/dns/asyncresolver.py +++ b/dns/asyncresolver.py @@ -25,11 +25,14 @@ import dns.asyncbackend import dns.asyncquery import dns.exception +import dns.inet import dns.name +import dns.nameserver import dns.query import dns.rdataclass import dns.rdatatype import dns.resolver # lgtm[py/import-and-import-from] +import dns.reversename # import some resolver symbols for brevity from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute @@ -426,7 +429,7 @@ async def make_resolver_at( answers = await resolver.resolve_name(where, family) for address in answers.addresses(): nameservers.append(dns.nameserver.Do53Nameserver(address, port)) - res = dns.asyncresolver.Resolver(configure=False) + res = Resolver(configure=False) res.nameservers = nameservers return res diff --git a/dns/e164.py b/dns/e164.py index 453736d40..dd9aebc8a 100644 --- a/dns/e164.py +++ b/dns/e164.py @@ -108,7 +108,7 @@ def query( for domain in domains: if isinstance(domain, str): domain = dns.name.from_text(domain) - qname = dns.e164.from_e164(number, domain) + qname = from_e164(number, domain) try: return resolver.resolve(qname, "NAPTR") except dns.resolver.NXDOMAIN as e: diff --git a/dns/edns.py b/dns/edns.py index ad9a07a8e..8db1d2e0f 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -84,14 +84,14 @@ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: def to_text(self) -> str: raise NotImplementedError # pragma: no cover - def to_generic(self) -> "dns.edns.GenericOption": + def to_generic(self) -> "GenericOption": """Creates a dns.edns.GenericOption equivalent of this rdata. Returns a ``dns.edns.GenericOption``. """ wire = self.to_wire() assert wire is not None # for mypy - return dns.edns.GenericOption(self.otype, wire) + return GenericOption(self.otype, wire) @classmethod def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option": @@ -178,7 +178,7 @@ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: def to_text(self) -> str: return "Generic %d" % self.otype - def to_generic(self) -> "dns.edns.GenericOption": + def to_generic(self) -> "GenericOption": return self @classmethod diff --git a/dns/entropy.py b/dns/entropy.py index 4dcdc6272..45e79e3dd 100644 --- a/dns/entropy.py +++ b/dns/entropy.py @@ -20,7 +20,7 @@ import random import threading import time -from typing import Any, Optional +from typing import Any, Optional, Union class EntropyPool: @@ -45,7 +45,7 @@ def __init__(self, seed: Optional[bytes] = None): self.seeded = False self.seed_pid = 0 - def _stir(self, entropy: bytes) -> None: + def _stir(self, entropy: Union[bytes, bytearray]) -> None: for c in entropy: if self.pool_index == self.hash_len: self.pool_index = 0 @@ -53,7 +53,7 @@ def _stir(self, entropy: bytes) -> None: self.pool[self.pool_index] ^= b self.pool_index += 1 - def stir(self, entropy: bytes) -> None: + def stir(self, entropy: Union[bytes, bytearray]) -> None: with self.lock: self._stir(entropy) diff --git a/dns/enum.py b/dns/enum.py index 71461f177..72e3e4aef 100644 --- a/dns/enum.py +++ b/dns/enum.py @@ -18,6 +18,8 @@ import enum from typing import Type, TypeVar, Union +import dns.exception + TIntEnum = TypeVar("TIntEnum", bound="IntEnum") @@ -25,9 +27,9 @@ class IntEnum(enum.IntEnum): @classmethod def _missing_(cls, value): cls._check_value(value) - val = int.__new__(cls, value) + val = int.__new__(cls, value) # pyright: ignore val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}" - val._value_ = value + val._value_ = value # pyright: ignore return val @classmethod @@ -56,7 +58,7 @@ def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum: try: return cls(value) except ValueError: - return value + return value # pyright: ignore raise cls._unknown_exception_class() @classmethod @@ -112,5 +114,5 @@ def _extra_to_text(cls, value, current_text): # pylint: disable=W0613 return current_text @classmethod - def _unknown_exception_class(cls): + def _unknown_exception_class(cls) -> Type[Exception]: return ValueError diff --git a/dns/message.py b/dns/message.py index 568799572..c89a5c110 100644 --- a/dns/message.py +++ b/dns/message.py @@ -35,6 +35,7 @@ import dns.rdataclass import dns.rdatatype import dns.rdtypes.ANY.OPT +import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.TSIG import dns.renderer import dns.rrset @@ -530,7 +531,8 @@ def _compute_opt_reserve(self) -> int: # worry about that for now. We also don't worry if there is an existing padding # option, as it is unlikely and probably harmless, as the worst case is that we # may add another, and this seems to be legal. - for option in self.opt[0].options: + opt_rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0]) + for option in opt_rdata.options: wire = option.to_wire() # We add 4 here to account for the option type and length size += len(wire) + 4 @@ -754,21 +756,24 @@ def keyname(self) -> Optional[dns.name.Name]: @property def keyalgorithm(self) -> Optional[dns.name.Name]: if self.tsig: - return self.tsig[0].algorithm + rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0]) + return rdata.algorithm else: return None @property def mac(self) -> Optional[bytes]: if self.tsig: - return self.tsig[0].mac + rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0]) + return rdata.mac else: return None @property def tsig_error(self) -> Optional[int]: if self.tsig: - return self.tsig[0].error + rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0]) + return rdata.error else: return None @@ -858,14 +863,16 @@ def ednsflags(self, v): @property def payload(self) -> int: if self.opt: - return self.opt[0].payload + rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0]) + return rdata.payload else: return 0 @property def options(self) -> Tuple: if self.opt: - return self.opt[0].options + rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0]) + return rdata.options else: return () @@ -1052,7 +1059,8 @@ def resolve_chaining(self) -> ChainingResult: srrset = self.find_rrset( self.authority, auname, question.rdclass, dns.rdatatype.SOA ) - min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum) + srdata = cast(dns.rdtypes.ANY.SOA.SOA, srrset[0]) + min_ttl = min(min_ttl, srrset.ttl, srdata.minimum) break except KeyError: try: @@ -1196,7 +1204,10 @@ def _get_section(self, section_number, count): else: with self.parser.restrict_to(rdlen): rd = dns.rdata.from_wire_parser( - rdclass, rdtype, self.parser, self.message.origin + rdclass, # pyright: ignore + rdtype, + self.parser, + self.message.origin, ) covers = rd.covers() if self.message.xfr and rdtype == dns.rdatatype.SOA: @@ -1204,12 +1215,13 @@ def _get_section(self, section_number, count): if rdtype == dns.rdatatype.OPT: self.message.opt = dns.rrset.from_rdata(name, ttl, rd) elif rdtype == dns.rdatatype.TSIG: + trd = cast(dns.rdtypes.ANY.TSIG.TSIG, rd) if self.keyring is None or self.keyring is True: raise UnknownTSIGKey("got signed message without keyring") elif isinstance(self.keyring, dict): key = self.keyring.get(absolute_name) if isinstance(key, bytes): - key = dns.tsig.Key(absolute_name, key, rd.algorithm) + key = dns.tsig.Key(absolute_name, key, trd.algorithm) elif callable(self.keyring): key = self.keyring(self.message, absolute_name) else: @@ -1234,7 +1246,7 @@ def _get_section(self, section_number, count): rrset = self.message.find_rrset( section, name, - rdclass, + rdclass, # pyright: ignore rdtype, covers, deleting, diff --git a/dns/opcode.py b/dns/opcode.py index 78b43d2cb..3fa610d04 100644 --- a/dns/opcode.py +++ b/dns/opcode.py @@ -17,6 +17,8 @@ """DNS Opcodes.""" +from typing import Type + import dns.enum import dns.exception @@ -38,7 +40,7 @@ def _maximum(cls): return 15 @classmethod - def _unknown_exception_class(cls): + def _unknown_exception_class(cls) -> Type[Exception]: return UnknownOpcode diff --git a/dns/rcode.py b/dns/rcode.py index 8e6386f82..7bb8467e2 100644 --- a/dns/rcode.py +++ b/dns/rcode.py @@ -17,7 +17,7 @@ """DNS Result Codes.""" -from typing import Tuple +from typing import Tuple, Type import dns.enum import dns.exception @@ -72,7 +72,7 @@ def _maximum(cls): return 4095 @classmethod - def _unknown_exception_class(cls): + def _unknown_exception_class(cls) -> Type[Exception]: return UnknownRcode diff --git a/dns/tsig.py b/dns/tsig.py index 8dee78dae..18640a817 100644 --- a/dns/tsig.py +++ b/dns/tsig.py @@ -21,6 +21,7 @@ import hashlib import hmac import struct +from typing import Union import dns.exception import dns.name @@ -327,7 +328,12 @@ def get_context(key): class Key: - def __init__(self, name, secret, algorithm=default_algorithm): + def __init__( + self, + name: Union[dns.name.Name, str], + secret: Union[bytes, str], + algorithm: Union[dns.name.Name, str] = default_algorithm, + ): if isinstance(name, str): name = dns.name.from_text(name) self.name = name diff --git a/dns/ttl.py b/dns/ttl.py index b9a99fe3c..06c11eeff 100644 --- a/dns/ttl.py +++ b/dns/ttl.py @@ -87,6 +87,6 @@ def make(value: Union[int, str]) -> int: if isinstance(value, int): return value elif isinstance(value, str): - return dns.ttl.from_text(value) + return from_text(value) else: raise ValueError("cannot convert value to TTL") diff --git a/dns/zonefile.py b/dns/zonefile.py index d74510b29..a6a8531b2 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -19,7 +19,7 @@ import re import sys -from typing import Any, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Iterable, List, Optional, Set, Tuple, Union, cast import dns.exception import dns.grange @@ -169,6 +169,9 @@ def _rr_line(self): return self.tok.unget(token) name = self.last_name + if name is None: + raise dns.exception.SyntaxError("the last used name is undefined") + assert self.zone_origin is not None if not name.is_subdomain(self.zone_origin): self._eat_line() return @@ -257,11 +260,12 @@ def _rr_line(self): # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default # TTL from the SOA minttl if no $TTL statement is present before the # SOA is parsed. - self.default_ttl = rd.minimum + soa_rd = cast(dns.rdtypes.ANY.SOA.SOA, rd) + self.default_ttl = soa_rd.minimum self.default_ttl_known = True if ttl is None: # if we didn't have a TTL on the SOA, set it! - ttl = rd.minimum + ttl = soa_rd.minimum # TTL check. We had to wait until now to do this as the SOA RR's # own TTL can be inferred from its minimum. @@ -356,6 +360,12 @@ def _generate_line(self): ttl = self.default_ttl elif self.last_ttl_known: ttl = self.last_ttl + else: + # We don't go to the extra "look at the SOA" level of effort for + # $GENERATE, because the user really ought to have defined a TTL + # somehow! + raise dns.exception.SyntaxError("Missing default TTL value") + # Class try: rdclass = dns.rdataclass.from_text(token.value) @@ -417,6 +427,7 @@ def _format_index(index: int, base: str, width: int) -> str: name, self.current_origin, self.tok.idna_codec ) name = self.last_name + assert self.zone_origin is not None if not name.is_subdomain(self.zone_origin): self._eat_line() return @@ -606,7 +617,7 @@ def _end_transaction(self, commit): ) rrset.update(rdataset) rrsets.append(rrset) - self.manager.set_rrsets(rrsets) + self.manager.set_rrsets(rrsets) # pyright: ignore def _set_origin(self, origin): pass @@ -620,7 +631,10 @@ def _iterate_names(self): class RRSetsReaderManager(dns.transaction.TransactionManager): def __init__( - self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN + self, + origin: Optional[dns.name.Name] = dns.name.root, + relativize=False, + rdclass=dns.rdataclass.IN, ): self.origin = origin self.relativize = relativize diff --git a/pyproject.toml b/pyproject.toml index 10e860e37..af218f7e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,4 +120,4 @@ ignore_missing_imports = true [tool.pyright] reportUnsupportedDunderAll = false # temporary! -exclude = ["dns/rdtypes/ANY/*.py", "dns/rdtypes/CH/*.py", "dns/rdtypes/IN/*.py"] +exclude = ["dns/rdtypes/ANY/*.py", "dns/rdtypes/CH/*.py", "dns/rdtypes/IN/*.py", "dns/quic/*.py" ] From a07ade552b58ca79393ab8b2273c51e72a694e53 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 13 Oct 2024 11:49:49 -0700 Subject: [PATCH 08/12] checkpoint more linting --- dns/versioned.py | 13 +++++++----- dns/xfr.py | 36 ++++++++++++++++++++------------ dns/zone.py | 53 ++++++++++++++++++++++++++++++++---------------- 3 files changed, 67 insertions(+), 35 deletions(-) diff --git a/dns/versioned.py b/dns/versioned.py index fd78e674e..6479ae47e 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -4,7 +4,7 @@ import collections import threading -from typing import Callable, Deque, Optional, Set, Union +from typing import Callable, Deque, Optional, Set, Union, cast import dns.exception import dns.immutable @@ -105,7 +105,10 @@ def reader( n = v.nodes.get(oname) if n: rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) - if rds and rds[0].serial == serial: + if rds is None: + continue + soa = cast(dns.rdtypes.ANY.SOA.SOA, rds[0]) + if rds and soa.serial == serial: version = v break if version is None: @@ -186,7 +189,7 @@ def _prune_versions_unlocked(self): # Note our definition of least_kept also ensures we do not try to # delete the greatest version. if len(self._readers) > 0: - least_kept = min(txn.version.id for txn in self._readers) + least_kept = min(txn.version.id for txn in self._readers) # pyright: ignore else: least_kept = self._versions[-1].id while self._versions[0].id < least_kept and self._pruning_policy( @@ -201,8 +204,8 @@ def set_max_versions(self, max_versions: Optional[int]) -> None: if max_versions is not None and max_versions < 1: raise ValueError("max versions must be at least 1") if max_versions is None: - - def policy(zone, _): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def policy(zone, _): # pyright: ignore return False else: diff --git a/dns/xfr.py b/dns/xfr.py index 520aa32dd..f1b875934 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -15,14 +15,21 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, cast +import dns.edns import dns.exception import dns.message import dns.name import dns.rcode +import dns.rdata import dns.rdataset import dns.rdatatype +import dns.rdtypes +import dns.rdtypes.ANY +import dns.rdtypes.ANY.SMIMEA +import dns.rdtypes.ANY.SOA +import dns.rdtypes.svcbbase import dns.serial import dns.transaction import dns.tsig @@ -123,14 +130,16 @@ def process_message(self, message: dns.message.Message) -> bool: if rdataset.rdtype != dns.rdatatype.SOA: raise dns.exception.FormError("first RRset is not an SOA") answer_index = 1 - self.soa_rdataset = rdataset.copy() + self.soa_rdataset = rdataset.copy() # pyright: ignore if self.rdtype == dns.rdatatype.IXFR: - if self.soa_rdataset[0].serial == self.serial: + assert self.soa_rdataset is not None + soa = cast(dns.rdtypes.ANY.SOA.SOA, self.soa_rdataset[0]) + if soa.serial == self.serial: # # We're already up-to-date. # self.done = True - elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial: + elif dns.serial.Serial(soa.serial) < self.serial: # It went backwards! raise SerialWentBackwards else: @@ -174,13 +183,11 @@ def process_message(self, message: dns.message.Message) -> bool: # # This is the final SOA # + soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0]) if self.expecting_SOA: # We got an empty IXFR sequence! raise dns.exception.FormError("empty IXFR sequence") - if ( - self.rdtype == dns.rdatatype.IXFR - and self.serial != rdataset[0].serial - ): + if self.rdtype == dns.rdatatype.IXFR and self.serial != soa.serial: raise dns.exception.FormError("unexpected end of IXFR sequence") self.txn.replace(name, rdataset) self.txn.commit() @@ -191,16 +198,17 @@ def process_message(self, message: dns.message.Message) -> bool: # This is not the final SOA # self.expecting_SOA = False + soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0]) if self.rdtype == dns.rdatatype.IXFR: if self.delete_mode: # This is the start of an IXFR deletion set - if rdataset[0].serial != self.serial: + if soa.serial != self.serial: raise dns.exception.FormError( "IXFR base serial mismatch" ) else: # This is the start of an IXFR addition set - self.serial = rdataset[0].serial + self.serial = soa.serial self.txn.replace(name, rdataset) else: # We saw a non-final SOA for the origin in an AXFR. @@ -289,7 +297,8 @@ def make_query( with txn_manager.reader() as txn: rdataset = txn.get(origin, "SOA") if rdataset: - serial = rdataset[0].serial + soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0]) + serial = soa.serial rdtype = dns.rdatatype.IXFR else: serial = None @@ -337,7 +346,8 @@ def extract_serial_from_query(query: dns.message.Message) -> Optional[int]: return None elif question.rdtype != dns.rdatatype.IXFR: raise ValueError("query is not an AXFR or IXFR") - soa = query.find_rrset( + soa_rrset = query.find_rrset( query.authority, question.name, question.rdclass, dns.rdatatype.SOA ) - return soa[0].serial + soa = cast(dns.rdtypes.ANY.SOA.SOA, soa_rrset[0]) + return soa.serial diff --git a/dns/zone.py b/dns/zone.py index 844919e41..7cba657d4 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -32,6 +32,7 @@ Set, Tuple, Union, + cast, ) import dns.exception @@ -698,9 +699,9 @@ def to_file( for n in names: l = self[n].to_text( n, - origin=self.origin, - relativize=relativize, - want_comments=want_comments, + origin=self.origin, # pyright: ignore + relativize=relativize, # pyright: ignore + want_comments=want_comments, # pyright: ignore ) l_b = l.encode(file_enc) @@ -786,14 +787,16 @@ def get_soa( # an SOA if there is no origin. raise NoSOA origin_name = self.origin - soa: Optional[dns.rdataset.Rdataset] + soa_rds: Optional[dns.rdataset.Rdataset] if txn: - soa = txn.get(origin_name, dns.rdatatype.SOA) + soa_rds = txn.get(origin_name, dns.rdatatype.SOA) else: - soa = self.get_rdataset(origin_name, dns.rdatatype.SOA) - if soa is None: + soa_rds = self.get_rdataset(origin_name, dns.rdatatype.SOA) + if soa_rds is None: raise NoSOA - return soa[0] + else: + soa = cast(dns.rdtypes.ANY.SOA.SOA, soa_rds[0]) + return soa def _compute_digest( self, @@ -892,12 +895,12 @@ def _end_read(self, txn): def _end_write(self, txn): pass - def _commit_version(self, _, version, origin): + def _commit_version(self, txn, version, origin): self.nodes = version.nodes if self.origin is None: self.origin = origin - def _get_next_version_id(self): + def _get_next_version_id(self) -> int: # Versions are ephemeral and all have id 1 return 1 @@ -1106,67 +1109,83 @@ def zone(self): def _setup_version(self): assert self.version is None - factory = self.manager.writable_version_factory + factory = self.manager.writable_version_factory # pyright: ignore if factory is None: factory = WritableVersion - self.version = factory(self.zone, self.replacement) + self.version = factory(self.zone, self.replacement) # pyright: ignore def _get_rdataset(self, name, rdtype, covers): + assert self.version is not None return self.version.get_rdataset(name, rdtype, covers) def _put_rdataset(self, name, rdataset): assert not self.read_only + assert self.version is not None self.version.put_rdataset(name, rdataset) def _delete_name(self, name): assert not self.read_only + assert self.version is not None self.version.delete_node(name) def _delete_rdataset(self, name, rdtype, covers): assert not self.read_only + assert self.version is not None self.version.delete_rdataset(name, rdtype, covers) def _name_exists(self, name): + assert self.version is not None return self.version.get_node(name) is not None def _changed(self): if self.read_only: return False else: + assert self.version is not None return len(self.version.changed) > 0 def _end_transaction(self, commit): + assert self.zone is not None + assert self.version is not None if self.read_only: - self.zone._end_read(self) + self.zone._end_read(self) # pyright: ignore elif commit and len(self.version.changed) > 0: if self.make_immutable: - factory = self.manager.immutable_version_factory + factory = self.manager.immutable_version_factory # pyright: ignore if factory is None: factory = ImmutableVersion version = factory(self.version) else: version = self.version - self.zone._commit_version(self, version, self.version.origin) + self.zone._commit_version( # pyright: ignore + self, version, self.version.origin + ) + else: # rollback - self.zone._end_write(self) + self.zone._end_write(self) # pyright: ignore def _set_origin(self, origin): + assert self.version is not None if self.version.origin is None: self.version.origin = origin def _iterate_rdatasets(self): + assert self.version is not None for name, node in self.version.items(): for rdataset in node: yield (name, rdataset) def _iterate_names(self): + assert self.version is not None return self.version.keys() def _get_node(self, name): + assert self.version is not None return self.version.get_node(name) def _origin_information(self): + assert self.version is not None (absolute, relativize, effective) = self.manager.origin_information() if absolute is None and self.version.origin is not None: # No origin has been committed yet, but we've learned one as part of @@ -1214,7 +1233,7 @@ def _from_text( reader.read() except dns.zonefile.UnknownOrigin: # for backwards compatibility - raise dns.zone.UnknownOrigin + raise UnknownOrigin # Now that we're done reading, do some basic checking of the zone. if check_origin: zone.check_origin() From 369c3c552a2c5c3db1e2b120b925b7a3158e8101 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 13 Oct 2024 12:16:43 -0700 Subject: [PATCH 09/12] checkpoint more linting --- dns/enum.py | 6 +++--- dns/rdtypes/dnskeybase.py | 2 +- dns/rdtypes/dsbase.py | 4 +++- dns/rdtypes/euibase.py | 5 ++++- dns/rdtypes/svcbbase.py | 14 ++++++++------ dns/rdtypes/tlsabase.py | 2 +- dns/update.py | 3 ++- pyproject.toml | 11 +++++++++-- 8 files changed, 31 insertions(+), 16 deletions(-) diff --git a/dns/enum.py b/dns/enum.py index 72e3e4aef..24942ccc8 100644 --- a/dns/enum.py +++ b/dns/enum.py @@ -16,7 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import enum -from typing import Type, TypeVar, Union +from typing import Any, Optional, Type, TypeVar, Union import dns.exception @@ -102,11 +102,11 @@ def _short_name(cls): return cls.__name__.lower() @classmethod - def _prefix(cls): + def _prefix(cls) -> str: return "" @classmethod - def _extra_from_text(cls, text): # pylint: disable=W0613 + def _extra_from_text(cls, text) -> Optional[Any]: # pylint: disable=W0613 return None @classmethod diff --git a/dns/rdtypes/dnskeybase.py b/dns/rdtypes/dnskeybase.py index db300f8b1..381fe770d 100644 --- a/dns/rdtypes/dnskeybase.py +++ b/dns/rdtypes/dnskeybase.py @@ -52,7 +52,7 @@ def to_text(self, origin=None, relativize=True, **kw): self.flags, self.protocol, self.algorithm, - dns.rdata._base64ify(self.key, **kw), + dns.rdata._base64ify(self.key, **kw), # pyright: ignore ) @classmethod diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py index cd21f026d..a9269d22c 100644 --- a/dns/rdtypes/dsbase.py +++ b/dns/rdtypes/dsbase.py @@ -59,7 +59,9 @@ def to_text(self, origin=None, relativize=True, **kw): self.key_tag, self.algorithm, self.digest_type, - dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + dns.rdata._hexify( + self.digest, chunksize=chunksize, **kw # pyright: ignore + ), ) @classmethod diff --git a/dns/rdtypes/euibase.py b/dns/rdtypes/euibase.py index a39c166b9..4eb82eb5e 100644 --- a/dns/rdtypes/euibase.py +++ b/dns/rdtypes/euibase.py @@ -16,6 +16,7 @@ import binascii +import dns.exception import dns.immutable import dns.rdata @@ -27,7 +28,9 @@ class EUIBase(dns.rdata.Rdata): # see: rfc7043.txt __slots__ = ["eui"] - # define these in subclasses + # redefine these in subclasses + byte_len = 0 + text_len = 0 # byte_len = 6 # 0123456789ab (in hex) # text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py index a2b15b922..bcde5cbbb 100644 --- a/dns/rdtypes/svcbbase.py +++ b/dns/rdtypes/svcbbase.py @@ -3,6 +3,7 @@ import base64 import enum import struct +from typing import Any, Dict import dns.enum import dns.exception @@ -97,9 +98,9 @@ def _escapify(qstring): return text -def _unescape(value): +def _unescape(value: str) -> bytes: if value == "": - return value + return b"" unescaped = b"" l = len(value) i = 0 @@ -159,7 +160,7 @@ class Param: """Abstract base class for SVCB parameters""" @classmethod - def emptiness(cls): + def emptiness(cls) -> Emptiness: return Emptiness.NEVER @@ -427,7 +428,7 @@ def to_wire(self, file, origin=None): # pylint: disable=W0613 raise NotImplementedError # pragma: no cover -_class_for_key = { +_class_for_key: Dict[ParamKey, Any] = { ParamKey.MANDATORY: MandatoryParam, ParamKey.ALPN: ALPNParam, ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam, @@ -571,10 +572,11 @@ def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): raise dns.exception.FormError("keys not in order") prior_key = key vlen = parser.get_uint16() - pcls = _class_for_key.get(key, GenericParam) + pkey = ParamKey.make(key) + pcls = _class_for_key.get(pkey, GenericParam) with parser.restrict_to(vlen): value = pcls.from_wire_parser(parser, origin) - params[key] = value + params[pkey] = value return cls(rdclass, rdtype, priority, target, params) def _processing_priority(self): diff --git a/dns/rdtypes/tlsabase.py b/dns/rdtypes/tlsabase.py index a059d2c4a..44d8cc24a 100644 --- a/dns/rdtypes/tlsabase.py +++ b/dns/rdtypes/tlsabase.py @@ -45,7 +45,7 @@ def to_text(self, origin=None, relativize=True, **kw): self.usage, self.selector, self.mtype, - dns.rdata._hexify(self.cert, chunksize=chunksize, **kw), + dns.rdata._hexify(self.cert, chunksize=chunksize, **kw), # pyright: ignore ) @classmethod diff --git a/dns/update.py b/dns/update.py index a1834b59e..cbf207977 100644 --- a/dns/update.py +++ b/dns/update.py @@ -28,6 +28,7 @@ import dns.rdataclass import dns.rdataset import dns.rdatatype +import dns.rrset import dns.tsig @@ -353,7 +354,7 @@ def _get_one_rr_per_rrset(self, value): # Updates are always one_rr_per_rrset return True - def _parse_rr_header(self, section, name, rdclass, rdtype): + def _parse_rr_header(self, section, name, rdclass, rdtype): # pyright: ignore deleting = None empty = False if section == UpdateSection.ZONE: diff --git a/pyproject.toml b/pyproject.toml index af218f7e5..1d89711d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,5 +119,12 @@ ignore_missing_imports = true [tool.pyright] reportUnsupportedDunderAll = false -# temporary! -exclude = ["dns/rdtypes/ANY/*.py", "dns/rdtypes/CH/*.py", "dns/rdtypes/IN/*.py", "dns/quic/*.py" ] +exclude = [ + "dns/dnssecalgs/*.py", + "dns/quic/*.py", + "dns/rdtypes/ANY/*.py", + "dns/rdtypes/CH/*.py", + "dns/rdtypes/IN/*.py", + "dns/rrset.py", + "dns/rdataset.py", +] # temporary! From 88ea8bf0719e72515d0b096e106e62c313e8c304 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 14 Oct 2024 06:30:17 -0700 Subject: [PATCH 10/12] lint query --- dns/query.py | 76 +++++++++++++++++++++++++++++++++---------------- dns/rdataset.py | 19 +++++++------ dns/rrset.py | 4 ++- pyproject.toml | 6 ++-- 4 files changed, 68 insertions(+), 37 deletions(-) diff --git a/dns/query.py b/dns/query.py index 0d8a977ab..068729db3 100644 --- a/dns/query.py +++ b/dns/query.py @@ -38,6 +38,7 @@ import dns.name import dns.quic import dns.rcode +import dns.rdata import dns.rdataclass import dns.rdatatype import dns.serial @@ -78,7 +79,7 @@ def __init__(self, resolver, local_port, bootstrap_address, family): self._family = family def connect_tcp( - self, host, port, timeout, local_address, socket_options=None + self, host, port, timeout=None, local_address=None, socket_options=None ): # pylint: disable=signature-differs addresses = [] _, expiration = _compute_times(timeout) @@ -98,6 +99,8 @@ def connect_tcp( for address in addresses: af = dns.inet.af_for_address(address) if local_address is not None or self._local_port != 0: + if local_address is None: + local_address = "0.0.0.0" source = dns.inet.low_level_address_tuple( (local_address, self._local_port), af ) @@ -117,11 +120,11 @@ def connect_tcp( raise httpcore.ConnectError def connect_unix_socket( - self, path, timeout, socket_options=None + self, path, timeout=None, socket_options=None ): # pylint: disable=signature-differs raise NotImplementedError - class _HTTPTransport(httpx.HTTPTransport): + class _HTTPTransport(httpx.HTTPTransport): # pyright: ignore def __init__( self, *args, @@ -144,6 +147,17 @@ def __init__( else: class _HTTPTransport: # type: ignore + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + pass + def connect_tcp(self, host, port, timeout, local_address): raise NotImplementedError @@ -151,7 +165,7 @@ def connect_tcp(self, host, port, timeout, local_address): have_doh = _have_httpx try: - import ssl + import ssl # pyright: ignore except ImportError: # pragma: no cover class ssl: # type: ignore @@ -163,11 +177,18 @@ class WantReadException(Exception): class WantWriteException(Exception): pass + class SSLWantReadError(Exception): + pass + + class SSLWantWriteError(Exception): + pass + class SSLContext: pass class SSLSocket: - pass + def pending(self) -> bool: + return False @classmethod def create_default_context(cls, *args, **kwargs): @@ -226,7 +247,7 @@ def _wait_for(fd, readable, writable, _, expiration): if writable: events |= selectors.EVENT_WRITE if events: - sel.register(fd, events) + sel.register(fd, events) # pyright: ignore if expiration is None: timeout = None else: @@ -338,8 +359,8 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): def _maybe_get_resolver( - resolver: Optional["dns.resolver.Resolver"], -) -> "dns.resolver.Resolver": + resolver: Optional["dns.resolver.Resolver"], # pyright: ignore +) -> "dns.resolver.Resolver": # pyright: ignore # We need a separate method for this to avoid overriding the global # variable "dns" with the as-yet undefined local variable "dns" # in https(). @@ -381,7 +402,7 @@ def https( post: bool = True, bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, - resolver: Optional["dns.resolver.Resolver"] = None, + resolver: Optional["dns.resolver.Resolver"] = None, # pyright: ignore family: int = socket.AF_UNSPEC, http_version: HTTPVersion = HTTPVersion.DEFAULT, ) -> dns.message.Message: @@ -441,13 +462,13 @@ def https( (af, _, the_source) = _destination_and_source( where, port, source, source_port, False ) + # we bind url and then override as pyright can't figure out all paths bind. + url = where if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = f"https://{where}:{port}{path}" elif af == socket.AF_INET6: url = f"https://[{where}]:{port}{path}" - else: - url = where extensions = {} if bootstrap_address is None: @@ -466,13 +487,13 @@ def https( ): if bootstrap_address is None: resolver = _maybe_get_resolver(resolver) - assert parsed.hostname is not None # for mypy - answers = resolver.resolve_name(parsed.hostname, family) + assert parsed.hostname is not None # pyright: ignore + answers = resolver.resolve_name(parsed.hostname, family) # pyright: ignore bootstrap_address = random.choice(list(answers.addresses())) return _http3( q, bootstrap_address, - url, + url, # pyright: ignore timeout, port, source, @@ -485,7 +506,7 @@ def https( if not have_doh: raise NoDOH # pragma: no cover - if session and not isinstance(session, httpx.Client): + if session and not isinstance(session, httpx.Client): # pyright: ignore raise ValueError("session parameter must be an httpx.Client") wire = q.to_wire() @@ -514,10 +535,12 @@ def https( local_port=local_port, bootstrap_address=bootstrap_address, resolver=resolver, - family=family, + family=family, # pyright: ignore ) - cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport) + cm = httpx.Client( # pyright: ignore + http1=h1, http2=h2, verify=verify, transport=transport # pyright: ignore + ) with cm as session: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples @@ -617,7 +640,7 @@ def _http3( q.id = 0 wire = q.to_wire() manager = dns.quic.SyncQuicManager( - verify_mode=verify, server_name=hostname, h3=True + verify_mode=verify, server_name=hostname, h3=True # pyright: ignore ) with manager: @@ -1162,7 +1185,7 @@ def tcp( with cm as s: if not sock: # pylint: disable=possibly-used-before-assignment - _connect(s, destination, expiration) + _connect(s, destination, expiration) # pyright: ignore send_tcp(s, wire, expiration) (r, received_time) = receive_tcp( s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing @@ -1385,14 +1408,18 @@ def quic( manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) the_connection = connection else: - manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname) + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=hostname # pyright: ignore + ) the_manager = manager # for type checking happiness with manager: if not connection: - the_connection = the_manager.connect(where, port, source, source_port) + the_connection = the_manager.connect( # pyright: ignore + where, port, source, source_port + ) (start, expiration) = _compute_times(timeout) - with the_connection.make_stream(timeout) as stream: + with the_connection.make_stream(timeout) as stream: # pyright: ignore stream.send(wire, True) wire = stream.receive(_remaining(expiration)) finish = time.time() @@ -1428,7 +1455,7 @@ def _inbound_xfr( query: dns.message.Message, serial: Optional[int], timeout: Optional[float], - expiration: float, + expiration: Optional[float], ) -> Any: """Given a socket, does the zone transfer.""" rdtype = query.question[0].rdtype @@ -1444,6 +1471,7 @@ def _inbound_xfr( with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: done = False tsig_ctx = None + r: Optional[dns.message.Message] = None while not done: (_, mexpiration) = _compute_times(timeout) if mexpiration is None or ( @@ -1469,7 +1497,7 @@ def _inbound_xfr( done = inbound.process_message(r) yield r tsig_ctx = r.tsig_ctx - if query.keyring and not r.had_tsig: + if query.keyring and r is not None and not r.had_tsig: raise dns.exception.FormError("missing TSIG") diff --git a/dns/rdataset.py b/dns/rdataset.py index 39cab2365..4b4bd7576 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -75,7 +75,7 @@ def __init__( self.ttl = ttl def _clone(self): - obj = super()._clone() + obj = cast(Rdataset, super()._clone()) obj.rdclass = self.rdclass obj.rdtype = self.rdtype obj.covers = self.covers @@ -97,7 +97,8 @@ def update_ttl(self, ttl: int) -> None: elif ttl < self.ttl: self.ttl = ttl - def add( # pylint: disable=arguments-differ,arguments-renamed + # pylint: disable=arguments-differ,arguments-renamed + def add( # pyright: ignore self, rd: dns.rdata.Rdata, ttl: Optional[int] = None ) -> None: """Add the specified rdata to the rdataset. @@ -355,7 +356,7 @@ def processing_order(self) -> List[dns.rdata.Rdata]: if len(self) == 0: return [] else: - return self[0]._processing_order(iter(self)) + return self[0]._processing_order(iter(self)) # pyright: ignore @dns.immutable.immutable @@ -410,22 +411,22 @@ def clear(self): raise TypeError("immutable") def __copy__(self): - return ImmutableRdataset(super().copy()) + return ImmutableRdataset(super().copy()) # pyright: ignore def copy(self): - return ImmutableRdataset(super().copy()) + return ImmutableRdataset(super().copy()) # pyright: ignore def union(self, other): - return ImmutableRdataset(super().union(other)) + return ImmutableRdataset(super().union(other)) # pyright: ignore def intersection(self, other): - return ImmutableRdataset(super().intersection(other)) + return ImmutableRdataset(super().intersection(other)) # pyright: ignore def difference(self, other): - return ImmutableRdataset(super().difference(other)) + return ImmutableRdataset(super().difference(other)) # pyright: ignore def symmetric_difference(self, other): - return ImmutableRdataset(super().symmetric_difference(other)) + return ImmutableRdataset(super().symmetric_difference(other)) # pyright: ignore def from_text_list( diff --git a/dns/rrset.py b/dns/rrset.py index 6f39b108d..2b0effaaa 100644 --- a/dns/rrset.py +++ b/dns/rrset.py @@ -20,8 +20,10 @@ from typing import Any, Collection, Dict, Optional, Union, cast import dns.name +import dns.rdata import dns.rdataclass import dns.rdataset +import dns.rdatatype import dns.renderer @@ -52,7 +54,7 @@ def __init__( self.deleting = deleting def _clone(self): - obj = super()._clone() + obj = cast(RRset, super()._clone()) obj.name = self.name obj.deleting = self.deleting return obj diff --git a/pyproject.toml b/pyproject.toml index 1d89711d7..16eeefb7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,6 @@ exclude = [ "dns/rdtypes/ANY/*.py", "dns/rdtypes/CH/*.py", "dns/rdtypes/IN/*.py", - "dns/rrset.py", - "dns/rdataset.py", -] # temporary! + "examples/*.py", + "tests/*.py", +] # (mostly) temporary! From f17a15f28ef99b9223b6b4ffa370f98095b4227c Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 14 Oct 2024 06:40:15 -0700 Subject: [PATCH 11/12] lint asyncquery --- dns/asyncquery.py | 73 +++++++++++++++++++++++++++++++---------------- pyproject.toml | 1 + 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index efad0fd75..883e8afc0 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -36,6 +36,8 @@ import dns.rdataclass import dns.rdatatype import dns.transaction +import dns.tsig +import dns.xfr from dns._asyncbackend import NullContext from dns.query import ( BadResponse, @@ -219,9 +221,9 @@ async def udp( dtuple = None cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) async with cm as s: - await send_udp(s, wire, destination, expiration) + await send_udp(s, wire, destination, expiration) # pyright: ignore (r, received_time, _) = await receive_udp( - s, + s, # pyright: ignore destination, expiration, ignore_unexpected, @@ -424,9 +426,14 @@ async def tcp( af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout ) async with cm as s: - await send_tcp(s, wire, expiration) + await send_tcp(s, wire, expiration) # pyright: ignore (r, received_time) = await receive_tcp( - s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + s, # pyright: ignore + expiration, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, ) r.time = received_time - begin_time if not q.is_response(r): @@ -469,7 +476,9 @@ async def tls( cm: contextlib.AbstractAsyncContextManager = NullContext(sock) else: if ssl_context is None: - ssl_context = _make_dot_ssl_context(server_hostname, verify) + ssl_context = _make_dot_ssl_context( + server_hostname, verify + ) # pyright: ignore af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) @@ -505,8 +514,8 @@ async def tls( def _maybe_get_resolver( - resolver: Optional["dns.asyncresolver.Resolver"], -) -> "dns.asyncresolver.Resolver": + resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore +) -> "dns.asyncresolver.Resolver": # pyright: ignore # We need a separate method for this to avoid overriding the global # variable "dns" with the as-yet undefined local variable "dns" # in https(). @@ -532,7 +541,7 @@ async def https( post: bool = True, verify: Union[bool, str] = True, bootstrap_address: Optional[str] = None, - resolver: Optional["dns.asyncresolver.Resolver"] = None, + resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore family: int = socket.AF_UNSPEC, http_version: HTTPVersion = HTTPVersion.DEFAULT, ) -> dns.message.Message: @@ -552,13 +561,13 @@ async def https( af = dns.inet.af_for_address(where) except ValueError: af = None + # we bind url and then override as pyright can't figure out all paths bind. + url = where if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = f"https://{where}:{port}{path}" elif af == socket.AF_INET6: url = f"https://[{where}]:{port}{path}" - else: - url = where extensions = {} if bootstrap_address is None: @@ -577,8 +586,10 @@ async def https( ): if bootstrap_address is None: resolver = _maybe_get_resolver(resolver) - assert parsed.hostname is not None # for mypy - answers = await resolver.resolve_name(parsed.hostname, family) + assert parsed.hostname is not None # pyright: ignore + answers = await resolver.resolve_name( # pyright: ignore + parsed.hostname, family # pyright: ignore + ) bootstrap_address = random.choice(list(answers.addresses())) return await _http3( q, @@ -597,7 +608,7 @@ async def https( if not have_doh: raise NoDOH # pragma: no cover # pylint: disable=possibly-used-before-assignment - if client and not isinstance(client, httpx.AsyncClient): + if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore raise ValueError("session parameter must be an httpx.AsyncClient") # pylint: enable=possibly-used-before-assignment @@ -630,7 +641,9 @@ async def https( family=family, ) - cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport) + cm = httpx.AsyncClient( # pyright: ignore + http1=h1, http2=h2, verify=verify, transport=transport + ) async with cm as the_client: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH @@ -643,7 +656,7 @@ async def https( } ) response = await backend.wait_for( - the_client.post( + the_client.post( # pyright: ignore url, headers=headers, content=wire, @@ -655,7 +668,7 @@ async def https( wire = base64.urlsafe_b64encode(wire).rstrip(b"=") twire = wire.decode() # httpx does a repr() if we give it bytes response = await backend.wait_for( - the_client.get( + the_client.get( # pyright: ignore url, headers=headers, params={"dns": twire}, @@ -785,9 +798,11 @@ async def quic( server_name=server_hostname, ) as the_manager: if not connection: - the_connection = the_manager.connect(where, port, source, source_port) + the_connection = the_manager.connect( # pyright: ignore + where, port, source, source_port + ) (start, expiration) = _compute_times(timeout) - stream = await the_connection.make_stream(timeout) + stream = await the_connection.make_stream(timeout) # pyright: ignore async with stream: await stream.send(wire, True) wire = await stream.receive(_remaining(expiration)) @@ -829,6 +844,7 @@ async def _inbound_xfr( with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: done = False tsig_ctx = None + r: Optional[dns.message.Message] = None while not done: (_, mexpiration) = _compute_times(timeout) if mexpiration is None or ( @@ -837,11 +853,11 @@ async def _inbound_xfr( mexpiration = expiration if is_udp: timeout = _timeout(mexpiration) - (rwire, _) = await udp_sock.recvfrom(65535, timeout) + (rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore else: - ldata = await _read_exactly(tcp_sock, 2, mexpiration) + ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore (l,) = struct.unpack("!H", ldata) - rwire = await _read_exactly(tcp_sock, l, mexpiration) + rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore r = dns.message.from_wire( rwire, keyring=query.keyring, @@ -855,7 +871,7 @@ async def _inbound_xfr( done = inbound.process_message(r) yield r tsig_ctx = r.tsig_ctx - if query.keyring and not r.had_tsig: + if query.keyring and r is not None and not r.had_tsig: raise dns.exception.FormError("missing TSIG") @@ -896,8 +912,13 @@ async def inbound_xfr( ) async with s: try: - async for _ in _inbound_xfr( - txn_manager, s, query, serial, timeout, expiration + async for _ in _inbound_xfr( # pyright: ignore + txn_manager, + s, + query, + serial, + timeout, + expiration, # pyright: ignore ): pass return @@ -909,5 +930,7 @@ async def inbound_xfr( af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration) ) async with s: - async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration): + async for _ in _inbound_xfr( # pyright: ignore + txn_manager, s, query, serial, timeout, expiration # pyright: ignore + ): pass diff --git a/pyproject.toml b/pyproject.toml index 16eeefb7d..75530baaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ ignore_missing_imports = true [tool.pyright] reportUnsupportedDunderAll = false exclude = [ + "dns/_*_backend.py", "dns/dnssecalgs/*.py", "dns/quic/*.py", "dns/rdtypes/ANY/*.py", From c3543f5af3ba1579aeb46ca2d0c67d86eef62bac Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 14 Oct 2024 07:53:26 -0700 Subject: [PATCH 12/12] harmonize with mypy, fix bug --- dns/enum.py | 7 ++----- dns/message.py | 14 +++++++------- dns/rdtypes/util.py | 4 ++-- dns/renderer.py | 15 ++++++++++----- dns/tsigkeyring.py | 4 ++-- dns/zonefile.py | 8 ++++---- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/dns/enum.py b/dns/enum.py index 24942ccc8..d7f261870 100644 --- a/dns/enum.py +++ b/dns/enum.py @@ -55,10 +55,7 @@ def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum: if text.startswith(prefix) and text[len(prefix) :].isdigit(): value = int(text[len(prefix) :]) cls._check_value(value) - try: - return cls(value) - except ValueError: - return value # pyright: ignore + return cls(value) raise cls._unknown_exception_class() @classmethod @@ -106,7 +103,7 @@ def _prefix(cls) -> str: return "" @classmethod - def _extra_from_text(cls, text) -> Optional[Any]: # pylint: disable=W0613 + def _extra_from_text(cls, text: str) -> Optional[Any]: # pylint: disable=W0613 return None @classmethod diff --git a/dns/message.py b/dns/message.py index c89a5c110..fc2a0e721 100644 --- a/dns/message.py +++ b/dns/message.py @@ -1427,14 +1427,14 @@ class _TextReader: def __init__( self, - text, - idna_codec, - one_rr_per_rrset=False, - origin=None, - relativize=True, - relativize_to=None, + text: str, + idna_codec: Optional[dns.name.IDNACodec], + one_rr_per_rrset: bool = False, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, ): - self.message: Optional[Message] = None + self.message: Optional[Message] = None # mypy: ignore self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) self.last_name = None self.one_rr_per_rrset = one_rr_per_rrset diff --git a/dns/rdtypes/util.py b/dns/rdtypes/util.py index defb8011b..ee6e8acad 100644 --- a/dns/rdtypes/util.py +++ b/dns/rdtypes/util.py @@ -35,7 +35,7 @@ class Gateway: name = "" - def __init__(self, type, gateway: Optional[Union[str, dns.name.Name]] = None): + def __init__(self, type: Any, gateway: Optional[Union[str, dns.name.Name]] = None): self.type = dns.rdata.Rdata._as_uint8(type) self.gateway = gateway self._check() @@ -130,7 +130,7 @@ def __init__(self, windows: Optional[Iterable[Tuple[int, bytes]]] = None): last_window = -1 if windows is None: windows = [] - self.windows = tuple(windows) + self.windows = windows for window, bitmap in self.windows: if not isinstance(window, int): raise ValueError(f"bad {self.type_name} window type") diff --git a/dns/renderer.py b/dns/renderer.py index 1b9903985..cc912b29d 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -25,11 +25,12 @@ import dns.edns import dns.exception -import dns.message import dns.rdataclass import dns.rdatatype import dns.tsig +# Note we can't import dns.message for cicularity reasons + QUESTION = 0 ANSWER = 1 AUTHORITY = 2 @@ -218,7 +219,9 @@ def add_opt(self, opt, pad=0, opt_size=0, tsig_size=0): pad = b"" options = list(opt_rdata.options) options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad)) - opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options) + opt = dns.message.Message._make_opt( # pyright: ignore + ttl, opt_rdata.rdclass, options + ) self.was_padded = True self.add_rrset(ADDITIONAL, opt) @@ -228,7 +231,9 @@ def add_edns(self, edns, ednsflags, payload, options=None): # make sure the EDNS version in ednsflags agrees with edns ednsflags &= 0xFF00FFFF ednsflags |= edns << 16 - opt = dns.message.Message._make_opt(ednsflags, payload, options) + opt = dns.message.Message._make_opt( # pyright: ignore + ednsflags, payload, options + ) self.add_opt(opt) def add_tsig( @@ -250,7 +255,7 @@ def add_tsig( key = secret else: key = dns.tsig.Key(keyname, secret, algorithm) - tsig = dns.message.Message._make_tsig( + tsig = dns.message.Message._make_tsig( # pyright: ignore keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data ) (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac) @@ -282,7 +287,7 @@ def add_multi_tsig( key = secret else: key = dns.tsig.Key(keyname, secret, algorithm) - tsig = dns.message.Message._make_tsig( + tsig = dns.message.Message._make_tsig( # pyright: ignore keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data ) (tsig, ctx) = dns.tsig.sign( diff --git a/dns/tsigkeyring.py b/dns/tsigkeyring.py index 1010a79f8..5996295a2 100644 --- a/dns/tsigkeyring.py +++ b/dns/tsigkeyring.py @@ -24,14 +24,14 @@ import dns.tsig -def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]: +def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, Any]: """Convert a dictionary containing (textual DNS name, base64 secret) pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or a dictionary containing (textual DNS name, (algorithm, base64 secret)) pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs. @rtype: dict""" - keyring = {} + keyring: Dict[dns.name.Name, Any] = {} for name, value in textring.items(): kname = dns.name.from_text(name) if isinstance(value, str): diff --git a/dns/zonefile.py b/dns/zonefile.py index a6a8531b2..af4778512 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -633,13 +633,13 @@ class RRSetsReaderManager(dns.transaction.TransactionManager): def __init__( self, origin: Optional[dns.name.Name] = dns.name.root, - relativize=False, - rdclass=dns.rdataclass.IN, + relativize: bool = False, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, ): self.origin = origin self.relativize = relativize self.rdclass = rdclass - self.rrsets = [] + self.rrsets: List[dns.rrset.RRset] = [] def reader(self): # pragma: no cover raise NotImplementedError @@ -658,7 +658,7 @@ def origin_information(self): effective = self.origin return (self.origin, self.relativize, effective) - def set_rrsets(self, rrsets): + def set_rrsets(self, rrsets: List[dns.rrset.RRset]) -> None: self.rrsets = rrsets