diff --git a/pyproject.toml b/pyproject.toml index b5e3d43153..79beab840d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,8 +64,6 @@ module = [ "trio/_core/_generated_io_windows", "trio/_core/_io_windows", -"trio/_signals", - # internal "trio/_windows_pipes", @@ -93,7 +91,6 @@ module = [ "trio/_tests/test_highlevel_ssl_helpers", "trio/_tests/test_path", "trio/_tests/test_scheduler_determinism", -"trio/_tests/test_signals", "trio/_tests/test_socket", "trio/_tests/test_ssl", "trio/_tests/test_subprocess", diff --git a/trio/_signals.py b/trio/_signals.py index fe2bde946e..283c3a44a8 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -1,11 +1,19 @@ +from __future__ import annotations + import signal from collections import OrderedDict +from collections.abc import AsyncIterator, Callable, Generator, Iterable from contextlib import contextmanager +from types import FrameType +from typing import TYPE_CHECKING import trio from ._util import ConflictDetector, is_main_thread, signal_raise +if TYPE_CHECKING: + from typing_extensions import Self + # Discussion of signal handling strategies: # # - On Windows signals barely exist. There are no options; signal handlers are @@ -43,7 +51,10 @@ @contextmanager -def _signal_handler(signals, handler): +def _signal_handler( + signals: Iterable[int], + handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None, +) -> Generator[None, None, None]: original_handlers = {} try: for signum in set(signals): @@ -55,23 +66,23 @@ def _signal_handler(signals, handler): class SignalReceiver: - def __init__(self): + def __init__(self) -> None: # {signal num: None} - self._pending = OrderedDict() + self._pending: OrderedDict[int, None] = OrderedDict() self._lot = trio.lowlevel.ParkingLot() self._conflict_detector = ConflictDetector( "only one task can iterate on a signal receiver at a time" ) self._closed = False - def _add(self, signum): + def _add(self, signum: int) -> None: if self._closed: signal_raise(signum) else: self._pending[signum] = None self._lot.unpark() - def _redeliver_remaining(self): + def _redeliver_remaining(self) -> None: # First make sure that any signals still in the delivery pipeline will # get redelivered self._closed = True @@ -79,7 +90,7 @@ def _redeliver_remaining(self): # And then redeliver any that are sitting in pending. This is done # using a weird recursive construct to make sure we process everything # even if some of the handlers raise exceptions. - def deliver_next(): + def deliver_next() -> None: if self._pending: signum, _ = self._pending.popitem(last=False) try: @@ -89,14 +100,10 @@ def deliver_next(): deliver_next() - # Helper for tests, not public or otherwise used - def _pending_signal_count(self): - return len(self._pending) - - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> int: if self._closed: raise RuntimeError("open_signal_receiver block already exited") # In principle it would be possible to support multiple concurrent @@ -111,8 +118,17 @@ async def __anext__(self): return signum +def get_pending_signal_count(rec: AsyncIterator[int]) -> int: + """Helper for tests, not public or otherwise used.""" + # open_signal_receiver() always produces SignalReceiver, this should not fail. + assert isinstance(rec, SignalReceiver) + return len(rec._pending) + + @contextmanager -def open_signal_receiver(*signals): +def open_signal_receiver( + *signals: signal.Signals | int, +) -> Generator[AsyncIterator[int], None, None]: """A context manager for catching signals. Entering this context manager starts listening for the given signals and @@ -158,7 +174,7 @@ def open_signal_receiver(*signals): token = trio.lowlevel.current_trio_token() queue = SignalReceiver() - def handler(signum, _): + def handler(signum: int, frame: FrameType | None) -> None: token.run_sync_soon(queue._add, signum, idempotent=True) try: diff --git a/trio/_tests/test_signals.py b/trio/_tests/test_signals.py index 313cce259f..1e42239e35 100644 --- a/trio/_tests/test_signals.py +++ b/trio/_tests/test_signals.py @@ -1,15 +1,19 @@ +from __future__ import annotations + import signal +from types import FrameType +from typing import NoReturn import pytest import trio from .. import _core -from .._signals import _signal_handler, open_signal_receiver +from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver from .._util import signal_raise -async def test_open_signal_receiver(): +async def test_open_signal_receiver() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the @@ -22,18 +26,18 @@ async def test_open_signal_receiver(): async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break - assert receiver._pending_signal_count() == 0 + assert get_pending_signal_count(receiver) == 0 signal_raise(signal.SIGILL) async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break - assert receiver._pending_signal_count() == 0 + assert get_pending_signal_count(receiver) == 0 with pytest.raises(RuntimeError): await receiver.__anext__() assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): +async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) with pytest.raises(ValueError): with open_signal_receiver(signal.SIGILL, 1234567): @@ -42,13 +46,13 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_empty_fail(): +async def test_open_signal_receiver_empty_fail() -> None: with pytest.raises(TypeError, match="No signals were provided"): with open_signal_receiver(): pass -async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): +async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL, signal.SIGILL): pass @@ -56,8 +60,8 @@ async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_catch_signals_wrong_thread(): - async def naughty(): +async def test_catch_signals_wrong_thread() -> None: + async def naughty() -> None: with open_signal_receiver(signal.SIGINT): pass # pragma: no cover @@ -65,7 +69,7 @@ async def naughty(): await trio.to_thread.run_sync(trio.run, naughty) -async def test_open_signal_receiver_conflict(): +async def test_open_signal_receiver_conflict() -> None: with pytest.raises(trio.BusyResourceError): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: @@ -75,14 +79,14 @@ async def test_open_signal_receiver_conflict(): # Blocks until all previous calls to run_sync_soon(idempotent=True) have been # processed. -async def wait_run_sync_soon_idempotent_queue_barrier(): +async def wait_run_sync_soon_idempotent_queue_barrier() -> None: ev = trio.Event() token = _core.current_trio_token() token.run_sync_soon(ev.set, idempotent=True) await ev.wait() -async def test_open_signal_receiver_no_starvation(): +async def test_open_signal_receiver_no_starvation() -> None: # Set up a situation where there are always 2 pending signals available to # report, and make sure that instead of getting the same signal reported # over and over, it alternates between reporting both of them. @@ -101,8 +105,8 @@ async def test_open_signal_receiver_no_starvation(): assert got in [signal.SIGILL, signal.SIGFPE] assert got != previous previous = got - # Clear out the last signal so it doesn't get redelivered - while receiver._pending_signal_count() != 0: + # Clear out the last signal so that it doesn't get redelivered + while get_pending_signal_count(receiver) != 0: await receiver.__anext__() except: # pragma: no cover # If there's an unhandled exception above, then exiting the @@ -113,10 +117,10 @@ async def test_open_signal_receiver_no_starvation(): traceback.print_exc() -async def test_catch_signals_race_condition_on_exit(): - delivered_directly = set() +async def test_catch_signals_race_condition_on_exit() -> None: + delivered_directly: set[int] = set() - def direct_handler(signo, frame): + def direct_handler(signo: int, frame: FrameType | None) -> None: delivered_directly.add(signo) print(1) @@ -138,7 +142,7 @@ def direct_handler(signo, frame): signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() - assert receiver._pending_signal_count() == 2 + assert get_pending_signal_count(receiver) == 2 assert delivered_directly == {signal.SIGILL, signal.SIGFPE} delivered_directly.clear() @@ -156,12 +160,12 @@ def direct_handler(signo, frame): with open_signal_receiver(signal.SIGILL) as receiver: signal_raise(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() - assert receiver._pending_signal_count() == 1 + assert get_pending_signal_count(receiver) == 1 # test passes if the process reaches this point without dying # Check exception chaining if there are multiple exception-raising # handlers - def raise_handler(signum, _): + def raise_handler(signum: int, frame: FrameType | None) -> NoReturn: raise RuntimeError(signum) with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): @@ -170,7 +174,7 @@ def raise_handler(signum, _): signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() - assert receiver._pending_signal_count() == 2 + assert get_pending_signal_count(receiver) == 2 exc = excinfo.value signums = {exc.args[0]} assert isinstance(exc.__context__, RuntimeError)