Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyright lint #1147

Merged
merged 12 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ check: test
type:
python -m mypy --install-types --non-interactive --disallow-incomplete-defs dns

pyright:
pyright dns

lint:
pylint dns

Expand Down
2 changes: 1 addition & 1 deletion dns/_immutable_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def nf(*args, **kwargs):
finally:
_in__init__.reset(previous)

nf.__signature__ = inspect.signature(f)
nf.__signature__ = inspect.signature(f) # pyright: ignore
return nf


Expand Down
73 changes: 48 additions & 25 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import dns.rdataclass
import dns.rdatatype
import dns.transaction
import dns.tsig
import dns.xfr
from dns._asyncbackend import NullContext
from dns.query import (
BadResponse,
Expand Down Expand Up @@ -219,9 +221,9 @@ async def udp(
dtuple = None
cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
async with cm as s:
await send_udp(s, wire, destination, expiration)
await send_udp(s, wire, destination, expiration) # pyright: ignore
(r, received_time, _) = await receive_udp(
s,
s, # pyright: ignore
destination,
expiration,
ignore_unexpected,
Expand Down Expand Up @@ -424,9 +426,14 @@ async def tcp(
af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
)
async with cm as s:
await send_tcp(s, wire, expiration)
await send_tcp(s, wire, expiration) # pyright: ignore
(r, received_time) = await receive_tcp(
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
s, # pyright: ignore
expiration,
one_rr_per_rrset,
q.keyring,
q.mac,
ignore_trailing,
)
r.time = received_time - begin_time
if not q.is_response(r):
Expand Down Expand Up @@ -469,7 +476,9 @@ async def tls(
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
ssl_context = _make_dot_ssl_context(server_hostname, verify)
ssl_context = _make_dot_ssl_context(
server_hostname, verify
) # pyright: ignore
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
Expand Down Expand Up @@ -505,8 +514,8 @@ async def tls(


def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"],
) -> "dns.asyncresolver.Resolver":
resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore
) -> "dns.asyncresolver.Resolver": # pyright: ignore
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
Expand All @@ -532,7 +541,7 @@ async def https(
post: bool = True,
verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
Expand All @@ -552,13 +561,13 @@ async def https(
af = dns.inet.af_for_address(where)
except ValueError:
af = None
# we bind url and then override as pyright can't figure out all paths bind.
url = where
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
else:
url = where

extensions = {}
if bootstrap_address is None:
Expand All @@ -577,8 +586,10 @@ async def https(
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = await resolver.resolve_name(parsed.hostname, family)
assert parsed.hostname is not None # pyright: ignore
answers = await resolver.resolve_name( # pyright: ignore
parsed.hostname, family # pyright: ignore
)
bootstrap_address = random.choice(list(answers.addresses()))
return await _http3(
q,
Expand All @@ -597,7 +608,7 @@ async def https(
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
if client and not isinstance(client, httpx.AsyncClient):
if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore
raise ValueError("session parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment

Expand Down Expand Up @@ -630,7 +641,9 @@ async def https(
family=family,
)

cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
cm = httpx.AsyncClient( # pyright: ignore
http1=h1, http2=h2, verify=verify, transport=transport
)

async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
Expand All @@ -643,7 +656,7 @@ async def https(
}
)
response = await backend.wait_for(
the_client.post(
the_client.post( # pyright: ignore
url,
headers=headers,
content=wire,
Expand All @@ -655,7 +668,7 @@ async def https(
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for(
the_client.get(
the_client.get( # pyright: ignore
url,
headers=headers,
params={"dns": twire},
Expand Down Expand Up @@ -785,9 +798,11 @@ async def quic(
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
the_connection = the_manager.connect( # pyright: ignore
where, port, source, source_port
)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
stream = await the_connection.make_stream(timeout) # pyright: ignore
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
Expand Down Expand Up @@ -829,6 +844,7 @@ async def _inbound_xfr(
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
r: Optional[dns.message.Message] = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
Expand All @@ -837,11 +853,11 @@ async def _inbound_xfr(
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
(rwire, _) = await udp_sock.recvfrom(65535, timeout)
(rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore
else:
ldata = await _read_exactly(tcp_sock, 2, mexpiration)
ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(tcp_sock, l, mexpiration)
rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
Expand All @@ -855,7 +871,7 @@ async def _inbound_xfr(
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
if query.keyring and r is not None and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")


Expand Down Expand Up @@ -896,8 +912,13 @@ async def inbound_xfr(
)
async with s:
try:
async for _ in _inbound_xfr(
txn_manager, s, query, serial, timeout, expiration
async for _ in _inbound_xfr( # pyright: ignore
txn_manager,
s,
query,
serial,
timeout,
expiration, # pyright: ignore
):
pass
return
Expand All @@ -909,5 +930,7 @@ async def inbound_xfr(
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
async for _ in _inbound_xfr( # pyright: ignore
txn_manager, s, query, serial, timeout, expiration # pyright: ignore
):
pass
5 changes: 4 additions & 1 deletion dns/asyncresolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
import dns.asyncbackend
import dns.asyncquery
import dns.exception
import dns.inet
import dns.name
import dns.nameserver
import dns.query
import dns.rdataclass
import dns.rdatatype
import dns.resolver # lgtm[py/import-and-import-from]
import dns.reversename

# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
Expand Down Expand Up @@ -426,7 +429,7 @@ async def make_resolver_at(
answers = await resolver.resolve_name(where, family)
for address in answers.addresses():
nameservers.append(dns.nameserver.Do53Nameserver(address, port))
res = dns.asyncresolver.Resolver(configure=False)
res = Resolver(configure=False)
res.nameservers = nameservers
return res

Expand Down
12 changes: 6 additions & 6 deletions dns/dnssec.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,16 @@ class Policy:
def __init__(self):
pass

def ok_to_sign(self, _: DNSKEY) -> bool: # pragma: no cover
def ok_to_sign(self, key: DNSKEY) -> bool: # pragma: no cover
return False

def ok_to_validate(self, _: DNSKEY) -> bool: # pragma: no cover
def ok_to_validate(self, key: DNSKEY) -> bool: # pragma: no cover
return False

def ok_to_create_ds(self, _: DSDigest) -> bool: # pragma: no cover
def ok_to_create_ds(self, algorithm: DSDigest) -> bool: # pragma: no cover
return False

def ok_to_validate_ds(self, _: DSDigest) -> bool: # pragma: no cover
def ok_to_validate_ds(self, algorithm: DSDigest) -> bool: # pragma: no cover
return False


Expand Down Expand Up @@ -587,7 +587,7 @@ def _sign(
signature=b"",
)

data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
data = _make_rrsig_signature_data(rrset, rrsig_template, origin)

# pylint: disable=possibly-used-before-assignment
if isinstance(private_key, GenericPrivateKey):
Expand Down Expand Up @@ -979,7 +979,7 @@ def default_rrset_signer(
keys = zsks

for private_key, dnskey in keys:
rrsig = dns.dnssec.sign(
rrsig = sign(
rrset=rrset,
private_key=private_key,
dnskey=dnskey,
Expand Down
2 changes: 1 addition & 1 deletion dns/e164.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def query(
for domain in domains:
if isinstance(domain, str):
domain = dns.name.from_text(domain)
qname = dns.e164.from_e164(number, domain)
qname = from_e164(number, domain)
try:
return resolver.resolve(qname, "NAPTR")
except dns.resolver.NXDOMAIN as e:
Expand Down
11 changes: 7 additions & 4 deletions dns/edns.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

import dns.enum
import dns.inet
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdata
import dns.wire

Expand Down Expand Up @@ -81,14 +84,14 @@ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
def to_text(self) -> str:
raise NotImplementedError # pragma: no cover

def to_generic(self) -> "dns.edns.GenericOption":
def to_generic(self) -> "GenericOption":
"""Creates a dns.edns.GenericOption equivalent of this rdata.

Returns a ``dns.edns.GenericOption``.
"""
wire = self.to_wire()
assert wire is not None # for mypy
return dns.edns.GenericOption(self.otype, wire)
return GenericOption(self.otype, wire)

@classmethod
def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
Expand Down Expand Up @@ -175,7 +178,7 @@ def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]:
def to_text(self) -> str:
return "Generic %d" % self.otype

def to_generic(self) -> "dns.edns.GenericOption":
def to_generic(self) -> "GenericOption":
return self

@classmethod
Expand Down Expand Up @@ -444,7 +447,7 @@ def from_wire_parser(

class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
super().__init__(dns.edns.OptionType.COOKIE)
super().__init__(OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
Expand Down
6 changes: 3 additions & 3 deletions dns/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import random
import threading
import time
from typing import Any, Optional
from typing import Any, Optional, Union


class EntropyPool:
Expand All @@ -45,15 +45,15 @@ def __init__(self, seed: Optional[bytes] = None):
self.seeded = False
self.seed_pid = 0

def _stir(self, entropy: bytes) -> None:
def _stir(self, entropy: Union[bytes, bytearray]) -> None:
for c in entropy:
if self.pool_index == self.hash_len:
self.pool_index = 0
b = c & 0xFF
self.pool[self.pool_index] ^= b
self.pool_index += 1

def stir(self, entropy: bytes) -> None:
def stir(self, entropy: Union[bytes, bytearray]) -> None:
with self.lock:
self._stir(entropy)

Expand Down
Loading