From 3adda46cf0ef00a071af9b7e6b3539876a1b5fd8 Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Mon, 16 Sep 2024 11:23:28 -0400 Subject: [PATCH 1/2] Add type hints to all defs --- nextline/events.py | 26 ++++++------ nextline/main.py | 3 +- nextline/utils/multiprocessing_logging.py | 5 ++- tests/conftest.py | 5 ++- tests/fsm/test_config.py | 2 +- tests/main/scenarios/test_example.py | 4 +- tests/main/scenarios/test_interruption.py | 2 +- tests/main/scenarios/test_kill.py | 2 +- tests/main/scenarios/test_raise.py | 2 +- .../scenarios/test_raise_dynamic_class.py | 2 +- tests/main/scenarios/test_simple.py | 2 +- tests/main/scenarios/test_terminate.py | 2 +- tests/main/test_nextline.py | 8 ++-- tests/main/test_prompts.py | 2 +- tests/main/test_register.py | 2 +- tests/script.py | 2 +- .../test_filter_by_module_name_profile.py | 8 ++-- tests/spawned/run/example/lib.py | 2 +- tests/spawned/run/test_run.py | 35 +++++++++------- tests/spawned/test_call.py | 6 +-- tests/spawned/test_io.py | 4 +- tests/spawned/trace/conftest.py | 13 +++--- tests/spawned/trace/module_a.py | 2 +- tests/spawned/trace/module_b.py | 2 +- tests/spawned/trace/test_filter.py | 14 +++---- tests/spawned/trace/test_fixture.py | 8 ++-- tests/test_version.py | 2 +- tests/utils/done_callback/test_task.py | 42 +++++++++---------- tests/utils/done_callback/test_thread.py | 36 ++++++++-------- tests/utils/done_callback/test_union.py | 26 ++++++------ tests/utils/pubsub/test_broker.py | 26 ++++++------ tests/utils/run/test_signal.py | 2 +- tests/utils/test_agen_with_wait.py | 20 +++++---- tests/utils/test_multiprocessing_logging.py | 4 +- tests/utils/test_peek.py | 6 +-- tests/utils/test_profile.py | 4 +- tests/utils/test_thread_exception.py | 6 ++- tests/utils/test_thread_task_id.py | 28 ++++++------- 38 files changed, 190 insertions(+), 177 deletions(-) diff --git a/nextline/events.py b/nextline/events.py index e6b65df7..dc9f5b6f 100644 --- a/nextline/events.py +++ b/nextline/events.py @@ -25,7 +25,7 @@ class OnStartRun(Event): run_no: RunNo statement: Statement - def __post_init__(self): + def __post_init__(self) -> None: _assert_aware_datetime(self.started_at) @@ -36,7 +36,7 @@ class OnEndRun(Event): returned: str raised: str - def __post_init__(self): + def __post_init__(self) -> None: _assert_aware_datetime(self.ended_at) @@ -48,7 +48,7 @@ class OnStartTrace(Event): thread_no: ThreadNo task_no: Optional[TaskNo] - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.started_at) @@ -58,7 +58,7 @@ class OnEndTrace(Event): run_no: RunNo trace_no: TraceNo - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.ended_at) @@ -73,7 +73,7 @@ class OnStartTraceCall(Event): frame_object_id: int event: str - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.started_at) @@ -84,7 +84,7 @@ class OnEndTraceCall(Event): trace_no: TraceNo trace_call_no: TraceCallNo - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.ended_at) @@ -95,7 +95,7 @@ class OnStartCmdloop(Event): trace_no: TraceNo trace_call_no: TraceCallNo - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.started_at) @@ -106,7 +106,7 @@ class OnEndCmdloop(Event): trace_no: TraceNo trace_call_no: TraceCallNo - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.ended_at) @@ -123,7 +123,7 @@ class OnStartPrompt(Event): frame_object_id: int event: str - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.started_at) @@ -136,7 +136,7 @@ class OnEndPrompt(Event): prompt_no: PromptNo command: str - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.ended_at) @@ -147,15 +147,15 @@ class OnWriteStdout(Event): trace_no: TraceNo text: str - def __post_init__(self): + def __post_init__(self) -> None: _assert_naive_datetime(self.written_at) -def _assert_naive_datetime(dt: datetime.datetime): +def _assert_naive_datetime(dt: datetime.datetime) -> None: if is_timezone_aware(dt): raise ValueError(f'Not a timezone-naive object: {dt!r}') -def _assert_aware_datetime(dt: datetime.datetime): +def _assert_aware_datetime(dt: datetime.datetime) -> None: if not is_timezone_aware(dt): raise ValueError(f'Not a timezone-aware object: {dt!r}') diff --git a/nextline/main.py b/nextline/main.py index 857c7b4f..d5a34e15 100644 --- a/nextline/main.py +++ b/nextline/main.py @@ -101,8 +101,7 @@ async def __aenter__(self) -> 'Nextline': await self.start() return self - async def __aexit__(self, exc_type, exc_value, traceback) -> None: - del exc_type, exc_value, traceback + async def __aexit__(self, *_: Any, **__: Any) -> None: await asyncio.wait_for(self.close(), timeout=self._timeout_on_exit) async def run(self) -> None: diff --git a/nextline/utils/multiprocessing_logging.py b/nextline/utils/multiprocessing_logging.py index 9c318efb..0d7c470c 100644 --- a/nextline/utils/multiprocessing_logging.py +++ b/nextline/utils/multiprocessing_logging.py @@ -1,4 +1,5 @@ import asyncio +from collections.abc import AsyncIterator, Callable import contextlib import logging import multiprocessing as mp @@ -19,7 +20,9 @@ def example_func() -> None: @contextlib.asynccontextmanager -async def MultiprocessingLogging(mp_context: Optional[BaseContext] = None): +async def MultiprocessingLogging( + mp_context: Optional[BaseContext] = None, +) -> AsyncIterator[Callable[[], None]]: '''Collect logging from other processes in the main process. Example: diff --git a/tests/conftest.py b/tests/conftest.py index 12f660a0..5bae1b78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,14 @@ import sys import threading +from collections.abc import Iterator import pytest @pytest.fixture(autouse=True) -def recover_trace(): +def recover_trace() -> Iterator[None]: """Set the original trace function back after each test""" trace_org = sys.gettrace() yield sys.settrace(trace_org) - threading.settrace(trace_org) + threading.settrace(trace_org) # type: ignore diff --git a/tests/fsm/test_config.py b/tests/fsm/test_config.py index 00770830..71b74af9 100644 --- a/tests/fsm/test_config.py +++ b/tests/fsm/test_config.py @@ -43,7 +43,7 @@ def test_restore_from_markup() -> None: @pytest.mark.skip -def test_graph(tmp_path: Path): +def test_graph(tmp_path: Path) -> None: FILE_NAME = 'states.png' path = tmp_path / FILE_NAME # print(f'Saving the state diagram to {path}...') diff --git a/tests/main/scenarios/test_example.py b/tests/main/scenarios/test_example.py index d83a3359..81f93a32 100644 --- a/tests/main/scenarios/test_example.py +++ b/tests/main/scenarios/test_example.py @@ -10,7 +10,7 @@ from .funcs import extract_comment -async def test_run(statement: str): +async def test_run(statement: str) -> None: nextline = Nextline(statement, trace_threads=True, trace_modules=True) assert nextline.state == 'created' plugin = Plugin() @@ -131,7 +131,7 @@ def find_command(line: str) -> Optional[str]: @pytest.fixture -def statement(script_dir, monkey_patch_syspath) -> str: +def statement(script_dir : str, monkey_patch_syspath: None) -> str: del monkey_patch_syspath return (Path(script_dir) / 'script.py').read_text() diff --git a/tests/main/scenarios/test_interruption.py b/tests/main/scenarios/test_interruption.py index ed0fd588..4cff57c0 100644 --- a/tests/main/scenarios/test_interruption.py +++ b/tests/main/scenarios/test_interruption.py @@ -47,7 +47,7 @@ async def on_finished(self, context: Context) -> None: assert 'KeyboardInterrupt' in fmt_exc -async def run(nextline: Nextline): +async def run(nextline: Nextline) -> None: async with nextline.run_session(): pass fmt_exc = nextline.format_exception() diff --git a/tests/main/scenarios/test_kill.py b/tests/main/scenarios/test_kill.py index 8227c8f6..016c7644 100644 --- a/tests/main/scenarios/test_kill.py +++ b/tests/main/scenarios/test_kill.py @@ -46,7 +46,7 @@ async def on_finished(self, context: Context) -> None: assert exited_process.process.exitcode == -signal.SIGKILL -async def run(nextline: Nextline): +async def run(nextline: Nextline) -> None: async with nextline.run_session(): pass assert not nextline.format_exception() diff --git a/tests/main/scenarios/test_raise.py b/tests/main/scenarios/test_raise.py index 3aef04fa..688aea21 100644 --- a/tests/main/scenarios/test_raise.py +++ b/tests/main/scenarios/test_raise.py @@ -42,7 +42,7 @@ async def on_finished(self, context: Context) -> None: assert 'RuntimeError' in fmt_exc -async def run(nextline: Nextline): +async def run(nextline: Nextline) -> None: async with nextline.run_session(): pass fmt_exc = nextline.format_exception() diff --git a/tests/main/scenarios/test_raise_dynamic_class.py b/tests/main/scenarios/test_raise_dynamic_class.py index 804f025a..13dfbd13 100644 --- a/tests/main/scenarios/test_raise_dynamic_class.py +++ b/tests/main/scenarios/test_raise_dynamic_class.py @@ -45,7 +45,7 @@ async def on_finished(self, context: Context) -> None: assert 'MyError' in fmt_exc -async def run(nextline: Nextline): +async def run(nextline: Nextline) -> None: async with nextline.run_session(): pass fmt_exc = nextline.format_exception() diff --git a/tests/main/scenarios/test_simple.py b/tests/main/scenarios/test_simple.py index 7fe466d9..efc3dfc0 100644 --- a/tests/main/scenarios/test_simple.py +++ b/tests/main/scenarios/test_simple.py @@ -38,7 +38,7 @@ async def on_finished(self, context: Context) -> None: assert exited_process.process.exitcode == 0 -async def run(nextline: Nextline): +async def run(nextline: Nextline) -> None: async with nextline.run_session(): pass assert not nextline.format_exception() diff --git a/tests/main/scenarios/test_terminate.py b/tests/main/scenarios/test_terminate.py index 78ef7916..fffc24c7 100644 --- a/tests/main/scenarios/test_terminate.py +++ b/tests/main/scenarios/test_terminate.py @@ -46,7 +46,7 @@ async def on_finished(self, context: Context) -> None: assert exited_process.process.exitcode == -signal.SIGTERM -async def run(nextline: Nextline): +async def run(nextline: Nextline) -> None: async with nextline.run_session(): pass assert not nextline.format_exception() diff --git a/tests/main/test_nextline.py b/tests/main/test_nextline.py index 16a0a1a4..be64a754 100644 --- a/tests/main/test_nextline.py +++ b/tests/main/test_nextline.py @@ -15,7 +15,7 @@ """.strip() -def test_init_sync(): +def test_init_sync() -> None: '''Assert the init without the running loop.''' with pytest.raises(RuntimeError): asyncio.get_running_loop() @@ -23,7 +23,7 @@ def test_init_sync(): assert nextline -async def test_repr(): +async def test_repr() -> None: nextline = Nextline(SOURCE) assert repr(nextline) async with nextline: @@ -48,8 +48,8 @@ async def test_one() -> None: ) -async def test_timeout(imp: Mock): - async def close(): +async def test_timeout(imp: Mock) -> None: + async def close() -> None: await asyncio.sleep(5) imp.aclose.side_effect = close diff --git a/tests/main/test_prompts.py b/tests/main/test_prompts.py index 61c899bf..55b35716 100644 --- a/tests/main/test_prompts.py +++ b/tests/main/test_prompts.py @@ -3,7 +3,7 @@ from nextline import Nextline -def func(): +def func() -> None: time.sleep(0.001) diff --git a/tests/main/test_register.py b/tests/main/test_register.py index d683d8c2..a4a72453 100644 --- a/tests/main/test_register.py +++ b/tests/main/test_register.py @@ -5,7 +5,7 @@ from nextline.plugin.spec import Context, hookimpl -def func(): +def func() -> None: time.sleep(0.001) diff --git a/tests/script.py b/tests/script.py index 3ac27cec..e7135f47 100644 --- a/tests/script.py +++ b/tests/script.py @@ -1,3 +1,3 @@ -def script(): +def script() -> None: for i in range(3): print("in script() i={}".format(i)) diff --git a/tests/spawned/callback/plugins/test_filter_by_module_name_profile.py b/tests/spawned/callback/plugins/test_filter_by_module_name_profile.py index 3d23b705..6feb5f23 100644 --- a/tests/spawned/callback/plugins/test_filter_by_module_name_profile.py +++ b/tests/spawned/callback/plugins/test_filter_by_module_name_profile.py @@ -9,7 +9,7 @@ from nextline.utils import profile_func -def test_timeit(plugin: FilterByModuleName): +def test_timeit(plugin: FilterByModuleName) -> None: n_calls = 200_000 thread_id = threading.current_thread().ident @@ -24,7 +24,7 @@ def test_timeit(plugin: FilterByModuleName): assert sec < 1 -def test_profile(plugin: FilterByModuleName): +def test_profile(plugin: FilterByModuleName) -> None: '''Used to print the profile.''' n_calls = 20_000 @@ -34,7 +34,7 @@ def test_profile(plugin: FilterByModuleName): frame = sys._current_frames()[thread_id] - def func(): + def func() -> None: for _ in range(n_calls): plugin.filter((frame, "line", None)) @@ -44,7 +44,7 @@ def func(): @pytest.fixture() -def plugin(): +def plugin() -> FilterByModuleName: p = FilterByModuleName() p.init(modules_to_skip=MODULES_TO_SKIP) return p diff --git a/tests/spawned/run/example/lib.py b/tests/spawned/run/example/lib.py index 4e133d0c..21fa973d 100644 --- a/tests/spawned/run/example/lib.py +++ b/tests/spawned/run/example/lib.py @@ -1,3 +1,3 @@ -def g(): +def g() -> None: print('here!') pass diff --git a/tests/spawned/run/test_run.py b/tests/spawned/run/test_run.py index 47b3c5b2..9ff6eb53 100644 --- a/tests/spawned/run/test_run.py +++ b/tests/spawned/run/test_run.py @@ -1,6 +1,8 @@ import queue +from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from typing import Any, TypeAlias import pytest @@ -8,14 +10,16 @@ from nextline.spawned import PdbCommand, QueueIn, QueueOut, RunArg, main, set_queues from nextline.types import RunNo +RunArgParams: TypeAlias = tuple[RunArg, Any, str | None] + def test_one( run_arg: RunArg, - expected_exc, - expected_ret, - call_set_queues, - task_send_commands, -): + expected_exc : str | None, + expected_ret : Any, + call_set_queues : None, + task_send_commands: None, +) -> None: del call_set_queues, task_send_commands result = main(run_arg) assert result.ret == expected_ret @@ -27,13 +31,12 @@ def test_one( @pytest.fixture -def call_set_queues(queue_in: QueueIn, queue_out: QueueOut): +def call_set_queues(queue_in: QueueIn, queue_out: QueueOut) -> None: set_queues(queue_in, queue_out) - yield @pytest.fixture -def task_send_commands(queue_in: QueueIn, queue_out: QueueOut): +def task_send_commands(queue_in: QueueIn, queue_out: QueueOut) -> Iterator[None]: with ThreadPoolExecutor(max_workers=1) as executor: fut = executor.submit(respond_prompt, queue_in, queue_out) yield @@ -41,7 +44,7 @@ def task_send_commands(queue_in: QueueIn, queue_out: QueueOut): fut.result() -def respond_prompt(queue_in: QueueIn, queue_out: QueueOut): +def respond_prompt(queue_in: QueueIn, queue_out: QueueOut) -> None: while (event := queue_out.get()) is not None: if not isinstance(event, OnStartPrompt): continue @@ -52,17 +55,17 @@ def respond_prompt(queue_in: QueueIn, queue_out: QueueOut): @pytest.fixture -def run_arg(run_arg_params) -> RunArg: +def run_arg(run_arg_params: RunArgParams) -> RunArg: return run_arg_params[0] @pytest.fixture -def expected_ret(run_arg_params): +def expected_ret(run_arg_params: RunArgParams) -> Any: return run_arg_params[1] @pytest.fixture -def expected_exc(run_arg_params): +def expected_exc(run_arg_params: RunArgParams) -> str | None: return run_arg_params[2] @@ -90,11 +93,11 @@ class MyException(Exception): CODE_OBJECT = compile(SRC_ONE, '', 'exec') -def func_one(): +def func_one() -> int: return 123 -def func_err(): +def func_err() -> None: 1 / 0 @@ -105,7 +108,7 @@ def func_err(): ERR_PATH = SCRIPT_DIR / 'err.py' assert ERR_PATH.is_file() -params = [ +params: list[RunArgParams] = [ (RunArg(run_no=RunNo(1), statement=SRC_ONE, filename=''), None, None), ( RunArg(run_no=RunNo(1), statement=SRC_COMPILE_ERROR, filename=''), @@ -131,7 +134,7 @@ def func_err(): @pytest.fixture(params=params) -def run_arg_params(request) -> tuple[RunArg, type[BaseException] | None]: +def run_arg_params(request: pytest.FixtureRequest) -> RunArgParams: return request.param diff --git a/tests/spawned/test_call.py b/tests/spawned/test_call.py index 9b9a300b..5e4ba179 100644 --- a/tests/spawned/test_call.py +++ b/tests/spawned/test_call.py @@ -16,7 +16,7 @@ def trace() -> Mock: return f -def test_simple(trace: Mock): +def test_simple(trace: Mock) -> None: def func() -> int: x = 123 return x @@ -33,7 +33,7 @@ class MockError(Exception): pass -def test_raise(trace: Mock): +def test_raise(trace: Mock) -> None: def func() -> NoReturn: raise MockError() @@ -66,7 +66,7 @@ def func() -> NoReturn: @pytest.mark.parametrize("thread", [True, False]) -def test_threading(trace: Mock, thread: bool): +def test_threading(trace: Mock, thread: bool) -> None: def f1() -> None: return diff --git a/tests/spawned/test_io.py b/tests/spawned/test_io.py index 5c6d151b..c5a7e307 100644 --- a/tests/spawned/test_io.py +++ b/tests/spawned/test_io.py @@ -12,7 +12,7 @@ from nextline.utils import current_task_or_thread -def print_lines(lines: Iterator[str], file: Optional[TextIO] = None): +def print_lines(lines: Iterator[str], file: Optional[TextIO] = None) -> None: for line in lines: time.sleep(0.0001) print(line, file=file) @@ -20,7 +20,7 @@ def print_lines(lines: Iterator[str], file: Optional[TextIO] = None): @given(st.data()) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) -def test_one(capsys: pytest.CaptureFixture, data: st.DataObject): +def test_one(capsys: pytest.CaptureFixture, data: st.DataObject) -> None: capsys.readouterr() # clear # exclude line breaks diff --git a/tests/spawned/trace/conftest.py b/tests/spawned/trace/conftest.py index 4962960b..0df3b5e4 100644 --- a/tests/spawned/trace/conftest.py +++ b/tests/spawned/trace/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from unittest.mock import Mock import pytest @@ -13,7 +14,7 @@ def target( wrap_target_trace_func: Mock, modules_in_summary: set[str] | None, - run_target, + run_target: None, ) -> TraceSummary: '''Summary of the calls to the target trace function. @@ -29,7 +30,7 @@ def target( def probe( probe_trace_func: Mock, modules_in_summary: set[str] | None, - run_target, + run_target: None, ) -> TraceSummary: '''Summary of the calls to the probe trace function. @@ -45,7 +46,7 @@ def probe( def ref( ref_trace_func: Mock, modules_in_summary: set[str] | None, - run_ref, + run_ref: None, ) -> TraceSummary: '''Summary of the calls to the reference trace function. @@ -97,7 +98,7 @@ def wrap_target_trace_func(target_trace_func: TraceFunction) -> Mock: '''A mock object wrapping the trace function under test to collect trace calls.''' wrap = Mock(wraps=target_trace_func) - def side_effect(*a, **k): + def side_effect(*a: Any, **k: Any) -> TraceFunction | None: # Wrap again if the target returns itself. local_trace_func = target_trace_func(*a, **k) if local_trace_func is target_trace_func: @@ -109,7 +110,7 @@ def side_effect(*a, **k): @pytest.fixture() -def target_trace_func(probe_trace_func: Mock): +def target_trace_func(probe_trace_func: Mock) -> TraceFunction | None: '''The trace function under test. This fixture is to be overridden by the test.''' del probe_trace_func raise RuntimeError('This fixture must be overridden by the test') diff --git a/tests/spawned/trace/module_a.py b/tests/spawned/trace/module_a.py index 0bc6954a..f5d8afbf 100644 --- a/tests/spawned/trace/module_a.py +++ b/tests/spawned/trace/module_a.py @@ -1,6 +1,6 @@ from . import module_b -def func_a(): +def func_a() -> None: module_b.func_b() return diff --git a/tests/spawned/trace/module_b.py b/tests/spawned/trace/module_b.py index 3631bc83..c0b840e5 100644 --- a/tests/spawned/trace/module_b.py +++ b/tests/spawned/trace/module_b.py @@ -1,2 +1,2 @@ -def func_b(): +def func_b() -> None: pass diff --git a/tests/spawned/trace/test_filter.py b/tests/spawned/trace/test_filter.py index 7ea488d6..151d02c4 100644 --- a/tests/spawned/trace/test_filter.py +++ b/tests/spawned/trace/test_filter.py @@ -15,7 +15,7 @@ def Filter( ) -> TraceFunction: '''Skip if the filter returns False.''' - def _trace(frame: FrameType, event, arg) -> Optional[TraceFunction]: + def _trace(frame: FrameType, event: str, arg: Any) -> Optional[TraceFunction]: if filter(frame, event, arg): return trace(frame, event, arg) return None @@ -26,7 +26,7 @@ def _trace(frame: FrameType, event, arg) -> Optional[TraceFunction]: def FilterLambda(trace: TraceFunction) -> TraceFunction: '''An example filter''' - def filter(frame: FrameType, event, arg) -> bool: + def filter(frame: FrameType, event: str, arg: Any) -> bool: del event, arg func_name = frame.f_code.co_name return not func_name == '' @@ -34,7 +34,7 @@ def filter(frame: FrameType, event, arg) -> bool: return Filter(trace=trace, filter=filter) -def test_one(target: TraceSummary, probe: TraceSummary, ref: TraceSummary): +def test_one(target: TraceSummary, probe: TraceSummary, ref: TraceSummary) -> None: assert ref.call.func assert ref.return_.func assert ref.call.func == target.call.func @@ -43,20 +43,20 @@ def test_one(target: TraceSummary, probe: TraceSummary, ref: TraceSummary): assert set(ref.return_.func) - {""} == set(probe.return_.func) -def f(): +def f() -> None: module_a.func_a() -def g(): +def g() -> None: (lambda: module_a.func_a())() @pytest.fixture(params=[f, g, lambda: module_a.func_a()]) -def func(request): +def func(request: pytest.FixtureRequest) -> Callable[[], Any]: return request.param @pytest.fixture() -def target_trace_func(probe_trace_func: Mock): +def target_trace_func(probe_trace_func: Mock) -> TraceFunction: y = FilterLambda(trace=probe_trace_func) return y diff --git a/tests/spawned/trace/test_fixture.py b/tests/spawned/trace/test_fixture.py index 5a88a0b6..856cf25f 100644 --- a/tests/spawned/trace/test_fixture.py +++ b/tests/spawned/trace/test_fixture.py @@ -8,7 +8,7 @@ def test_wrap_target_trace_func( target_trace_func: Mock | TraceFunction, wrap_target_trace_func: Mock | TraceFunction, -): +) -> None: arg = (Mock(), "", None) if target_trace_func is target_trace_func(*arg): assert wrap_target_trace_func is wrap_target_trace_func(*arg) @@ -17,7 +17,9 @@ def test_wrap_target_trace_func( @pytest.fixture(params=["self", "another", "none"]) -def target_trace_func(request, another_trace_func): +def target_trace_func( + request: pytest.FixtureRequest, another_trace_func: TraceFunction +) -> TraceFunction: y = Mock() map = {"self": y, "another": another_trace_func, "none": None} y.return_value = map[request.param] @@ -25,7 +27,7 @@ def target_trace_func(request, another_trace_func): @pytest.fixture() -def another_trace_func(): +def another_trace_func() -> TraceFunction: y = Mock() y.return_value = y return y diff --git a/tests/test_version.py b/tests/test_version.py index 894f60b7..bfeb525a 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,6 +1,6 @@ import nextline -def test_version(): +def test_version() -> None: """test if the version string is attached to the module""" nextline.__version__ diff --git a/tests/utils/done_callback/test_task.py b/tests/utils/done_callback/test_task.py index 55450efe..add94e5d 100644 --- a/tests/utils/done_callback/test_task.py +++ b/tests/utils/done_callback/test_task.py @@ -9,7 +9,7 @@ from nextline.utils import TaskDoneCallback -async def target(obj: TaskDoneCallback): +async def target(obj: TaskDoneCallback) -> None: assert asyncio.current_task() == obj.register() delay = random.random() * 0.01 await asyncio.sleep(delay) @@ -18,20 +18,20 @@ async def target(obj: TaskDoneCallback): class Done: """A callback function""" - def __init__(self): - self.args = set() + def __init__(self) -> None: + self.args = set[asyncio.Task]() - def __call__(self, arg): + def __call__(self, arg: asyncio.Task) -> None: self.args.add(arg) @pytest.fixture() -def done(): +def done() -> Done: """A callback function""" - yield Done() + return Done() -async def test_aclose(done: Done): +async def test_aclose(done: Done) -> None: obj = TaskDoneCallback(done=done) t = asyncio.create_task(target(obj)) await t @@ -39,30 +39,30 @@ async def test_aclose(done: Done): assert {t} == done.args -async def test_async_with(done: Done): +async def test_async_with(done: Done) -> None: async with TaskDoneCallback(done=done) as obj: t = asyncio.create_task(target(obj)) await t assert {t} == done.args -def test_asyncio_run_close(done: Done): +def test_asyncio_run_close(done: Done) -> None: obj = TaskDoneCallback(done=done) asyncio.run(target(obj)) obj.close() assert 1 == len(done.args) -def test_asyncio_run_with(done: Done): +def test_asyncio_run_with(done: Done) -> None: with TaskDoneCallback(done=done) as obj: asyncio.run(target(obj)) assert 1 == len(done.args) -def test_thread(done: Done): +def test_thread(done: Done) -> None: event = threading.Event() - def f(obj: TaskDoneCallback): + def f(obj: TaskDoneCallback) -> None: asyncio.run(target(obj)) event.set() @@ -75,8 +75,8 @@ def f(obj: TaskDoneCallback): t.join() -async def test_register_arg(done: Done): - async def target(): +async def test_register_arg(done: Done) -> None: + async def target() -> None: delay = random.random() * 0.01 await asyncio.sleep(delay) @@ -92,7 +92,7 @@ async def target(): @pytest.mark.parametrize("n_tasks", [0, 1, 2, 5, 10]) -async def test_multiple(n_tasks: int, done: Done): +async def test_multiple(n_tasks: int, done: Done) -> None: async with TaskDoneCallback(done=done) as obj: tasks = {asyncio.create_task(target(obj)) for _ in range(n_tasks)} await asyncio.gather(*tasks) @@ -100,8 +100,8 @@ async def test_multiple(n_tasks: int, done: Done): assert tasks == done.args -async def test_raise_aclose_from_task(done: Done): - async def target(obj: TaskDoneCallback): +async def test_raise_aclose_from_task(done: Done) -> None: + async def target(obj: TaskDoneCallback) -> None: obj.register() delay = random.random() * 0.01 time.sleep(delay) @@ -114,8 +114,8 @@ async def target(obj: TaskDoneCallback): await t -async def test_raise_close_from_task(done: Done): - async def target(obj: TaskDoneCallback): +async def test_raise_close_from_task(done: Done) -> None: + async def target(obj: TaskDoneCallback) -> None: obj.register() delay = random.random() * 0.01 time.sleep(delay) @@ -128,7 +128,7 @@ async def target(obj: TaskDoneCallback): await t -async def test_raise_in_done(): +async def test_raise_in_done() -> None: done = Mock(side_effect=ValueError) obj = TaskDoneCallback(done=done) t = asyncio.create_task(target(obj)) @@ -139,7 +139,7 @@ async def test_raise_in_done(): await obj.aclose() -async def test_done_none(): +async def test_done_none() -> None: async with TaskDoneCallback() as obj: t = asyncio.create_task(target(obj)) await asyncio.sleep(0) # let the task be registered diff --git a/tests/utils/done_callback/test_thread.py b/tests/utils/done_callback/test_thread.py index 28992458..2d1c8f06 100644 --- a/tests/utils/done_callback/test_thread.py +++ b/tests/utils/done_callback/test_thread.py @@ -1,6 +1,6 @@ import random import time -from threading import current_thread +from threading import Thread, current_thread from unittest.mock import Mock import pytest @@ -8,7 +8,7 @@ from nextline.utils import ExcThread, ThreadDoneCallback -def target(obj: ThreadDoneCallback): +def target(obj: ThreadDoneCallback) -> None: """To run in a thread""" assert current_thread() == obj.register() delay = random.random() * 0.01 @@ -18,20 +18,20 @@ def target(obj: ThreadDoneCallback): class Done: """A callback function""" - def __init__(self): - self.args = set() + def __init__(self) -> None: + self.args = set[Thread]() - def __call__(self, arg): + def __call__(self, arg: Thread) -> None: self.args.add(arg) @pytest.fixture() -def done(): +def done() -> Done: """A callback function""" - yield Done() + return Done() -def test_close(done: Done): +def test_close(done: Done) -> None: obj = ThreadDoneCallback(done=done) t = ExcThread(target=target, args=(obj,)) t.start() @@ -41,7 +41,7 @@ def test_close(done: Done): t.join() -def test_with(done: Done): +def test_with(done: Done) -> None: with ThreadDoneCallback(done=done) as obj: t = ExcThread(target=target, args=(obj,)) t.start() @@ -50,8 +50,8 @@ def test_with(done: Done): t.join() -def test_register_arg(done: Done): - def target(): +def test_register_arg(done: Done) -> None: + def target() -> None: delay = random.random() * 0.01 time.sleep(delay) @@ -68,7 +68,7 @@ def target(): t.join() -def test_daemon(done: Done): +def test_daemon(done: Done) -> None: """Not blocked even if close() is not called""" obj = ThreadDoneCallback(done=done) t = ExcThread(target=target, args=(obj,)) @@ -79,7 +79,7 @@ def test_daemon(done: Done): @pytest.mark.parametrize("n_threads", [0, 1, 2, 5, 10]) -def test_multiple(n_threads: int, done: Done): +def test_multiple(n_threads: int, done: Done) -> None: with ThreadDoneCallback(done=done) as obj: threads = {ExcThread(target=target, args=(obj,)) for _ in range(n_threads)} for t in threads: @@ -92,8 +92,8 @@ def test_multiple(n_threads: int, done: Done): t.join() -def test_raise_close_from_thread(done: Done): - def target(obj: ThreadDoneCallback): +def test_raise_close_from_thread(done: Done) -> None: + def target(obj: ThreadDoneCallback) -> None: obj.register() obj.close() @@ -107,7 +107,7 @@ def target(obj: ThreadDoneCallback): t.join() -def test_raise_in_done(): +def test_raise_in_done() -> None: done = Mock(side_effect=ValueError) obj = ThreadDoneCallback(done=done) t = ExcThread(target=target, args=(obj,)) @@ -120,7 +120,7 @@ def test_raise_in_done(): t.join() -def test_interval(done: Done): +def test_interval(done: Done) -> None: interval = 0.02 obj = ThreadDoneCallback(done=done, interval=interval) t = ExcThread(target=target, args=(obj,)) @@ -131,7 +131,7 @@ def test_interval(done: Done): t.join() -def test_done_none(): +def test_done_none() -> None: with ThreadDoneCallback() as obj: t = ExcThread(target=target, args=(obj,)) t.start() diff --git a/tests/utils/done_callback/test_union.py b/tests/utils/done_callback/test_union.py index d59e6d07..c3200338 100644 --- a/tests/utils/done_callback/test_union.py +++ b/tests/utils/done_callback/test_union.py @@ -8,34 +8,34 @@ from nextline.utils import ThreadTaskDoneCallback, current_task_or_thread -def target(obj: ThreadTaskDoneCallback): +def target(obj: ThreadTaskDoneCallback) -> None: """To run in a thread or task""" assert current_task_or_thread() == obj.register() delay = random.random() * 0.01 time.sleep(delay) -async def atarget(obj: ThreadTaskDoneCallback): +async def atarget(obj: ThreadTaskDoneCallback) -> None: target(obj) class Done: """A callback function""" - def __init__(self): - self.args = set() + def __init__(self) -> None: + self.args = set[Thread | asyncio.Task]() - def __call__(self, arg): + def __call__(self, arg: Thread | asyncio.Task) -> None: self.args.add(arg) @pytest.fixture() -def done(): +def done() -> Done: """A callback function""" - yield Done() + return Done() -def test_thread(done: Done): +def test_thread(done: Done) -> None: obj = ThreadTaskDoneCallback(done=done) t = Thread(target=target, args=(obj,)) t.start() @@ -45,7 +45,7 @@ def test_thread(done: Done): t.join() -async def test_task(done: Done): +async def test_task(done: Done) -> None: obj = ThreadTaskDoneCallback(done=done) t = asyncio.create_task(atarget(obj)) await t @@ -53,7 +53,7 @@ async def test_task(done: Done): assert {t} == done.args -def test_with_thread(done: Done): +def test_with_thread(done: Done) -> None: with ThreadTaskDoneCallback(done=done) as obj: t = Thread(target=target, args=(obj,)) t.start() @@ -62,14 +62,14 @@ def test_with_thread(done: Done): t.join() -async def test_with_task(done: Done): +async def test_with_task(done: Done) -> None: async with ThreadTaskDoneCallback(done=done) as obj: t = asyncio.create_task(atarget(obj)) await t assert {t} == done.args -def test_done_none_thread(): +def test_done_none_thread() -> None: with ThreadTaskDoneCallback() as obj: t = Thread(target=target, args=(obj,)) t.start() @@ -78,7 +78,7 @@ def test_done_none_thread(): t.join() -async def test_done_none_task(): +async def test_done_none_task() -> None: async with ThreadTaskDoneCallback() as obj: t = asyncio.create_task(atarget(obj)) await asyncio.sleep(0) # let the task be registered diff --git a/tests/utils/pubsub/test_broker.py b/tests/utils/pubsub/test_broker.py index 71b2325e..f7b30599 100644 --- a/tests/utils/pubsub/test_broker.py +++ b/tests/utils/pubsub/test_broker.py @@ -9,14 +9,14 @@ from nextline.utils import PubSub -async def test_end(): +async def test_end() -> None: key = 'foo' async with PubSub[str, str]() as obj: - async def subscribe(): + async def subscribe() -> tuple[str, ...]: return tuple([y async for y in obj.subscribe(key)]) - async def put(): + async def put() -> None: await asyncio.sleep(0.001) await obj.end(key) @@ -24,23 +24,23 @@ async def put(): assert result == () -async def test_end_without_subscription(): +async def test_end_without_subscription() -> None: key = 'foo' async with PubSub[str, str]() as obj: await obj.end(key) @given(items=st.lists(st.text())) -async def test_close(items: Sequence[str]): +async def test_close(items: Sequence[str]) -> None: items = tuple(items) key = 'foo' async with PubSub[str, str]() as obj: - async def subscribe(): + async def subscribe() -> tuple[str, ...]: return tuple([y async for y in obj.subscribe(key)]) - async def put(): + async def put() -> None: await asyncio.sleep(0.001) for item in items: await obj.publish(key, item) @@ -63,7 +63,7 @@ async def test_last( pre_items: Sequence[str], items: Sequence[str], last: bool, -): +) -> None: key = 'foo' pre_items = tuple(pre_items) items = tuple(items) @@ -75,10 +75,10 @@ async def test_last( await asyncio.sleep(0.001) - async def subscribe(): + async def subscribe() -> tuple[str, ...]: return tuple([y async for y in obj.subscribe(key, last=last)]) - async def put(): + async def put() -> None: await asyncio.sleep(0.001) for item in items: await obj.publish(key, item) @@ -97,17 +97,17 @@ async def test_matrix( keys: Sequence[str], n_items: int, n_subscribers: int, -): +) -> None: keys = tuple(keys) n_keys = len(keys) items = {k: tuple(f"{k}-{i+1}" for i in range(n_items)) for k in keys} async with PubSub[str, str]() as obj: - async def subscribe(key): + async def subscribe(key: str) -> tuple[str, ...]: return tuple([y async for y in obj.subscribe(key)]) - async def put(key): + async def put(key: str) -> None: time.sleep(0.01) for item in items[key]: await obj.publish(key, item) diff --git a/tests/utils/run/test_signal.py b/tests/utils/run/test_signal.py index 758cde04..078cdd7a 100644 --- a/tests/utils/run/test_signal.py +++ b/tests/utils/run/test_signal.py @@ -96,7 +96,7 @@ class Handled(Exception): pass -def handler(signum: signal._SIGNUM, frame: FrameType): +def handler(signum: signal._SIGNUM, frame: FrameType) -> NoReturn: raise Handled diff --git a/tests/utils/test_agen_with_wait.py b/tests/utils/test_agen_with_wait.py index 4b514754..68e1b895 100644 --- a/tests/utils/test_agen_with_wait.py +++ b/tests/utils/test_agen_with_wait.py @@ -1,5 +1,7 @@ import asyncio +from collections.abc import AsyncIterator from random import randint, random +from typing import NoReturn import pytest @@ -7,12 +9,12 @@ async def test_one() -> None: - async def agen(): + async def agen() -> AsyncIterator[int]: for i in range(3): yield i await asyncio.sleep(0.001) - async def afunc(): + async def afunc() -> None: delay = random() * 0.001 await asyncio.sleep(delay) @@ -23,7 +25,7 @@ async def afunc(): async for _ in obj: tasks = {asyncio.create_task(afunc()) for _ in range(randint(0, 5))} all |= tasks - done_, pending = await obj.asend(tasks) + done_, pending = await obj.asend(tasks) # type: ignore done.extend(done_) # type: ignore await asyncio.gather(*pending) @@ -31,13 +33,13 @@ async def afunc(): assert all == set(done) | set(pending) -async def test_raise(): - async def agen(): +async def test_raise() -> None: + async def agen() -> AsyncIterator[int]: yield 0 await asyncio.sleep(0.1) assert False # The line shouldn't be reached - async def afunc(): + async def afunc() -> NoReturn: await asyncio.sleep(0) raise Exception("foo", "bar") @@ -45,13 +47,13 @@ async def afunc(): with pytest.raises(Exception) as exc: async for _ in obj: tasks = {asyncio.create_task(afunc())} - _, pending = await obj.asend(tasks) + _, pending = await obj.asend(tasks) # type: ignore assert ("foo", "bar") == exc.value.args -async def test_without_send(): - async def agen(): +async def test_without_send() -> None: + async def agen() -> AsyncIterator[int]: for i in range(3): yield i await asyncio.sleep(0) diff --git a/tests/utils/test_multiprocessing_logging.py b/tests/utils/test_multiprocessing_logging.py index 871cdd23..6c8def77 100644 --- a/tests/utils/test_multiprocessing_logging.py +++ b/tests/utils/test_multiprocessing_logging.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize('mp_method', [None, 'spawn', 'fork', 'forkserver']) async def test_multiprocessing_logging( mp_method: str | None, caplog: LogCaptureFixture -): +) -> None: mp_context = mp.get_context(mp_method) if mp_method else None with caplog.at_level(logging.DEBUG): @@ -28,7 +28,7 @@ async def test_multiprocessing_logging( assert caplog.records[0].name == __name__ -def fn(): +def fn() -> str: logger = logging.getLogger(__name__) logger.debug('bar') return 'foo' diff --git a/tests/utils/test_peek.py b/tests/utils/test_peek.py index 3e856717..477d945b 100644 --- a/tests/utils/test_peek.py +++ b/tests/utils/test_peek.py @@ -21,7 +21,7 @@ def test_print( post_msg: str, errs: list[str], post_err: str, -): +) -> None: capsys.readouterr() # clear callback_out = Mock() @@ -60,7 +60,7 @@ def test_yield( capsys: pytest.CaptureFixture, msgs: list[str], errs: list[str], -): +) -> None: capsys.readouterr() # clear callback_out = Mock() @@ -84,7 +84,7 @@ def test_yield( assert ''.join(errs) == captured.err -def test_raise(capsys: pytest.CaptureFixture): +def test_raise(capsys: pytest.CaptureFixture) -> None: callback = Mock(side_effect=MockError) with peek_stdout(callback): diff --git a/tests/utils/test_profile.py b/tests/utils/test_profile.py index be06ba22..b2792372 100644 --- a/tests/utils/test_profile.py +++ b/tests/utils/test_profile.py @@ -3,8 +3,8 @@ from nextline.utils import profile_func -def test_one(): - def func(): +def test_one() -> None: + def func() -> int: re.compile("foo|bar") return 123 diff --git a/tests/utils/test_thread_exception.py b/tests/utils/test_thread_exception.py index 289e68d6..09220241 100644 --- a/tests/utils/test_thread_exception.py +++ b/tests/utils/test_thread_exception.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import pytest from nextline.utils import ExcThread @@ -7,11 +9,11 @@ class ErrorInThread(Exception): pass -def func_raise(): +def func_raise() -> NoReturn: raise ErrorInThread() -def test_return(): +def test_return() -> None: t = ExcThread(target=func_raise) t.start() with pytest.raises(ErrorInThread): diff --git a/tests/utils/test_thread_task_id.py b/tests/utils/test_thread_task_id.py index d6c8a78e..4ab6e0dd 100644 --- a/tests/utils/test_thread_task_id.py +++ b/tests/utils/test_thread_task_id.py @@ -11,7 +11,7 @@ from nextline.utils import ThreadTaskIdComposer as IdComposer -def assert_call(obj: IdComposer, expected: ThreadTaskId, has_id: bool = False): +def assert_call(obj: IdComposer, expected: ThreadTaskId, has_id: bool = False) -> None: assert obj.has_id() is has_id assert obj.has_id() is has_id assert expected == obj() @@ -21,7 +21,7 @@ def assert_call(obj: IdComposer, expected: ThreadTaskId, has_id: bool = False): async def async_assert_call( obj: IdComposer, expected: ThreadTaskId, has_id: bool = False -): +) -> None: assert obj.has_id() is has_id assert obj.has_id() is has_id await asyncio.sleep(0) @@ -32,12 +32,12 @@ async def async_assert_call( @pytest.fixture() -def obj(): +def obj() -> IdComposer: y = IdComposer() - yield y + return y -def test_compose(obj: IdComposer): +def test_compose(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), None) assert_call(obj, expected) @@ -46,7 +46,7 @@ def test_compose(obj: IdComposer): assert_call(obj, expected, True) -def test_threads(obj: IdComposer): +def test_threads(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), None) assert_call(obj, expected) @@ -76,7 +76,7 @@ def test_threads(obj: IdComposer): t.join() -async def test_async_coroutine(obj: IdComposer): +async def test_async_coroutine(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), TaskNo(1)) assert_call(obj, expected) @@ -90,7 +90,7 @@ async def test_async_coroutine(obj: IdComposer): await async_assert_call(obj, expected, True) -async def test_async_tasks(obj: IdComposer): +async def test_async_tasks(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), TaskNo(1)) assert_call(obj, expected) @@ -116,7 +116,7 @@ async def test_async_tasks(obj: IdComposer): await t -async def test_async_tasks_gather(obj: IdComposer): +async def test_async_tasks_gather(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), TaskNo(1)) assert_call(obj, expected) @@ -140,7 +140,7 @@ async def test_async_tasks_gather(obj: IdComposer): await asyncio.gather(*aws) -def test_async_asyncio_run(obj: IdComposer): +def test_async_asyncio_run(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), None) assert_call(obj, expected) @@ -157,7 +157,7 @@ def test_async_asyncio_run(obj: IdComposer): @mark.skipif(getenv('GITHUB_ACTIONS') == 'true', reason='Fails on GitHub Actions') -async def test_async_asyncio_to_thread(obj: IdComposer): +async def test_async_asyncio_to_thread(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), TaskNo(1)) assert_call(obj, expected) @@ -179,7 +179,7 @@ async def test_async_asyncio_to_thread(obj: IdComposer): await to_thread(partial(assert_call, obj, expected, True)) -async def async_nested(obj: IdComposer, expected_thread_id): +async def async_nested(obj: IdComposer, expected_thread_id: ThreadNo) -> None: expected1 = ThreadTaskId(expected_thread_id, TaskNo(1)) assert_call(obj, expected1) await async_assert_call(obj, expected1, True) @@ -192,14 +192,14 @@ async def async_nested(obj: IdComposer, expected_thread_id): await asyncio.gather(*aws) -def nested(obj: IdComposer, expected_thread_id): +def nested(obj: IdComposer, expected_thread_id: ThreadNo) -> None: expected = ThreadTaskId(expected_thread_id, None) assert_call(obj, expected) asyncio.run(async_nested(obj, expected_thread_id)) -def test_nested(obj: IdComposer): +def test_nested(obj: IdComposer) -> None: expected = ThreadTaskId(ThreadNo(1), None) assert_call(obj, expected) From 297ee86eafa9ce19a018a51c334be4b8ce0bcdd5 Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Mon, 16 Sep 2024 11:24:31 -0400 Subject: [PATCH 2/2] Update mypy settings --- pyproject.toml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 08be8477..297552b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,8 +88,11 @@ target_version = ['py310', 'py311', 'py312'] profile = "black" [tool.mypy] -# disallow_untyped_defs = true -exclude = ['script\.py'] +disallow_untyped_defs = true +exclude = '''(?x)( + example/.*\.py$ + | test_script\.py$ +)''' [[tool.mypy.overrides]]