Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flake8 annotations #3098

Merged
merged 18 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ allowed-confusables = ["–"]

select = [
"A", # flake8-builtins
"ANN", # flake8-annotations
"ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions
Expand All @@ -131,6 +132,9 @@ select = [
]
extend-ignore = [
'A002', # builtin-argument-shadowing
'ANN101', # missing-type-self
'ANN102', # missing-type-cls
'ANN401', # any-type (mypy's `disallow_any_explicit` is better)
'E402', # module-import-not-at-top-of-file (usually OS-specific)
'E501', # line-too-long
'F403', # undefined-local-with-import-star
Expand Down Expand Up @@ -160,6 +164,8 @@ extend-ignore = [
'src/trio/_abc.py' = ['A005']
'src/trio/_socket.py' = ['A005']
'src/trio/_ssl.py' = ['A005']
# Don't check annotations in notes-to-self
'notes-to-self/*.py' = ['ANN001', 'ANN002', 'ANN003', 'ANN201', 'ANN202', 'ANN204']
Comment on lines +167 to +168
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whichever is merged latest of this and #3117 can remove these lines.


[tool.ruff.lint.isort]
combine-as-imports = true
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __new__( # type: ignore[misc] # "must return a subtype"
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int | float): # noqa: PYI041
def __init__(self, max_buffer_size: int | float) -> None: # noqa: PYI041
...

else:
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Instruments(dict[str, dict[Instrument, None]]):

__slots__ = ()

def __init__(self, incoming: Sequence[Instrument]):
def __init__(self, incoming: Sequence[Instrument]) -> None:
self["_all"] = {}
for instrument in incoming:
self.add_instrument(instrument)
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_mock_clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MockClock(Clock):

"""

def __init__(self, rate: float = 0.0, autojump_threshold: float = inf):
def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None:
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@ def __init__(
parent_task: Task,
cancel_scope: CancelScope,
strict_exception_groups: bool,
):
) -> None:
self._parent_task = parent_task
self._strict_exception_groups = strict_exception_groups
parent_task._child_nurseries.append(self)
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_dtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def challenge_for(


class _Queue(Generic[_T]):
def __init__(self, incoming_packets_buffer: int | float): # noqa: PYI041
def __init__(self, incoming_packets_buffer: int | float) -> None: # noqa: PYI041
self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)


Expand Down
7 changes: 6 additions & 1 deletion src/trio/_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from typing_extensions import Literal

from ._sync import CapacityLimiter

# This list is also in the docs, make sure to keep them in sync
_FILE_SYNC_ATTRS: set[str] = {
"closed",
Expand Down Expand Up @@ -241,7 +243,10 @@ def __getattr__(self, name: str) -> object:
meth = getattr(self._wrapped, name)

@async_wraps(self.__class__, self._wrapped.__class__, name)
async def wrapper(*args, **kwargs):
async def wrapper(
*args: Callable[..., T],
**kwargs: object | str | bool | CapacityLimiter | None,
) -> T:
func = partial(meth, *args, **kwargs)
return await trio.to_thread.run_sync(func)

Expand Down
4 changes: 2 additions & 2 deletions src/trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SocketStream(HalfCloseableStream):

"""

def __init__(self, socket: SocketType):
def __init__(self, socket: SocketType) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
Expand Down Expand Up @@ -364,7 +364,7 @@ class SocketListener(Listener[SocketStream]):

"""

def __init__(self, socket: SocketType):
def __init__(self, socket: SocketType) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TrioInteractiveConsole(InteractiveConsole):
# we make the type more specific on our subclass
locals: dict[str, object]

def __init__(self, repl_locals: dict[str, object] | None = None):
def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
super().__init__(locals=repl_locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT

Expand Down
4 changes: 2 additions & 2 deletions src/trio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class _try_sync:
def __init__(
self,
blocking_exc_override: Callable[[BaseException], bool] | None = None,
):
) -> None:
self._blocking_exc_override = blocking_exc_override

def _is_blocking_io_error(self, exc: BaseException) -> bool:
Expand Down Expand Up @@ -782,7 +782,7 @@ async def sendmsg(


class _SocketType(SocketType):
def __init__(self, sock: _stdlib_socket.socket):
def __init__(self, sock: _stdlib_socket.socket) -> None:
if type(sock) is not _stdlib_socket.socket:
# For example, ssl.SSLSocket subclasses socket.socket, but we
# certainly don't want to blindly wrap one of those.
Expand Down
8 changes: 4 additions & 4 deletions src/trio/_subprocess_platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@

elif os.name == "posix":

def create_pipe_to_child_stdin():
def create_pipe_to_child_stdin() -> tuple[trio.lowlevel.FdStream, int]:
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(wfd), rfd

def create_pipe_from_child_output():
def create_pipe_from_child_output() -> tuple[trio.lowlevel.FdStream, int]:
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(rfd), wfd

Expand All @@ -106,12 +106,12 @@

from .._windows_pipes import PipeReceiveStream, PipeSendStream

def create_pipe_to_child_stdin():
def create_pipe_to_child_stdin() -> tuple[PipeSendStream, int]:

Check warning on line 109 in src/trio/_subprocess_platform/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_subprocess_platform/__init__.py#L109

Added line #L109 was not covered by tests
# for stdin, we want the write end (our end) to use overlapped I/O
rh, wh = windows_pipe(overlapped=(False, True))
return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY)

def create_pipe_from_child_output():
def create_pipe_from_child_output() -> tuple[PipeReceiveStream, int]:

Check warning on line 114 in src/trio/_subprocess_platform/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_subprocess_platform/__init__.py#L114

Added line #L114 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a mypy setting for catching this as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are referring to disallow_incomplete_defs and check_untyped_defs, those two are already enabled for Trio.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant disallow_untyped_defs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disallow_untyped_defs is also already enabled

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then why did this not get caught by mypy previously??

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure

# for stdout/err, it's the read end that's overlapped
rh, wh = windows_pipe(overlapped=(True, False))
return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0)
Expand Down
6 changes: 3 additions & 3 deletions src/trio/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class CapacityLimiter(AsyncContextManagerMixin):
"""

# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float): # noqa: PYI041
def __init__(self, total_tokens: int | float) -> None: # noqa: PYI041
self._lot = ParkingLot()
self._borrowers: set[Task | object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
Expand Down Expand Up @@ -433,7 +433,7 @@ class Semaphore(AsyncContextManagerMixin):

"""

def __init__(self, initial_value: int, *, max_value: int | None = None):
def __init__(self, initial_value: int, *, max_value: int | None = None) -> None:
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an int")
if initial_value < 0:
Expand Down Expand Up @@ -759,7 +759,7 @@ class Condition(AsyncContextManagerMixin):

"""

def __init__(self, lock: Lock | None = None):
def __init__(self, lock: Lock | None = None) -> None:
if lock is None:
lock = Lock()
if type(lock) is not Lock:
Expand Down
5 changes: 1 addition & 4 deletions src/trio/_tests/test_highlevel_open_tcp_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ def setsockopt(
) -> None:
pass

async def bind(
self,
address: AddressFormat,
) -> None:
async def bind(self, address: AddressFormat) -> None:
pass

def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_tests/test_highlevel_ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> NoReturn:
) -> NoReturn: # pragma: no cover
raise NotImplementedError


Expand Down
15 changes: 11 additions & 4 deletions src/trio/_tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _frozenbind(
flags: int = 0,
) -> GetAddrInfoArgs:
sig = inspect.signature(self._orig_getaddrinfo)
bound = sig.bind(host, port, family, type, proto, flags)
bound = sig.bind(host, port, family=family, type=type, proto=proto, flags=flags)
bound.apply_defaults()
frozenbound = bound.args
assert not bound.kwargs
Expand All @@ -95,9 +95,16 @@ def set(
proto: int = 0,
flags: int = 0,
) -> None:
self._responses[self._frozenbind(host, port, family, type, proto, flags)] = (
response
)
self._responses[
self._frozenbind(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
)
] = response

def getaddrinfo(
self,
Expand Down
63 changes: 41 additions & 22 deletions src/trio/_tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def ssl_echo_serve_sync(
# Fixture that gives a raw socket connected to a trio-test-1 echo server
# (running in a thread). Useful for testing making connections with different
# SSLContexts.
@asynccontextmanager # type: ignore[misc] # decorated contains Any
async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
@asynccontextmanager
async def ssl_echo_server_raw(expect_fail: bool = False) -> AsyncIterator[SocketStream]:
a, b = stdlib_socket.socketpair()
async with trio.open_nursery() as nursery:
# Exiting the 'with a, b' context manager closes the sockets, which
Expand All @@ -178,20 +178,20 @@ async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
with a, b:
nursery.start_soon(
trio.to_thread.run_sync,
partial(ssl_echo_serve_sync, b, **kwargs),
partial(ssl_echo_serve_sync, b, expect_fail=expect_fail),
)

yield SocketStream(tsocket.from_stdlib_socket(a))


# Fixture that gives a properly set up SSLStream connected to a trio-test-1
# echo server (running in a thread)
@asynccontextmanager # type: ignore[misc] # decorated contains Any
@asynccontextmanager
async def ssl_echo_server(
client_ctx: SSLContext,
**kwargs: Any,
expect_fail: bool = False,
) -> AsyncIterator[SSLStream[Stream]]:
async with ssl_echo_server_raw(**kwargs) as sock:
async with ssl_echo_server_raw(expect_fail=expect_fail) as sock:
yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")


Expand All @@ -201,7 +201,10 @@ async def ssl_echo_server(
# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit
# from Stream for the sake of typechecking.
class PyOpenSSLEchoStream(Stream):
def __init__(self, sleeper: None = None) -> None:
def __init__(
self,
sleeper: Callable[[str], Awaitable[None]] | None = None,
) -> None:
ctx = SSL.Context(SSL.SSLv23_METHOD)
# TLS 1.3 removes renegotiation support. Which is great for them, but
# we still have to support versions before that, and that means we
Expand Down Expand Up @@ -249,6 +252,7 @@ def __init__(self, sleeper: None = None) -> None:
"simultaneous calls to PyOpenSSLEchoStream.receive_some",
)

self.sleeper: Callable[[str], Awaitable[None]]
if sleeper is None:

async def no_op_sleeper(_: object) -> None:
Expand Down Expand Up @@ -384,12 +388,12 @@ async def do_test(
await do_test("receive_some", (1,), "receive_some", (1,))


@contextmanager # type: ignore[misc] # decorated contains Any
@contextmanager
def virtual_ssl_echo_server(
client_ctx: SSLContext,
**kwargs: Any,
sleeper: Callable[[str], Awaitable[None]] | None = None,
) -> Iterator[SSLStream[PyOpenSSLEchoStream]]:
fakesock = PyOpenSSLEchoStream(**kwargs)
fakesock = PyOpenSSLEchoStream(sleeper=sleeper)
yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org")


Expand Down Expand Up @@ -424,31 +428,43 @@ def ssl_wrap_pair( # type: ignore[misc]
MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream]


# Explicit "Any" is not allowed
def ssl_memory_stream_pair( # type: ignore[misc]
def ssl_memory_stream_pair(
client_ctx: SSLContext,
**kwargs: Any,
client_kwargs: dict[str, str | bytes | bool | None] | None = None,
server_kwargs: dict[str, str | bytes | bool | None] | None = None,
) -> tuple[
SSLStream[MemoryStapledStream],
SSLStream[MemoryStapledStream],
]:
client_transport, server_transport = memory_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
return ssl_wrap_pair(
client_ctx,
client_transport,
server_transport,
client_kwargs=client_kwargs,
server_kwargs=server_kwargs,
)


MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream]


# Explicit "Any" is not allowed
def ssl_lockstep_stream_pair( # type: ignore[misc]
def ssl_lockstep_stream_pair(
client_ctx: SSLContext,
**kwargs: Any,
client_kwargs: dict[str, str | bytes | bool | None] | None = None,
server_kwargs: dict[str, str | bytes | bool | None] | None = None,
) -> tuple[
SSLStream[MyStapledStream],
SSLStream[MyStapledStream],
]:
client_transport, server_transport = lockstep_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
return ssl_wrap_pair(
client_ctx,
client_transport,
server_transport,
client_kwargs=client_kwargs,
server_kwargs=server_kwargs,
)


# Simple smoke test for handshake/send/receive/shutdown talking to a
Expand Down Expand Up @@ -1327,15 +1343,18 @@ async def test_getpeercert(client_ctx: SSLContext) -> None:


async def test_SSLListener(client_ctx: SSLContext) -> None:
# Explicit "Any" is not allowed
async def setup( # type: ignore[misc]
**kwargs: Any,
async def setup(
https_compatible: bool = False,
) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
socket_listener = SocketListener(listen_sock)
ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs)
ssl_listener = SSLListener(
socket_listener,
SERVER_CTX,
https_compatible=https_compatible,
)

transport_client = await open_tcp_stream(*listen_sock.getsockname())
ssl_client = SSLStream(
Expand Down
Loading
Loading