From be39867137ee4593a8ab6cc3f3db83d7072eca05 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:14:55 -0500 Subject: [PATCH] Add typing for some of the tests (#2771) * Add typing for tests `test_channel` had an error, was awaiting a synchronous function * Use `Union` for runtime eval types * More WIP tests * Finish `test_deprecate.py` I think * Work on `test_dtls` * Import `TypeAlias` from typing_extensions if type checking * More work on `test_dtls` and complete `tutil` * Work on `test_asyncgen` * Silence some issues from FakeNet (needs another PR) and other type fixes * Fix return type of `pipe_with_overlapped_read` * Ignore FakeNet typing issues for the moment * Help mypy understand we are in 3.10 * 2nd try fixing mypy missing `contextlib.aclosing` * Use assert instead * Use type var to indicate input is unchanged * Make parameter more specific * Use a version check mypy will understand * Fix missing `Callable` import * Attempt inlining `aclosing` Co-authored-by: TeamSpen210 * It's `aclose` not `close` * Suggestions by TeamSpen210 Co-authored-by: TeamSpen210 * Ignore incorrect typing for LogCaptureFixture * Type check tests more strictly * Remove type alias for `pytest.WarningsRecorder` * Fix a bunch of mypy errors * Use `Any` for generic `CaptureFixture`'s argument * Ignore `skipif` leaves function untyped error * Use `str` for `CaptureFixture`'s generic and remove skipif ignore --------- Co-authored-by: TeamSpen210 --- pyproject.toml | 9 -- trio/_core/_tests/test_asyncgen.py | 93 ++++++++------- trio/_core/_tests/test_tutil.py | 2 +- trio/_core/_tests/test_windows.py | 42 ++++--- trio/_core/_tests/tutil.py | 17 ++- trio/_tests/check_type_completeness.py | 4 +- trio/_tests/pytest_plugin.py | 10 +- trio/_tests/test_abc.py | 18 +-- trio/_tests/test_channel.py | 96 ++++++++------- trio/_tests/test_deprecate.py | 82 ++++++++----- trio/_tests/test_dtls.py | 154 ++++++++++++++----------- 11 files changed, 300 insertions(+), 227 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17dd2aa1b7..b5e3d43153 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,6 @@ module = [ # tests "trio/testing/_fake_net", -"trio/_core/_tests/test_asyncgen", "trio/_core/_tests/test_guest_mode", "trio/_core/_tests/test_instrumentation", "trio/_core/_tests/test_ki", @@ -82,15 +81,7 @@ module = [ "trio/_core/_tests/test_multierror_scripts/simple_excepthook", "trio/_core/_tests/test_parking_lot", "trio/_core/_tests/test_thread_cache", -"trio/_core/_tests/test_tutil", "trio/_core/_tests/test_unbounded_queue", -"trio/_core/_tests/test_windows", -"trio/_core/_tests/tutil", -"trio/_tests/pytest_plugin", -"trio/_tests/test_abc", -"trio/_tests/test_channel", -"trio/_tests/test_deprecate", -"trio/_tests/test_dtls", "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", diff --git a/trio/_core/_tests/test_asyncgen.py b/trio/_core/_tests/test_asyncgen.py index f72d5c6859..7e6a5fb4b9 100644 --- a/trio/_core/_tests/test_asyncgen.py +++ b/trio/_core/_tests/test_asyncgen.py @@ -1,7 +1,10 @@ -import contextlib +from __future__ import annotations + import sys import weakref +from collections.abc import AsyncGenerator from math import inf +from typing import NoReturn import pytest @@ -9,11 +12,10 @@ from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook -@pytest.mark.skipif(sys.version_info < (3, 10), reason="no aclosing() in stdlib<3.10") -def test_asyncgen_basics(): +def test_asyncgen_basics() -> None: collected = [] - async def example(cause): + async def example(cause: str) -> AsyncGenerator[int, None]: try: try: yield 42 @@ -37,7 +39,7 @@ async def example(cause): saved = [] - async def async_main(): + async def async_main() -> None: # GC'ed before exhausted with pytest.warns( ResourceWarning, match="Async generator.*collected before.*exhausted" @@ -47,9 +49,11 @@ async def async_main(): await _core.wait_all_tasks_blocked() assert collected.pop() == "abandoned" - # aclosing() ensures it's cleaned up at point of use - async with contextlib.aclosing(example("exhausted 1")) as aiter: + aiter = example("exhausted 1") + try: assert 42 == await aiter.asend(None) + finally: + await aiter.aclose() assert collected.pop() == "exhausted 1" # Also fine if you exhaust it at point of use @@ -60,9 +64,12 @@ async def async_main(): gc_collect_harder() # No problems saving the geniter when using either of these patterns - async with contextlib.aclosing(example("exhausted 3")) as aiter: + aiter = example("exhausted 3") + try: saved.append(aiter) assert 42 == await aiter.asend(None) + finally: + await aiter.aclose() assert collected.pop() == "exhausted 3" # Also fine if you exhaust it at point of use @@ -85,10 +92,12 @@ async def async_main(): assert agen.ag_frame is None # all should now be exhausted -async def test_asyncgen_throws_during_finalization(caplog): +async def test_asyncgen_throws_during_finalization( + caplog: pytest.LogCaptureFixture, +) -> None: record = [] - async def agen(): + async def agen() -> AsyncGenerator[int, None]: try: yield 1 finally: @@ -101,18 +110,19 @@ async def agen(): gc_collect_harder() await _core.wait_all_tasks_blocked() assert record == ["crashing"] - exc_type, exc_value, exc_traceback = caplog.records[0].exc_info + # Following type ignore is because typing for LogCaptureFixture is wrong + exc_type, exc_value, exc_traceback = caplog.records[0].exc_info # type: ignore[misc] assert exc_type is ValueError assert str(exc_value) == "oops" assert "during finalization of async generator" in caplog.records[0].message @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_firstiter_after_closing(): +def test_firstiter_after_closing() -> None: saved = [] record = [] - async def funky_agen(): + async def funky_agen() -> AsyncGenerator[int, None]: try: yield 1 except GeneratorExit: @@ -124,7 +134,7 @@ async def funky_agen(): record.append("cleanup 2") await funky_agen().asend(None) - async def async_main(): + async def async_main() -> None: aiter = funky_agen() saved.append(aiter) assert 1 == await aiter.asend(None) @@ -135,18 +145,20 @@ async def async_main(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_interdependent_asyncgen_cleanup_order(): - saved = [] - record = [] +def test_interdependent_asyncgen_cleanup_order() -> None: + saved: list[AsyncGenerator[int, None]] = [] + record: list[int | str] = [] - async def innermost(): + async def innermost() -> AsyncGenerator[int, None]: try: yield 1 finally: await _core.cancel_shielded_checkpoint() record.append("innermost") - async def agen(label, inner): + async def agen( + label: int, inner: AsyncGenerator[int, None] + ) -> AsyncGenerator[int, None]: try: yield await inner.asend(None) finally: @@ -158,7 +170,7 @@ async def agen(label, inner): await inner.asend(None) record.append(label) - async def async_main(): + async def async_main() -> None: # This makes a chain of 101 interdependent asyncgens: # agen(99)'s cleanup will iterate agen(98)'s will iterate # ... agen(0)'s will iterate innermost()'s @@ -174,19 +186,20 @@ async def async_main(): @restore_unraisablehook() -def test_last_minute_gc_edge_case(): - saved = [] +def test_last_minute_gc_edge_case() -> None: + saved: list[AsyncGenerator[int, None]] = [] record = [] needs_retry = True - async def agen(): + async def agen() -> AsyncGenerator[int, None]: try: yield 1 finally: record.append("cleaned up") - def collect_at_opportune_moment(token): + def collect_at_opportune_moment(token: _core._entry_queue.TrioToken) -> None: runner = _core._run.GLOBAL_RUN_CONTEXT.runner + assert runner.system_nursery is not None if runner.system_nursery._closed and isinstance( runner.asyncgens.alive, weakref.WeakSet ): @@ -201,7 +214,7 @@ def collect_at_opportune_moment(token): nonlocal needs_retry needs_retry = True - async def async_main(): + async def async_main() -> None: token = _core.current_trio_token() token.run_sync_soon(collect_at_opportune_moment, token) saved.append(agen()) @@ -231,7 +244,7 @@ async def async_main(): ) -async def step_outside_async_context(aiter): +async def step_outside_async_context(aiter: AsyncGenerator[int, None]) -> None: # abort_fns run outside of task context, at least if they're # triggered by a deadline expiry rather than a direct # cancellation. Thus, an asyncgen first iterated inside one @@ -242,13 +255,13 @@ async def step_outside_async_context(aiter): # NB: the strangeness with aiter being an attribute of abort_fn is # to make it as easy as possible to ensure we don't hang onto a # reference to aiter inside the guts of the run loop. - def abort_fn(_): + def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: with pytest.raises(StopIteration, match="42"): - abort_fn.aiter.asend(None).send(None) - del abort_fn.aiter + abort_fn.aiter.asend(None).send(None) # type: ignore[attr-defined] # Callables don't have attribute "aiter" + del abort_fn.aiter # type: ignore[attr-defined] return _core.Abort.SUCCEEDED - abort_fn.aiter = aiter + abort_fn.aiter = aiter # type: ignore[attr-defined] async with _core.open_nursery() as nursery: nursery.start_soon(_core.wait_task_rescheduled, abort_fn) @@ -257,16 +270,18 @@ def abort_fn(_): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -async def test_fallback_when_no_hook_claims_it(capsys): - async def well_behaved(): +async def test_fallback_when_no_hook_claims_it( + capsys: pytest.CaptureFixture[str], +) -> None: + async def well_behaved() -> AsyncGenerator[int, None]: yield 42 - async def yields_after_yield(): + async def yields_after_yield() -> AsyncGenerator[int, None]: with pytest.raises(GeneratorExit): yield 42 yield 100 - async def awaits_after_yield(): + async def awaits_after_yield() -> AsyncGenerator[int, None]: with pytest.raises(GeneratorExit): yield 42 await _core.cancel_shielded_checkpoint() @@ -286,16 +301,16 @@ async def awaits_after_yield(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_delegation_to_existing_hooks(): +def test_delegation_to_existing_hooks() -> None: record = [] - def my_firstiter(agen): + def my_firstiter(agen: AsyncGenerator[object, NoReturn]) -> None: record.append("firstiter " + agen.ag_frame.f_locals["arg"]) - def my_finalizer(agen): + def my_finalizer(agen: AsyncGenerator[object, NoReturn]) -> None: record.append("finalizer " + agen.ag_frame.f_locals["arg"]) - async def example(arg): + async def example(arg: str) -> AsyncGenerator[int, None]: try: yield 42 finally: @@ -303,7 +318,7 @@ async def example(arg): await _core.checkpoint() record.append("trio collected " + arg) - async def async_main(): + async def async_main() -> None: await step_outside_async_context(example("theirs")) assert 42 == await example("ours").asend(None) gc_collect_harder() diff --git a/trio/_core/_tests/test_tutil.py b/trio/_core/_tests/test_tutil.py index eb16de883f..07bba9407d 100644 --- a/trio/_core/_tests/test_tutil.py +++ b/trio/_core/_tests/test_tutil.py @@ -3,7 +3,7 @@ from .tutil import check_sequence_matches -def test_check_sequence_matches(): +def test_check_sequence_matches() -> None: check_sequence_matches([1, 2, 3], [1, 2, 3]) with pytest.raises(AssertionError): check_sequence_matches([1, 3, 2], [1, 2, 3]) diff --git a/trio/_core/_tests/test_windows.py b/trio/_core/_tests/test_windows.py index 99bb97284b..7beb59cc21 100644 --- a/trio/_core/_tests/test_windows.py +++ b/trio/_core/_tests/test_windows.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import os import sys import tempfile +from collections.abc import Generator from contextlib import contextmanager +from io import BufferedWriter from typing import TYPE_CHECKING from unittest.mock import create_autospec @@ -27,7 +31,7 @@ ) -def test_winerror(monkeypatch) -> None: +def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None: mock = create_autospec(ffi.getwinerror) monkeypatch.setattr(ffi, "getwinerror", mock) @@ -68,8 +72,8 @@ def test_winerror(monkeypatch) -> None: # UnboundedQueue (or just removed until we have time to redo it), but until # then we filter out the warning. @pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") -async def test_completion_key_listen(): - async def post(key): +async def test_completion_key_listen() -> None: + async def post(key: int) -> None: iocp = ffi.cast("HANDLE", _core.current_iocp()) for i in range(10): print("post", i) @@ -94,7 +98,7 @@ async def post(key): print("end loop") -async def test_readinto_overlapped(): +async def test_readinto_overlapped() -> None: data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024 buffer = bytearray(len(data)) @@ -121,7 +125,7 @@ async def test_readinto_overlapped(): try: with memoryview(buffer) as buffer_view: - async def read_region(start, end): + async def read_region(start: int, end: int) -> None: await _core.readinto_overlapped( handle, buffer_view[start:end], start ) @@ -140,7 +144,7 @@ async def read_region(start, end): @contextmanager -def pipe_with_overlapped_read(): +def pipe_with_overlapped_read() -> Generator[tuple[BufferedWriter, int], None, None]: import msvcrt from asyncio.windows_utils import pipe @@ -154,14 +158,14 @@ def pipe_with_overlapped_read(): @restore_unraisablehook() -def test_forgot_to_register_with_iocp(): +def test_forgot_to_register_with_iocp() -> None: with pipe_with_overlapped_read() as (write_fp, read_handle): with write_fp: write_fp.write(b"test\n") left_run_yet = False - async def main(): + async def main() -> None: target = bytearray(1) try: async with _core.open_nursery() as nursery: @@ -188,7 +192,7 @@ async def main(): @slow -async def test_too_late_to_cancel(): +async def test_too_late_to_cancel() -> None: import time with pipe_with_overlapped_read() as (write_fp, read_handle): @@ -216,11 +220,15 @@ async def test_too_late_to_cancel(): assert target[:6] == b"test2\n" -def test_lsp_that_hooks_select_gives_good_error(monkeypatch): +def test_lsp_that_hooks_select_gives_good_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: from .. import _io_windows - from .._windows_cffi import WSAIoctls, _handle + from .._windows_cffi import CData, WSAIoctls, _handle - def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + def patched_get_underlying( + sock: int | CData, *, which: int = WSAIoctls.SIO_BASE_HANDLE + ) -> CData: if hasattr(sock, "fileno"): # pragma: no branch sock = sock.fileno() if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: @@ -235,16 +243,20 @@ def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): _core.run(sleep, 0) -def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): +def test_lsp_that_completely_hides_base_socket_gives_good_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to # make sure we get an error rather than an infinite loop. from .. import _io_windows - from .._windows_cffi import WSAIoctls, _handle + from .._windows_cffi import CData, WSAIoctls, _handle - def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + def patched_get_underlying( + sock: int | CData, *, which: int = WSAIoctls.SIO_BASE_HANDLE + ) -> CData: if hasattr(sock, "fileno"): # pragma: no branch sock = sock.fileno() if which == WSAIoctls.SIO_BASE_HANDLE: diff --git a/trio/_core/_tests/tutil.py b/trio/_core/_tests/tutil.py index b3aa73fb7d..070af8ed15 100644 --- a/trio/_core/_tests/tutil.py +++ b/trio/_core/_tests/tutil.py @@ -1,10 +1,13 @@ # Utilities for testing +from __future__ import annotations + import asyncio import gc import os import socket as stdlib_socket import sys import warnings +from collections.abc import Generator, Iterable, Sequence from contextlib import closing, contextmanager from typing import TYPE_CHECKING @@ -50,7 +53,7 @@ binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6") -def gc_collect_harder(): +def gc_collect_harder() -> None: # In the test suite we sometimes want to call gc.collect() to make sure # that any objects with noisy __del__ methods (e.g. unawaited coroutines) # get collected before we continue, so their noise doesn't leak into @@ -69,7 +72,7 @@ def gc_collect_harder(): # manager should be used anywhere this happens to hide those messages, because # when expected they're clutter. @contextmanager -def ignore_coroutine_never_awaited_warnings(): +def ignore_coroutine_never_awaited_warnings() -> Generator[None, None, None]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited") try: @@ -80,12 +83,12 @@ def ignore_coroutine_never_awaited_warnings(): gc_collect_harder() -def _noop(*args, **kwargs): +def _noop(*args: object, **kwargs: object) -> None: pass @contextmanager -def restore_unraisablehook(): +def restore_unraisablehook() -> Generator[None, None, None]: sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook try: yield @@ -95,7 +98,9 @@ def restore_unraisablehook(): # template is like: # [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3] -def check_sequence_matches(seq, template): +def check_sequence_matches( + seq: Sequence[object], template: Iterable[object | set[object]] +) -> None: i = 0 for pattern in template: if not isinstance(pattern, set): @@ -115,6 +120,6 @@ def check_sequence_matches(seq, template): ) -def create_asyncio_future_in_new_loop(): +def create_asyncio_future_in_new_loop() -> asyncio.Future[object]: with closing(asyncio.new_event_loop()) as loop: return loop.create_future() diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py index 1352926be3..233a2ab5dc 100755 --- a/trio/_tests/check_type_completeness.py +++ b/trio/_tests/check_type_completeness.py @@ -36,8 +36,8 @@ def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]: def check_less_than( key: str, - current_dict: Mapping[str, float], - last_dict: Mapping[str, float], + current_dict: Mapping[str, int | float], + last_dict: Mapping[str, int | float], /, invert: bool = False, ) -> None: diff --git a/trio/_tests/pytest_plugin.py b/trio/_tests/pytest_plugin.py index c6d73e25ea..2170a1e8b6 100644 --- a/trio/_tests/pytest_plugin.py +++ b/trio/_tests/pytest_plugin.py @@ -7,22 +7,22 @@ RUN_SLOW = True -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption("--run-slow", action="store_true", help="run slow tests") -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: global RUN_SLOW RUN_SLOW = config.getoption("--run-slow", True) @pytest.fixture -def mock_clock(): +def mock_clock() -> MockClock: return MockClock() @pytest.fixture -def autojump_clock(): +def autojump_clock() -> MockClock: return MockClock(autojump_threshold=0) @@ -31,6 +31,6 @@ def autojump_clock(): # guess it's useful with the class- and file-level marking machinery (where # the raw @trio_test decorator isn't enough). @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): +def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None: if inspect.iscoroutinefunction(pyfuncitem.obj): pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/_tests/test_abc.py b/trio/_tests/test_abc.py index 2b0b7088b0..b93c48f539 100644 --- a/trio/_tests/test_abc.py +++ b/trio/_tests/test_abc.py @@ -1,15 +1,17 @@ +from __future__ import annotations + import attr import pytest from .. import abc as tabc -async def test_AsyncResource_defaults(): +async def test_AsyncResource_defaults() -> None: @attr.s class MyAR(tabc.AsyncResource): - record = attr.ib(factory=list) + record: list[str] = attr.ib(factory=list) - async def aclose(self): + async def aclose(self) -> None: self.record.append("ac") async with MyAR() as myar: @@ -19,7 +21,7 @@ async def aclose(self): assert myar.record == ["ac"] -def test_abc_generics(): +def test_abc_generics() -> None: # Pythons below 3.5.2 had a typing.Generic that would throw # errors when instantiating or subclassing a parameterized # version of a class with any __slots__. This is why RunVar @@ -30,16 +32,16 @@ def test_abc_generics(): class SlottedChannel(tabc.SendChannel[tabc.Stream]): __slots__ = ("x",) - def send_nowait(self, value): + def send_nowait(self, value: object) -> None: raise RuntimeError - async def send(self, value): + async def send(self, value: object) -> None: raise RuntimeError # pragma: no cover - def clone(self): + def clone(self) -> None: raise RuntimeError # pragma: no cover - async def aclose(self): + async def aclose(self) -> None: pass # pragma: no cover channel = SlottedChannel() diff --git a/trio/_tests/test_channel.py b/trio/_tests/test_channel.py index 4478c523f5..c81933b6b7 100644 --- a/trio/_tests/test_channel.py +++ b/trio/_tests/test_channel.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Union + import pytest import trio @@ -6,13 +10,13 @@ from ..testing import assert_checkpoints, wait_all_tasks_blocked -async def test_channel(): +async def test_channel() -> None: with pytest.raises(TypeError): open_memory_channel(1.0) with pytest.raises(ValueError): open_memory_channel(-1) - s, r = open_memory_channel(2) + s, r = open_memory_channel[Union[int, str, None]](2) repr(s) # smoke test repr(r) # smoke test @@ -45,26 +49,26 @@ async def test_channel(): with pytest.raises(trio.ClosedResourceError): await r.receive() with pytest.raises(trio.ClosedResourceError): - await r.receive_nowait() + r.receive_nowait() await r.aclose() -async def test_553(autojump_clock): - s, r = open_memory_channel(1) +async def test_553(autojump_clock: trio.abc.Clock) -> None: + s, r = open_memory_channel[str](1) with trio.move_on_after(10) as timeout_scope: await r.receive() assert timeout_scope.cancelled_caught await s.send("Test for PR #553") -async def test_channel_multiple_producers(): - async def producer(send_channel, i): +async def test_channel_multiple_producers() -> None: + async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None: # We close our handle when we're done with it async with send_channel: for j in range(3 * i, 3 * (i + 1)): await send_channel.send(j) - send_channel, receive_channel = open_memory_channel(0) + send_channel, receive_channel = open_memory_channel[int](0) async with trio.open_nursery() as nursery: # We hand out clones to all the new producers, and then close the # original. @@ -80,17 +84,17 @@ async def producer(send_channel, i): assert got == list(range(30)) -async def test_channel_multiple_consumers(): +async def test_channel_multiple_consumers() -> None: successful_receivers = set() received = [] - async def consumer(receive_channel, i): + async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> None: async for value in receive_channel: successful_receivers.add(i) received.append(value) async with trio.open_nursery() as nursery: - send_channel, receive_channel = trio.open_memory_channel(1) + send_channel, receive_channel = trio.open_memory_channel[int](1) async with send_channel: for i in range(5): nursery.start_soon(consumer, receive_channel, i) @@ -103,13 +107,15 @@ async def consumer(receive_channel, i): assert set(received) == set(range(10)) -async def test_close_basics(): - async def send_block(s, expect): +async def test_close_basics() -> None: + async def send_block( + s: trio.MemorySendChannel[None], expect: type[BaseException] + ) -> None: with pytest.raises(expect): await s.send(None) # closing send -> other send gets ClosedResourceError - s, r = open_memory_channel(0) + s, r = open_memory_channel[None](0) async with trio.open_nursery() as nursery: nursery.start_soon(send_block, s, trio.ClosedResourceError) await wait_all_tasks_blocked() @@ -128,7 +134,7 @@ async def send_block(s, expect): await r.receive() # closing receive -> send gets BrokenResourceError - s, r = open_memory_channel(0) + s, r = open_memory_channel[None](0) async with trio.open_nursery() as nursery: nursery.start_soon(send_block, s, trio.BrokenResourceError) await wait_all_tasks_blocked() @@ -141,11 +147,11 @@ async def send_block(s, expect): await s.send(None) # closing receive -> other receive gets ClosedResourceError - async def receive_block(r): + async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() - s, r = open_memory_channel(0) + s, r = open_memory_channel[None](0) async with trio.open_nursery() as nursery: nursery.start_soon(receive_block, r) await wait_all_tasks_blocked() @@ -158,13 +164,15 @@ async def receive_block(r): await r.receive() -async def test_close_sync(): - async def send_block(s, expect): +async def test_close_sync() -> None: + async def send_block( + s: trio.MemorySendChannel[None], expect: type[BaseException] + ) -> None: with pytest.raises(expect): await s.send(None) # closing send -> other send gets ClosedResourceError - s, r = open_memory_channel(0) + s, r = open_memory_channel[None](0) async with trio.open_nursery() as nursery: nursery.start_soon(send_block, s, trio.ClosedResourceError) await wait_all_tasks_blocked() @@ -183,7 +191,7 @@ async def send_block(s, expect): await r.receive() # closing receive -> send gets BrokenResourceError - s, r = open_memory_channel(0) + s, r = open_memory_channel[None](0) async with trio.open_nursery() as nursery: nursery.start_soon(send_block, s, trio.BrokenResourceError) await wait_all_tasks_blocked() @@ -196,11 +204,11 @@ async def send_block(s, expect): await s.send(None) # closing receive -> other receive gets ClosedResourceError - async def receive_block(r): + async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() - s, r = open_memory_channel(0) + s, r = open_memory_channel[None](0) async with trio.open_nursery() as nursery: nursery.start_soon(receive_block, r) await wait_all_tasks_blocked() @@ -213,8 +221,8 @@ async def receive_block(r): await r.receive() -async def test_receive_channel_clone_and_close(): - s, r = open_memory_channel(10) +async def test_receive_channel_clone_and_close() -> None: + s, r = open_memory_channel[None](10) r2 = r.clone() r3 = r.clone() @@ -240,17 +248,17 @@ async def test_receive_channel_clone_and_close(): s.send_nowait(None) -async def test_close_multiple_send_handles(): +async def test_close_multiple_send_handles() -> None: # With multiple send handles, closing one handle only wakes senders on # that handle, but others can continue just fine - s1, r = open_memory_channel(0) + s1, r = open_memory_channel[str](0) s2 = s1.clone() - async def send_will_close(): + async def send_will_close() -> None: with pytest.raises(trio.ClosedResourceError): await s1.send("nope") - async def send_will_succeed(): + async def send_will_succeed() -> None: await s2.send("ok") async with trio.open_nursery() as nursery: @@ -261,17 +269,17 @@ async def send_will_succeed(): assert await r.receive() == "ok" -async def test_close_multiple_receive_handles(): +async def test_close_multiple_receive_handles() -> None: # With multiple receive handles, closing one handle only wakes receivers on # that handle, but others can continue just fine - s, r1 = open_memory_channel(0) + s, r1 = open_memory_channel[str](0) r2 = r1.clone() - async def receive_will_close(): + async def receive_will_close() -> None: with pytest.raises(trio.ClosedResourceError): await r1.receive() - async def receive_will_succeed(): + async def receive_will_succeed() -> None: assert await r2.receive() == "ok" async with trio.open_nursery() as nursery: @@ -282,8 +290,8 @@ async def receive_will_succeed(): await s.send("ok") -async def test_inf_capacity(): - s, r = open_memory_channel(float("inf")) +async def test_inf_capacity() -> None: + s, r = open_memory_channel[int](float("inf")) # It's accepted, and we can send all day without blocking with s: @@ -296,8 +304,8 @@ async def test_inf_capacity(): assert got == list(range(10)) -async def test_statistics(): - s, r = open_memory_channel(2) +async def test_statistics() -> None: + s, r = open_memory_channel[None](2) assert s.statistics() == r.statistics() stats = s.statistics() @@ -346,10 +354,10 @@ async def test_statistics(): assert s.statistics().tasks_waiting_receive == 0 -async def test_channel_fairness(): +async def test_channel_fairness() -> None: # We can remove an item we just sent, and send an item back in after, if # no-one else is waiting. - s, r = open_memory_channel(1) + s, r = open_memory_channel[Union[int, None]](1) s.send_nowait(1) assert r.receive_nowait() == 1 s.send_nowait(2) @@ -360,7 +368,7 @@ async def test_channel_fairness(): result = None - async def do_receive(r): + async def do_receive(r: trio.MemoryReceiveChannel[int]) -> None: nonlocal result result = await r.receive() @@ -375,7 +383,7 @@ async def do_receive(r): # And the analogous situation for send: if we free up a space, we can't # immediately send something in it if someone is already waiting to do # that - s, r = open_memory_channel(1) + s, r = open_memory_channel[Union[int, None]](1) s.send_nowait(1) with pytest.raises(trio.WouldBlock): s.send_nowait(None) @@ -388,14 +396,14 @@ async def do_receive(r): assert (await r.receive()) == 2 -async def test_unbuffered(): - s, r = open_memory_channel(0) +async def test_unbuffered() -> None: + s, r = open_memory_channel[int](0) with pytest.raises(trio.WouldBlock): r.receive_nowait() with pytest.raises(trio.WouldBlock): s.send_nowait(1) - async def do_send(s, v): + async def do_send(s: trio.MemorySendChannel[int], v: int) -> None: with assert_checkpoints(): await s.send(v) diff --git a/trio/_tests/test_deprecate.py b/trio/_tests/test_deprecate.py index 33c05ffd25..da548fc715 100644 --- a/trio/_tests/test_deprecate.py +++ b/trio/_tests/test_deprecate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import warnings @@ -13,7 +15,7 @@ @pytest.fixture -def recwarn_always(recwarn): +def recwarn_always(recwarn: pytest.WarningsRecorder) -> pytest.WarningsRecorder: warnings.simplefilter("always") # ResourceWarnings about unclosed sockets can occur nondeterministically # (during GC) which throws off the tests in this file @@ -21,19 +23,23 @@ def recwarn_always(recwarn): return recwarn -def _here(): - info = inspect.getframeinfo(inspect.currentframe().f_back) +def _here() -> tuple[str, int]: + frame = inspect.currentframe() + assert frame is not None + assert frame.f_back is not None + info = inspect.getframeinfo(frame.f_back) return (info.filename, info.lineno) -def test_warn_deprecated(recwarn_always): - def deprecated_thing(): +def test_warn_deprecated(recwarn_always: pytest.WarningsRecorder) -> None: + def deprecated_thing() -> None: warn_deprecated("ice", "1.2", issue=1, instead="water") deprecated_thing() filename, lineno = _here() assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "ice is deprecated" in got.message.args[0] assert "Trio 1.2" in got.message.args[0] assert "water instead" in got.message.args[0] @@ -42,21 +48,24 @@ def deprecated_thing(): assert got.lineno == lineno - 1 -def test_warn_deprecated_no_instead_or_issue(recwarn_always): +def test_warn_deprecated_no_instead_or_issue( + recwarn_always: pytest.WarningsRecorder, +) -> None: # Explicitly no instead or issue warn_deprecated("water", "1.3", issue=None, instead=None) assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "water is deprecated" in got.message.args[0] assert "no replacement" in got.message.args[0] assert "Trio 1.3" in got.message.args[0] -def test_warn_deprecated_stacklevel(recwarn_always): - def nested1(): +def test_warn_deprecated_stacklevel(recwarn_always: pytest.WarningsRecorder) -> None: + def nested1() -> None: nested2() - def nested2(): + def nested2() -> None: warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3) filename, lineno = _here() @@ -66,29 +75,31 @@ def nested2(): assert got.lineno == lineno + 1 -def old(): # pragma: no cover +def old() -> None: # pragma: no cover pass -def new(): # pragma: no cover +def new() -> None: # pragma: no cover pass -def test_warn_deprecated_formatting(recwarn_always): +def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None: warn_deprecated(old, "1.0", issue=1, instead=new) got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.old is deprecated" in got.message.args[0] assert "test_deprecate.new instead" in got.message.args[0] @deprecated("1.5", issue=123, instead=new) -def deprecated_old(): +def deprecated_old() -> int: return 3 -def test_deprecated_decorator(recwarn_always): +def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None: assert deprecated_old() == 3 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0] assert "1.5" in got.message.args[0] assert "test_deprecate.new" in got.message.args[0] @@ -97,50 +108,56 @@ def test_deprecated_decorator(recwarn_always): class Foo: @deprecated("1.0", issue=123, instead="crying") - def method(self): + def method(self) -> int: return 7 -def test_deprecated_decorator_method(recwarn_always): +def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None: f = Foo() assert f.method() == 7 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.Foo.method is deprecated" in got.message.args[0] @deprecated("1.2", thing="the thing", issue=None, instead=None) -def deprecated_with_thing(): +def deprecated_with_thing() -> int: return 72 -def test_deprecated_decorator_with_explicit_thing(recwarn_always): +def test_deprecated_decorator_with_explicit_thing( + recwarn_always: pytest.WarningsRecorder, +) -> None: assert deprecated_with_thing() == 72 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "the thing is deprecated" in got.message.args[0] -def new_hotness(): +def new_hotness() -> str: return "new hotness" old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1) -def test_deprecated_alias(recwarn_always): +def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None: assert old_hotness() == "new hotness" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.old_hotness is deprecated" in got.message.args[0] assert "1.23" in got.message.args[0] assert "test_deprecate.new_hotness instead" in got.message.args[0] assert "issues/1" in got.message.args[0] + assert isinstance(old_hotness.__doc__, str) assert ".. deprecated:: 1.23" in old_hotness.__doc__ assert "test_deprecate.new_hotness instead" in old_hotness.__doc__ assert "issues/1>`__" in old_hotness.__doc__ class Alias: - def new_hotness_method(self): + def new_hotness_method(self) -> str: return "new hotness method" old_hotness_method = deprecated_alias( @@ -148,36 +165,37 @@ def new_hotness_method(self): ) -def test_deprecated_alias_method(recwarn_always): +def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None: obj = Alias() assert obj.old_hotness_method() == "new hotness method" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) msg = got.message.args[0] assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg assert "test_deprecate.Alias.new_hotness_method instead" in msg @deprecated("2.1", issue=1, instead="hi") -def docstring_test1(): # pragma: no cover +def docstring_test1() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=None, instead="hi") -def docstring_test2(): # pragma: no cover +def docstring_test2() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=1, instead=None) -def docstring_test3(): # pragma: no cover +def docstring_test3() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=None, instead=None) -def docstring_test4(): # pragma: no cover +def docstring_test4() -> None: # pragma: no cover """Hello!""" -def test_deprecated_docstring_munging(): +def test_deprecated_docstring_munging() -> None: assert ( docstring_test1.__doc__ == """Hello! @@ -219,13 +237,14 @@ def test_deprecated_docstring_munging(): ) -def test_module_with_deprecations(recwarn_always): +def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> None: assert module_with_deprecations.regular == "hi" assert len(recwarn_always) == 0 filename, lineno = _here() - assert module_with_deprecations.dep1 == "value1" + assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined] got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert got.filename == filename assert got.lineno == lineno + 1 @@ -234,12 +253,13 @@ def test_module_with_deprecations(recwarn_always): assert "/issues/1" in got.message.args[0] assert "value1 instead" in got.message.args[0] - assert module_with_deprecations.dep2 == "value2" + assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined] got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "instead-string instead" in got.message.args[0] with pytest.raises(AttributeError): - module_with_deprecations.asdf + module_with_deprecations.asdf # type: ignore[attr-defined] def test_tests_is_deprecated1() -> None: diff --git a/trio/_tests/test_dtls.py b/trio/_tests/test_dtls.py index 8cb06ccb3d..b7cb1830d1 100644 --- a/trio/_tests/test_dtls.py +++ b/trio/_tests/test_dtls.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import random +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from itertools import count +from typing import NoReturn import attr import pytest @@ -9,8 +13,8 @@ import trio import trio.testing -from trio import DTLSEndpoint -from trio.testing._fake_net import FakeNet +from trio import DTLSChannel, DTLSEndpoint +from trio.testing._fake_net import FakeNet, UDPPacket from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow @@ -29,7 +33,7 @@ ) -def endpoint(**kwargs): +def endpoint(**kwargs: int | bool) -> DTLSEndpoint: ipv6 = kwargs.pop("ipv6", False) if ipv6: family = trio.socket.AF_INET6 @@ -40,7 +44,9 @@ def endpoint(**kwargs): @asynccontextmanager -async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False): +async def dtls_echo_server( + *, autocancel: bool = True, mtu: int | None = None, ipv6: bool = False +) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]: with endpoint(ipv6=ipv6) as server: if ipv6: localhost = "::1" @@ -49,11 +55,11 @@ async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False): await server.socket.bind((localhost, 0)) async with trio.open_nursery() as nursery: - async def echo_handler(dtls_channel): + async def echo_handler(dtls_channel: DTLSChannel) -> None: print( "echo handler started: " - f"server {dtls_channel.endpoint.socket.getsockname()} " - f"client {dtls_channel.peer_address}" + f"server {dtls_channel.endpoint.socket.getsockname()!r} " + f"client {dtls_channel.peer_address!r}" ) if mtu is not None: dtls_channel.set_ciphertext_mtu(mtu) @@ -62,7 +68,7 @@ async def echo_handler(dtls_channel): await dtls_channel.do_handshake() print("server finished do_handshake") async for packet in dtls_channel: - print(f"echoing {packet} -> {dtls_channel.peer_address}") + print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}") await dtls_channel.send(packet) except trio.BrokenResourceError: # pragma: no cover print("echo handler channel broken") @@ -76,7 +82,7 @@ async def echo_handler(dtls_channel): @parametrize_ipv6 -async def test_smoke(ipv6): +async def test_smoke(ipv6: bool) -> None: async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): with endpoint(ipv6=ipv6) as client_endpoint: client_channel = client_endpoint.connect(address, client_ctx) @@ -101,7 +107,9 @@ async def test_smoke(ipv6): @slow -async def test_handshake_over_terrible_network(autojump_clock): +async def test_handshake_over_terrible_network( + autojump_clock: trio.testing.MockClock, +) -> None: HANDSHAKES = 100 r = random.Random(0) fn = FakeNet() @@ -112,7 +120,7 @@ async def test_handshake_over_terrible_network(autojump_clock): async with dtls_echo_server() as (_, address): async with trio.open_nursery() as nursery: - async def route_packet(packet): + async def route_packet(packet: UDPPacket) -> None: while True: op = r.choices( ["deliver", "drop", "dupe", "delay"], @@ -157,7 +165,7 @@ async def route_packet(packet): fn.deliver_packet(packet) break - def route_packet_wrapper(packet): + def route_packet_wrapper(packet: UDPPacket) -> None: try: nursery.start_soon(route_packet, packet) except RuntimeError: # pragma: no cover @@ -165,7 +173,7 @@ def route_packet_wrapper(packet): # dropped pass - fn.route_packet = route_packet_wrapper + fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing for i in range(HANDSHAKES): print("#" * 80) @@ -187,7 +195,7 @@ def route_packet_wrapper(packet): break -async def test_implicit_handshake(): +async def test_implicit_handshake() -> None: async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: client = client_endpoint.connect(address, client_ctx) @@ -197,14 +205,14 @@ async def test_implicit_handshake(): assert await client.receive() == b"xyz" -async def test_full_duplex(): +async def test_full_duplex() -> None: # Tests simultaneous send/receive, and also multiple methods implicitly invoking # do_handshake simultaneously. with endpoint() as server_endpoint, endpoint() as client_endpoint: await server_endpoint.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as server_nursery: - async def handler(channel): + async def handler(channel: DTLSChannel) -> None: async with trio.open_nursery() as nursery: nursery.start_soon(channel.send, b"from server") nursery.start_soon(channel.receive) @@ -221,7 +229,7 @@ async def handler(channel): server_nursery.cancel_scope.cancel() -async def test_channel_closing(): +async def test_channel_closing() -> None: async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: client = client_endpoint.connect(address, client_ctx) @@ -239,7 +247,7 @@ async def test_channel_closing(): await client.aclose() -async def test_serve_exits_cleanly_on_close(): +async def test_serve_exits_cleanly_on_close() -> None: async with dtls_echo_server(autocancel=False) as (server_endpoint, address): server_endpoint.close() # Testing that the nursery exits even without being cancelled @@ -247,7 +255,7 @@ async def test_serve_exits_cleanly_on_close(): server_endpoint.close() -async def test_client_multiplex(): +async def test_client_multiplex() -> None: async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): with endpoint() as client_endpoint: client1 = client_endpoint.connect(address1, client_ctx) @@ -261,7 +269,7 @@ async def test_client_multiplex(): client_endpoint.close() with pytest.raises(trio.ClosedResourceError): - await client1.send("xxx") + await client1.send(b"xxx") with pytest.raises(trio.ClosedResourceError): await client2.receive() with pytest.raises(trio.ClosedResourceError): @@ -270,20 +278,20 @@ async def test_client_multiplex(): async with trio.open_nursery() as nursery: with pytest.raises(trio.ClosedResourceError): - async def null_handler(_): # pragma: no cover + async def null_handler(_: object) -> None: # pragma: no cover pass await nursery.start(client_endpoint.serve, server_ctx, null_handler) -async def test_dtls_over_dgram_only(): +async def test_dtls_over_dgram_only() -> None: with trio.socket.socket() as s: with pytest.raises(ValueError): DTLSEndpoint(s) -async def test_double_serve(): - async def null_handler(_): # pragma: no cover +async def test_double_serve() -> None: + async def null_handler(_: object) -> None: # pragma: no cover pass with endpoint() as server_endpoint: @@ -300,7 +308,7 @@ async def null_handler(_): # pragma: no cover nursery.cancel_scope.cancel() -async def test_connect_to_non_server(autojump_clock): +async def test_connect_to_non_server(autojump_clock: trio.abc.Clock) -> None: fn = FakeNet() fn.enable() with endpoint() as client1, endpoint() as client2: @@ -312,7 +320,7 @@ async def test_connect_to_non_server(autojump_clock): assert cscope.cancelled_caught -async def test_incoming_buffer_overflow(autojump_clock): +async def test_incoming_buffer_overflow(autojump_clock: trio.abc.Clock) -> None: fn = FakeNet() fn.enable() for buffer_size in [10, 20]: @@ -331,7 +339,9 @@ async def test_incoming_buffer_overflow(autojump_clock): assert await client.receive() == b"buffer clear now" -async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): +async def test_server_socket_doesnt_crash_on_garbage( + autojump_clock: trio.abc.Clock, +) -> None: fn = FakeNet() fn.enable() @@ -443,7 +453,7 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): await trio.sleep(1) -async def test_invalid_cookie_rejected(autojump_clock): +async def test_invalid_cookie_rejected(autojump_clock: trio.abc.Clock) -> None: fn = FakeNet() fn.enable() @@ -454,7 +464,7 @@ async def test_invalid_cookie_rejected(autojump_clock): # corrupting bytes after that. offset_to_corrupt = count(11) - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: try: _, cookie, _ = decode_client_hello_untrusted(packet.payload) except BadPacket: @@ -475,17 +485,19 @@ def route_packet(packet): fn.deliver_packet(packet) - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing async with dtls_echo_server() as (_, address): while True: with endpoint() as client: channel = client.connect(address, client_ctx) await channel.do_handshake() - assert cscope.cancelled_caught + assert cscope.cancelled_caught -async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): +async def test_client_cancels_handshake_and_starts_new_one( + autojump_clock: trio.abc.Clock, +) -> None: # if a client disappears during the handshake, and then starts a new handshake from # scratch, then the first handler's channel should fail, and a new handler get # started @@ -497,7 +509,7 @@ async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): async with trio.open_nursery() as nursery: first_time = True - async def handler(channel): + async def handler(channel: DTLSChannel) -> None: nonlocal first_time if first_time: first_time = False @@ -528,16 +540,16 @@ async def handler(channel): nursery.cancel_scope.cancel() -async def test_swap_client_server(): +async def test_swap_client_server() -> None: with endpoint() as a, endpoint() as b: await a.socket.bind(("127.0.0.1", 0)) await b.socket.bind(("127.0.0.1", 0)) - async def echo_handler(channel): + async def echo_handler(channel: DTLSChannel) -> None: async for packet in channel: await channel.send(packet) - async def crashing_echo_handler(channel): + async def crashing_echo_handler(channel: DTLSChannel) -> None: with pytest.raises(trio.BrokenResourceError): await echo_handler(channel) @@ -560,7 +572,7 @@ async def crashing_echo_handler(channel): @slow -async def test_openssl_retransmit_doesnt_break_stuff(): +async def test_openssl_retransmit_doesnt_break_stuff() -> None: # can't use autojump_clock here, because the point of the test is to wait for # openssl's built-in retransmit timer to expire, which is hard-coded to use # wall-clock time. @@ -569,7 +581,7 @@ async def test_openssl_retransmit_doesnt_break_stuff(): blackholed = True - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: if blackholed: print("dropped packet", packet) return @@ -583,13 +595,13 @@ def route_packet(packet): # ) fn.deliver_packet(packet) - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet async with dtls_echo_server() as (server_endpoint, address): with endpoint() as client_endpoint: async with trio.open_nursery() as nursery: - async def connecter(): + async def connecter() -> None: client = client_endpoint.connect(address, client_ctx) await client.do_handshake(initial_retransmit_timeout=1.5) await client.send(b"hi") @@ -611,20 +623,22 @@ async def connecter(): # scapy.all.wrpcap("/tmp/trace.pcap", packets) -async def test_initial_retransmit_timeout_configuration(autojump_clock): +async def test_initial_retransmit_timeout_configuration( + autojump_clock: trio.abc.Clock, +) -> None: fn = FakeNet() fn.enable() blackholed = True - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: nonlocal blackholed if blackholed: blackholed = False else: fn.deliver_packet(packet) - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet async with dtls_echo_server() as (_, address): for t in [1, 2, 4]: @@ -637,7 +651,7 @@ def route_packet(packet): assert after - before == t -async def test_explicit_tiny_mtu_is_respected(): +async def test_explicit_tiny_mtu_is_respected() -> None: # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to # be larger than that. (300 is still smaller than any real network though.) MTU = 300 @@ -645,13 +659,13 @@ async def test_explicit_tiny_mtu_is_respected(): fn = FakeNet() fn.enable() - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: print(f"delivering {packet}") print(f"payload size: {len(packet.payload)}") assert len(packet.payload) <= MTU fn.deliver_packet(packet) - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet async with dtls_echo_server(mtu=MTU) as (server, address): with endpoint() as client: @@ -663,7 +677,9 @@ def route_packet(packet): @parametrize_ipv6 -async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock): +async def test_handshake_handles_minimum_network_mtu( + ipv6: bool, autojump_clock: trio.abc.Clock +) -> None: # Fake network that has the minimum allowable MTU for whatever protocol we're using. fn = FakeNet() fn.enable() @@ -673,14 +689,14 @@ async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock): else: mtu = 576 - 28 - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: if len(packet.payload) > mtu: print(f"dropping {packet}") else: print(f"delivering {packet}") fn.deliver_packet(packet) - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet # See if we can successfully do a handshake -- some of the volleys will get dropped, # and the retransmit logic should detect this and back off the MTU to something @@ -698,14 +714,14 @@ def route_packet(packet): @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") -async def test_system_task_cleaned_up_on_gc(): +async def test_system_task_cleaned_up_on_gc() -> None: before_tasks = trio.lowlevel.current_statistics().tasks_living # We put this into a sub-function so that everything automatically becomes garbage # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy # with coverage enabled -- I think we were hitting this bug: # https://foss.heptapod.net/pypy/pypy/-/issues/3656 - async def start_and_forget_endpoint(): + async def start_and_forget_endpoint() -> int: e = endpoint() # This connection/handshake attempt can't succeed. The only purpose is to force @@ -734,7 +750,7 @@ async def start_and_forget_endpoint(): @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") -async def test_gc_before_system_task_starts(): +async def test_gc_before_system_task_starts() -> None: e = endpoint() with pytest.warns(ResourceWarning): @@ -745,7 +761,7 @@ async def test_gc_before_system_task_starts(): @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") -async def test_gc_as_packet_received(): +async def test_gc_as_packet_received() -> None: fn = FakeNet() fn.enable() @@ -766,8 +782,8 @@ async def test_gc_as_packet_received(): @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") -def test_gc_after_trio_exits(): - async def main(): +def test_gc_after_trio_exits() -> None: + async def main() -> DTLSEndpoint: # We use fakenet just to make sure no real sockets can leak out of the test # case - on pypy somehow the socket was outliving the gc_collect_harder call # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode @@ -782,7 +798,7 @@ async def main(): gc_collect_harder() -async def test_already_closed_socket_doesnt_crash(): +async def test_already_closed_socket_doesnt_crash() -> None: with endpoint() as e: # We close the socket before checkpointing, so the socket will already be closed # when the system task starts up @@ -791,7 +807,9 @@ async def test_already_closed_socket_doesnt_crash(): await trio.testing.wait_all_tasks_blocked() -async def test_socket_closed_while_processing_clienthello(autojump_clock): +async def test_socket_closed_while_processing_clienthello( + autojump_clock: trio.abc.Clock, +) -> None: fn = FakeNet() fn.enable() @@ -799,11 +817,11 @@ async def test_socket_closed_while_processing_clienthello(autojump_clock): # HelloVerifyRequest, since that has its own sending logic async with dtls_echo_server() as (server, address): - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: fn.deliver_packet(packet) server.socket.close() - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet with endpoint() as client_endpoint: with trio.move_on_after(10): @@ -811,21 +829,23 @@ def route_packet(packet): await client.do_handshake() -async def test_association_replaced_while_handshake_running(autojump_clock): +async def test_association_replaced_while_handshake_running( + autojump_clock: trio.abc.Clock, +) -> None: fn = FakeNet() fn.enable() - def route_packet(packet): + def route_packet(packet: UDPPacket) -> None: pass - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: c1 = client_endpoint.connect(address, client_ctx) async with trio.open_nursery() as nursery: - async def doomed_handshake(): + async def doomed_handshake() -> None: with pytest.raises(trio.BrokenResourceError): await c1.do_handshake() @@ -836,15 +856,15 @@ async def doomed_handshake(): client_endpoint.connect(address, client_ctx) -async def test_association_replaced_before_handshake_starts(): +async def test_association_replaced_before_handshake_starts() -> None: fn = FakeNet() fn.enable() # This test shouldn't send any packets - def route_packet(packet): # pragma: no cover + def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover assert False - fn.route_packet = route_packet + fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: @@ -854,7 +874,7 @@ def route_packet(packet): # pragma: no cover await c1.do_handshake() -async def test_send_to_closed_local_port(): +async def test_send_to_closed_local_port() -> None: # On Windows, sending a UDP packet to a closed local port can cause a weird # ECONNRESET error later, inside the receive task. Make sure we're handling it # properly.