diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index a6dc5b1f5a..6cb0a561c2 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -802,9 +802,9 @@ to set the default behavior for any nursery in your program that doesn't overrid wrapping, so you'll get maximum compatibility with code that was written to support older versions of Trio. -To maintain backwards compatibility, the default is ``strict_exception_groups=False``. -The default will eventually change to ``True`` in a future version of Trio, once -Python 3.11 and later versions are in wide use. +The default is set to ``strict_exception_groups=True`` in line with the default behaviour +of ``TaskGroup`` in asyncio and anyio. This is also to avoid any bugs caused by only +catching one type of exceptions/exceptiongroups. .. _exceptiongroup: https://pypi.org/project/exceptiongroup/ diff --git a/newsfragments/2786.breaking.rst b/newsfragments/2786.breaking.rst new file mode 100644 index 0000000000..35433da6df --- /dev/null +++ b/newsfragments/2786.breaking.rst @@ -0,0 +1,2 @@ +``strict_exception_groups`` now defaults to True in ``trio.run`` and ``trio.start_guest_run``, as well as ``trio.open_nursery`` as a result of that. +This is unfortunately very tricky to change with a deprecation period, as raising a ``DeprecationWarning`` whenever ``strict_exception_groups`` is not specified would raise a lot of unnecessary warnings. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index d5006eed91..bc2059761c 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -917,7 +917,7 @@ class NurseryManager: """ - strict_exception_groups: bool = attr.ib(default=False) + strict_exception_groups: bool = attr.ib(default=True) @enable_ki_protection async def __aenter__(self) -> Nursery: @@ -985,8 +985,8 @@ def open_nursery( Args: strict_exception_groups (bool): If true, even a single raised exception will be - wrapped in an exception group. This will eventually become the default - behavior. If not specified, uses the value passed to :func:`run`. + wrapped in an exception group. If not specified, uses the value passed to + :func:`run`, which defaults to true. """ if strict_exception_groups is None: @@ -2150,7 +2150,7 @@ def run( clock: Clock | None = None, instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, - strict_exception_groups: bool = False, + strict_exception_groups: bool = True, ) -> RetT: """Run a Trio-flavored async function, and return the result. @@ -2207,9 +2207,9 @@ def run( main thread (this is a Python limitation), or if you use :func:`open_signal_receiver` to catch SIGINT. - strict_exception_groups (bool): If true, nurseries will always wrap even a single - raised exception in an exception group. This can be overridden on the level of - individual nurseries. This will eventually become the default behavior. + strict_exception_groups (bool): Unless set to false, nurseries will always wrap + even a single raised exception in an exception group. This can be overridden + on the level of individual nurseries. Returns: Whatever ``async_fn`` returns. @@ -2267,7 +2267,7 @@ def start_guest_run( clock: Clock | None = None, instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, - strict_exception_groups: bool = False, + strict_exception_groups: bool = True, ) -> None: """Start a "guest" run of Trio on top of some other "host" event loop. diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index cd98bc9bca..fde7ea7f76 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -9,6 +9,9 @@ import outcome import pytest +from trio import testing +from trio.testing import ExpectedExceptionGroup + try: from async_generator import async_generator, yield_ except ImportError: # pragma: no cover @@ -293,7 +296,7 @@ async def check_unprotected_kill() -> None: nursery.start_soon(sleeper, "s2", record_set) nursery.start_soon(raiser, "r1", record_set) - with pytest.raises(KeyboardInterrupt): + with testing.raises(ExpectedExceptionGroup(KeyboardInterrupt)): _core.run(check_unprotected_kill) assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"} @@ -309,7 +312,7 @@ async def check_protected_kill() -> None: nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set) # __aexit__ blocks, and then receives the KI - with pytest.raises(KeyboardInterrupt): + with testing.raises(ExpectedExceptionGroup(KeyboardInterrupt)): _core.run(check_protected_kill) assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"} @@ -331,7 +334,8 @@ def kill_during_shutdown() -> None: token.run_sync_soon(kill_during_shutdown) - with pytest.raises(KeyboardInterrupt): + # Does not wrap in an ExceptionGroup(!!) + with testing.raises(KeyboardInterrupt): _core.run(check_kill_during_shutdown) # KI arrives very early, before main is even spawned @@ -344,6 +348,7 @@ def before_run(self) -> None: async def main_1() -> None: await _core.checkpoint() + # Does not wrap in an ExceptionGroup(!!) with pytest.raises(KeyboardInterrupt): _core.run(main_1, instruments=[InstrumentOfDeath()]) diff --git a/src/trio/_core/_tests/test_multierror.py b/src/trio/_core/_tests/test_multierror.py index 505b371f41..7f47e04d74 100644 --- a/src/trio/_core/_tests/test_multierror.py +++ b/src/trio/_core/_tests/test_multierror.py @@ -136,7 +136,7 @@ async def test_MultiErrorNotHashable() -> None: assert exc1 != exc2 assert exc1 != exc3 - with pytest.raises(MultiError): + with pytest.raises(ExceptionGroup): async with open_nursery() as nursery: nursery.start_soon(raise_nothashable, 42) nursery.start_soon(raise_nothashable, 4242) diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 310e9a67e5..056bdc6a3e 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -16,8 +16,9 @@ import pytest import sniffio +from trio.testing import ExpectedExceptionGroup + from ... import _core -from ..._core._multierror import MultiError, NonBaseMultiError from ..._threads import to_thread_run_sync from ..._timeouts import fail_after, sleep from ...testing import Sequencer, assert_checkpoints, wait_all_tasks_blocked @@ -123,10 +124,11 @@ async def test_nursery_warn_use_async_with() -> None: async def test_nursery_main_block_error_basic() -> None: exc = ValueError("whoops") - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(ValueError)) as excinfo: async with _core.open_nursery(): raise exc - assert excinfo.value is exc + assert len(excinfo.value.exceptions) == 1 + assert excinfo.value.exceptions[0] is exc async def test_child_crash_basic() -> None: @@ -139,8 +141,9 @@ async def erroring() -> NoReturn: # nursery.__aexit__ propagates exception from child back to parent async with _core.open_nursery() as nursery: nursery.start_soon(erroring) - except ValueError as e: - assert e is exc + except ExceptionGroup as e: + assert len(e.exceptions) == 1 + assert e.exceptions[0] is exc async def test_basic_interleave() -> None: @@ -178,16 +181,15 @@ async def main() -> None: nursery.start_soon(looper) nursery.start_soon(crasher) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(ValueError("argh"))): _core.run(main) assert looper_record == ["cancelled"] - assert excinfo.value.args == ("argh",) def test_main_and_task_both_crash() -> None: - # If main crashes and there's also a task crash, then we get both in a - # MultiError + # If main crashes and there's also a task crash, then we get both in an + # ExceptionGroup async def crasher() -> NoReturn: raise ValueError @@ -196,13 +198,8 @@ async def main() -> NoReturn: nursery.start_soon(crasher) raise KeyError - with pytest.raises(MultiError) as excinfo: + with pytest.raises(ExpectedExceptionGroup((ValueError, KeyError))): _core.run(main) - print(excinfo.value) - assert {type(exc) for exc in excinfo.value.exceptions} == { - ValueError, - KeyError, - } def test_two_child_crashes() -> None: @@ -214,19 +211,15 @@ async def main() -> None: nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) - with pytest.raises(MultiError) as excinfo: + with pytest.raises(ExpectedExceptionGroup((ValueError, KeyError))): _core.run(main) - assert {type(exc) for exc in excinfo.value.exceptions} == { - ValueError, - KeyError, - } async def test_child_crash_wakes_parent() -> None: async def crasher() -> NoReturn: raise ValueError - with pytest.raises(ValueError): + with pytest.raises(ExpectedExceptionGroup(ValueError)): async with _core.open_nursery() as nursery: nursery.start_soon(crasher) await sleep_forever() @@ -433,7 +426,10 @@ async def test_cancel_scope_multierror_filtering() -> None: async def crasher() -> NoReturn: raise KeyError - try: + # This is outside the outer scope, so all the Cancelled + # exceptions should have been absorbed, leaving just a regular + # KeyError from crasher() + with pytest.raises(ExpectedExceptionGroup(KeyError)): with _core.CancelScope() as outer: try: async with _core.open_nursery() as nursery: @@ -449,7 +445,7 @@ async def crasher() -> NoReturn: # And one that raises a different error nursery.start_soon(crasher) # t4 # and then our __aexit__ also receives an outer Cancelled - except MultiError as multi_exc: + except ExceptionGroup as multi_exc: # Since the outer scope became cancelled before the # nursery block exited, all cancellations inside the # nursery block continue propagating to reach the @@ -461,15 +457,8 @@ async def crasher() -> NoReturn: summary[type(exc)] += 1 assert summary == {_core.Cancelled: 3, KeyError: 1} raise - except AssertionError: # pragma: no cover - raise - except BaseException as exc: - # This is outside the outer scope, so all the Cancelled - # exceptions should have been absorbed, leaving just a regular - # KeyError from crasher() - assert type(exc) is KeyError - else: # pragma: no cover - raise AssertionError() + else: # pragma: no cover + raise AssertionError("no ExceptionGroup raised") async def test_precancelled_task() -> None: @@ -785,17 +774,19 @@ async def task2() -> None: await wait_all_tasks_blocked() nursery.cancel_scope.__exit__(None, None, None) finally: - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises( + RuntimeError, match="which had already been exited" + ) as exc_info: await nursery_mgr.__aexit__(*sys.exc_info()) - assert "which had already been exited" in str(exc_info.value) - assert type(exc_info.value.__context__) is NonBaseMultiError - assert len(exc_info.value.__context__.exceptions) == 3 - cancelled_in_context = False - for exc in exc_info.value.__context__.exceptions: - assert isinstance(exc, RuntimeError) - assert "closed before the task exited" in str(exc) - cancelled_in_context |= isinstance(exc.__context__, _core.Cancelled) - assert cancelled_in_context # for the sleep_forever + assert ExpectedExceptionGroup( + (RuntimeError("closed before the task exited"),) * 3 + ).matches(exc_info.value.__context__) + # TODO: TypeGuard + assert isinstance(exc_info.value.__context__, BaseExceptionGroup) + assert any( + isinstance(exc.__context__, _core.Cancelled) + for exc in exc_info.value.__context__.exceptions + ) # Trying to exit a cancel scope from an unrelated task raises an error # without affecting any state @@ -946,14 +937,14 @@ async def main() -> None: _core.spawn_system_task(system_task) await sleep_forever() + # not wrapped in ExceptionGroup with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - me = excinfo.value.__cause__ - assert isinstance(me, MultiError) - assert len(me.exceptions) == 2 - for exc in me.exceptions: - assert isinstance(exc, (KeyError, ValueError)) + # triple-wrapped exceptions ?!?! + assert ExpectedExceptionGroup( + ExpectedExceptionGroup(ExpectedExceptionGroup(KeyError, ValueError)) + ).matches(excinfo.value.__cause__) def test_system_task_crash_plus_Cancelled() -> None: @@ -979,7 +970,10 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert type(excinfo.value.__cause__) is ValueError + # triple-wrap + assert ExpectedExceptionGroup( + ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError)) + ).matches(excinfo.value.__cause__) def test_system_task_crash_KeyboardInterrupt() -> None: @@ -992,7 +986,9 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert isinstance(excinfo.value.__cause__, KeyboardInterrupt) + assert ExpectedExceptionGroup(ExpectedExceptionGroup(KeyboardInterrupt)).matches( + excinfo.value.__cause__ + ) # This used to fail because checkpoint was a yield followed by an immediate @@ -1093,7 +1089,7 @@ async def child() -> None: await sleep_forever() raise - with pytest.raises(KeyError): + with pytest.raises(ExpectedExceptionGroup(KeyError)): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() @@ -1114,20 +1110,19 @@ async def child() -> None: except Exception: await sleep_forever() - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(ValueError)) as excinfo: async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() _core.reschedule(not_none(child_task), outcome.Error(ValueError())) - - assert isinstance(excinfo.value.__context__, KeyError) + assert isinstance(excinfo.value.exceptions[0].__context__, KeyError) async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: async def crasher() -> NoReturn: raise KeyError - with pytest.raises(MultiError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(KeyError, ValueError)) as excinfo: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise ValueError @@ -1257,8 +1252,9 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - - assert type(excinfo.value.__cause__) is KeyError + assert ExpectedExceptionGroup(ExpectedExceptionGroup(KeyError)).matches( + excinfo.value.__cause__ + ) assert record == {"2nd run_sync_soon ran", "cancelled!"} @@ -1373,7 +1369,7 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert type(excinfo.value.__cause__) is KeyError + assert ExpectedExceptionGroup(KeyError).matches(excinfo.value.__cause__) assert record == ["main exiting", "2nd ran"] @@ -1587,16 +1583,30 @@ async def main() -> None: async def f() -> None: # pragma: no cover pass - with pytest.raises(TypeError, match="expecting an async function"): - bad_call(f()) # type: ignore[arg-type] - async def async_gen(arg: T) -> AsyncGenerator[T, None]: # pragma: no cover yield arg - with pytest.raises( - TypeError, match="expected an async function but got an async generator" - ): - bad_call(async_gen, 0) # type: ignore + # this is obviously horribly ugly code + # but one raising an exceptiongroup and one not doing so is probably bad + if bad_call is bad_call_run: + with pytest.raises(TypeError, match="expecting an async function"): + bad_call(f()) # type: ignore[arg-type] + with pytest.raises( + TypeError, match="expected an async function but got an async generator" + ): + bad_call(async_gen, 0) # type: ignore + else: + with pytest.raises( + ExpectedExceptionGroup(TypeError("expecting an async function")) + ): + bad_call(f()) # type: ignore[arg-type] + + with pytest.raises( + ExpectedExceptionGroup( + TypeError("expected an async function but got an async generator") + ) + ): + bad_call(async_gen, 0) # type: ignore def test_calling_asyncio_function_gives_nice_error() -> None: @@ -1618,11 +1628,10 @@ async def misguided() -> None: async def test_asyncio_function_inside_nursery_does_not_explode() -> None: # Regression test for https://github.com/python-trio/trio/issues/552 - with pytest.raises(TypeError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(TypeError("asyncio"))): async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) await create_asyncio_future_in_new_loop() - assert "asyncio" in str(excinfo.value) async def test_trivial_yields() -> None: @@ -1659,7 +1668,7 @@ async def noop_with_no_checkpoint() -> None: with _core.CancelScope() as cancel_scope: cancel_scope.cancel() - with pytest.raises(KeyError): + with pytest.raises(ExpectedExceptionGroup(KeyError)): async with _core.open_nursery(): raise KeyError @@ -1788,7 +1797,7 @@ async def test_task_nursery_stack() -> None: assert task._child_nurseries == [] async with _core.open_nursery() as nursery1: assert task._child_nurseries == [nursery1] - with pytest.raises(KeyError): + with pytest.raises(ExpectedExceptionGroup(KeyError)): async with _core.open_nursery() as nursery2: assert task._child_nurseries == [nursery1, nursery2] raise KeyError @@ -1881,7 +1890,7 @@ async def start_sleep_then_crash(nursery: _core.Nursery) -> None: async def test_nursery_explicit_exception() -> None: - with pytest.raises(KeyError): + with pytest.raises(ExpectedExceptionGroup(KeyError)): async with _core.open_nursery(): raise KeyError() @@ -1890,12 +1899,10 @@ async def test_nursery_stop_iteration() -> None: async def fail() -> NoReturn: raise ValueError - try: + with pytest.raises(ExpectedExceptionGroup(StopIteration, ValueError)): async with _core.open_nursery() as nursery: nursery.start_soon(fail) raise StopIteration - except MultiError as e: - assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) async def test_nursery_stop_async_iteration() -> None: @@ -1928,9 +1935,18 @@ async def __anext__(self) -> list[int]: nexts = self.nexts items: list[int] = [-1] * len(nexts) - async with _core.open_nursery() as nursery: - for i, f in enumerate(nexts): - nursery.start_soon(self._accumulate, f, items, i) + try: + async with _core.open_nursery() as nursery: + for i, f in enumerate(nexts): + nursery.start_soon(self._accumulate, f, items, i) + except ExceptionGroup as e: + # I think requiring this is acceptable? + if len(e.exceptions) == 1 and isinstance( + e.exceptions[0], StopAsyncIteration + ): + raise StopAsyncIteration from None + else: + raise return items @@ -1944,7 +1960,7 @@ async def test_traceback_frame_removal() -> None: async def my_child_task() -> NoReturn: raise KeyError() - try: + with pytest.raises(ExceptionGroup) as exc: # Trick: For now cancel/nursery scopes still leave a bunch of tb gunk # behind. But if there's a MultiError, they leave it on the MultiError, # which lets us get a clean look at the KeyError itself. Someday I @@ -1953,16 +1969,15 @@ async def my_child_task() -> NoReturn: async with _core.open_nursery() as nursery: nursery.start_soon(my_child_task) nursery.start_soon(my_child_task) - except MultiError as exc: - first_exc = exc.exceptions[0] - assert isinstance(first_exc, KeyError) - # The top frame in the exception traceback should be inside the child - # task, not trio/contextvars internals. And there's only one frame - # inside the child task, so this will also detect if our frame-removal - # is too eager. - tb = first_exc.__traceback__ - assert tb is not None - assert tb.tb_frame.f_code is my_child_task.__code__ + first_exc = exc.value.exceptions[0] + assert isinstance(first_exc, KeyError) + # The top frame in the exception traceback should be inside the child + # task, not trio/contextvars internals. And there's only one frame + # inside the child task, so this will also detect if our frame-removal + # is too eager. + tb = first_exc.__traceback__ + assert tb is not None + assert tb.tb_frame.f_code is my_child_task.__code__ def test_contextvar_support() -> None: @@ -2154,7 +2169,7 @@ async def detachable_coroutine( # Check the exception paths too task = None pdco_outcome = None - with pytest.raises(KeyError): + with pytest.raises(ExpectedExceptionGroup(KeyError)): async with _core.open_nursery() as nursery: nursery.start_soon(detachable_coroutine, outcome.Error(KeyError()), "uh oh") throw_in = ValueError() @@ -2323,7 +2338,7 @@ async def crasher() -> NoReturn: # (See https://github.com/python-trio/trio/pull/1864) await do_a_cancel() - with pytest.raises(ValueError): + with pytest.raises(ExpectedExceptionGroup(ValueError)): async with _core.open_nursery() as nursery: # cover NurseryManager.__aexit__ nursery.start_soon(crasher) @@ -2347,7 +2362,9 @@ async def crasher() -> NoReturn: old_flags = gc.get_debug() try: - with pytest.raises(ValueError), _core.CancelScope() as outer: + with pytest.raises( + ExpectedExceptionGroup(ValueError) + ), _core.CancelScope() as outer: async with _core.open_nursery() as nursery: gc.collect() gc.set_debug(gc.DEBUG_SAVEALL) @@ -2436,12 +2453,10 @@ async def main() -> NoReturn: async with _core.open_nursery(): raise Exception("foo") - with pytest.raises(MultiError) as exc: + with pytest.raises(ExpectedExceptionGroup(Exception("foo"))) as exc: _core.run(main, strict_exception_groups=True) - assert len(exc.value.exceptions) == 1 assert type(exc.value.exceptions[0]) is Exception - assert exc.value.exceptions[0].args == ("foo",) def test_run_strict_exception_groups_nursery_override() -> None: @@ -2460,14 +2475,10 @@ async def main() -> NoReturn: async def test_nursery_strict_exception_groups() -> None: """Test that strict exception groups can be enabled on a per-nursery basis.""" - with pytest.raises(MultiError) as exc: + with pytest.raises(ExpectedExceptionGroup(Exception("foo"))): async with _core.open_nursery(strict_exception_groups=True): raise Exception("foo") - assert len(exc.value.exceptions) == 1 - assert type(exc.value.exceptions[0]) is Exception - assert exc.value.exceptions[0].args == ("foo",) - async def test_nursery_collapse_strict() -> None: """ @@ -2478,7 +2489,9 @@ async def test_nursery_collapse_strict() -> None: async def raise_error() -> NoReturn: raise RuntimeError("test error") - with pytest.raises(MultiError) as exc: + with pytest.raises( + ExpectedExceptionGroup((RuntimeError, ExpectedExceptionGroup(RuntimeError))) + ): async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) nursery.start_soon(raise_error) @@ -2487,13 +2500,6 @@ async def raise_error() -> NoReturn: nursery2.start_soon(raise_error) nursery.cancel_scope.cancel() - exceptions = exc.value.exceptions - assert len(exceptions) == 2 - assert isinstance(exceptions[0], RuntimeError) - assert isinstance(exceptions[1], MultiError) - assert len(exceptions[1].exceptions) == 1 - assert isinstance(exceptions[1].exceptions[0], RuntimeError) - async def test_nursery_collapse_loose() -> None: """ @@ -2504,20 +2510,15 @@ async def test_nursery_collapse_loose() -> None: async def raise_error() -> NoReturn: raise RuntimeError("test error") - with pytest.raises(MultiError) as exc: - async with _core.open_nursery() as nursery: + with pytest.raises(ExpectedExceptionGroup(RuntimeError, RuntimeError)): + async with _core.open_nursery(strict_exception_groups=False) as nursery: nursery.start_soon(sleep_forever) nursery.start_soon(raise_error) - async with _core.open_nursery() as nursery2: + async with _core.open_nursery(strict_exception_groups=False) as nursery2: nursery2.start_soon(sleep_forever) nursery2.start_soon(raise_error) nursery.cancel_scope.cancel() - exceptions = exc.value.exceptions - assert len(exceptions) == 2 - assert isinstance(exceptions[0], RuntimeError) - assert isinstance(exceptions[1], RuntimeError) - async def test_cancel_scope_no_cancellederror() -> None: """ diff --git a/src/trio/_tests/test_highlevel_serve_listeners.py b/src/trio/_tests/test_highlevel_serve_listeners.py index 5fd8ac72e3..1053eddb20 100644 --- a/src/trio/_tests/test_highlevel_serve_listeners.py +++ b/src/trio/_tests/test_highlevel_serve_listeners.py @@ -10,6 +10,7 @@ import trio from trio import Nursery, StapledStream, TaskStatus from trio.testing import ( + ExpectedExceptionGroup, MemoryReceiveStream, MemorySendStream, MockClock, @@ -112,9 +113,10 @@ async def raise_error() -> NoReturn: listener.accept_hook = raise_error - with pytest.raises(type(error)) as excinfo: + with pytest.raises(ExpectedExceptionGroup(type(error))) as excinfo: await trio.serve_listeners(None, [listener]) # type: ignore[arg-type] - assert excinfo.value is error + + assert excinfo.value.exceptions[0] is error async def test_serve_listeners_accept_capacity_error( @@ -158,7 +160,8 @@ async def connection_watcher( assert len(nursery.child_tasks) == 10 raise Done - with pytest.raises(Done): + # the exception is wrapped twice because we open two nested nurseries + with pytest.raises(ExpectedExceptionGroup(ExpectedExceptionGroup(Done))): async with trio.open_nursery() as nursery: handler_nursery: trio.Nursery = await nursery.start(connection_watcher) await nursery.start( diff --git a/src/trio/_tests/test_signals.py b/src/trio/_tests/test_signals.py index 6caadc3f8b..f694894476 100644 --- a/src/trio/_tests/test_signals.py +++ b/src/trio/_tests/test_signals.py @@ -6,6 +6,7 @@ import pytest import trio +from trio.testing import ExpectedExceptionGroup from .. import _core from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver @@ -72,7 +73,7 @@ async def naughty() -> None: async def test_open_signal_receiver_conflict() -> None: - with pytest.raises(trio.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(trio.BusyResourceError)): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: nursery.start_soon(receiver.__anext__) diff --git a/src/trio/_tests/test_ssl.py b/src/trio/_tests/test_ssl.py index 94e0356f06..9dacf2d5c7 100644 --- a/src/trio/_tests/test_ssl.py +++ b/src/trio/_tests/test_ssl.py @@ -15,7 +15,7 @@ from trio import StapledStream from trio._tests.pytest_plugin import skip_if_optional_else_raise from trio.abc import ReceiveStream, SendStream -from trio.testing import MemoryReceiveStream, MemorySendStream +from trio.testing import ExpectedExceptionGroup, MemoryReceiveStream, MemorySendStream try: import trustme @@ -345,32 +345,28 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: # PyOpenSSLEchoStream will notice and complain. s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("simultaneous"))): async with _core.open_nursery() as nursery: nursery.start_soon(s.send_all, b"x") nursery.start_soon(s.send_all, b"x") - assert "simultaneous" in str(excinfo.value) s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("simultaneous"))): async with _core.open_nursery() as nursery: nursery.start_soon(s.send_all, b"x") nursery.start_soon(s.wait_send_all_might_not_block) - assert "simultaneous" in str(excinfo.value) s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("simultaneous"))): async with _core.open_nursery() as nursery: nursery.start_soon(s.wait_send_all_might_not_block) nursery.start_soon(s.wait_send_all_might_not_block) - assert "simultaneous" in str(excinfo.value) s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("simultaneous"))): async with _core.open_nursery() as nursery: nursery.start_soon(s.receive_some, 1) nursery.start_soon(s.receive_some, 1) - assert "simultaneous" in str(excinfo.value) @contextmanager # type: ignore[misc] # decorated contains Any @@ -740,32 +736,28 @@ async def do_wait_send_all_might_not_block() -> None: await s.wait_send_all_might_not_block() s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("another task"))): async with _core.open_nursery() as nursery: nursery.start_soon(do_send_all) nursery.start_soon(do_send_all) - assert "another task" in str(excinfo.value) s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("another task"))): async with _core.open_nursery() as nursery: nursery.start_soon(do_receive_some) nursery.start_soon(do_receive_some) - assert "another task" in str(excinfo.value) s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("another task"))): async with _core.open_nursery() as nursery: nursery.start_soon(do_send_all) nursery.start_soon(do_wait_send_all_might_not_block) - assert "another task" in str(excinfo.value) s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("another task"))): async with _core.open_nursery() as nursery: nursery.start_soon(do_wait_send_all_might_not_block) nursery.start_soon(do_wait_send_all_might_not_block) - assert "another task" in str(excinfo.value) async def test_wait_writable_calls_underlying_wait_writable() -> None: diff --git a/src/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py index 93f3d3ac53..5c83573581 100644 --- a/src/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -21,6 +21,8 @@ import pytest from pytest import MonkeyPatch, WarningsRecorder +from trio.testing import ExpectedExceptionGroup + from .. import ( ClosedResourceError, Event, @@ -607,7 +609,7 @@ async def test_warn_on_cancel_SIGKILL_escalation( # the background_process_param exercises a lot of run_process cases, but it uses # check=False, so lets have a test that uses check=True as well async def test_run_process_background_fail() -> None: - with pytest.raises(subprocess.CalledProcessError): + with pytest.raises(ExpectedExceptionGroup(subprocess.CalledProcessError)): async with _core.open_nursery() as nursery: proc: subprocess.CompletedProcess[bytes] = await nursery.start( run_process, EXIT_FALSE @@ -624,11 +626,12 @@ async def test_for_leaking_fds() -> None: await run_process(EXIT_TRUE) assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + # TODO: one of these raises an exceptiongroup, one doesn't. That's bad with pytest.raises(subprocess.CalledProcessError): await run_process(EXIT_FALSE) assert set(SyncPath("/dev/fd").iterdir()) == starting_fds - with pytest.raises(PermissionError): + with pytest.raises(ExpectedExceptionGroup(PermissionError)): await run_process(["/dev/fd/0"]) assert set(SyncPath("/dev/fd").iterdir()) == starting_fds diff --git a/src/trio/_tests/test_testing.py b/src/trio/_tests/test_testing.py index f6e137c55e..1d787f5f49 100644 --- a/src/trio/_tests/test_testing.py +++ b/src/trio/_tests/test_testing.py @@ -7,6 +7,8 @@ import pytest from pytest import WarningsRecorder +from trio.testing import ExpectedExceptionGroup + from .. import _core, sleep, socket as tsocket from .._core._tests.tutil import can_bind_ipv6 from .._highlevel_generic import StapledStream, aclose_forcefully @@ -293,7 +295,7 @@ async def getter(expect: bytes) -> None: nursery.start_soon(putter, b"xyz") # Two gets at the same time -> BusyResourceError - with pytest.raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(getter, b"asdf") nursery.start_soon(getter, b"asdf") @@ -427,7 +429,7 @@ async def do_receive_some(max_bytes: int | None) -> bytes: mrs.put_data(b"abc") assert await do_receive_some(None) == b"abc" - with pytest.raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(do_receive_some, 10) nursery.start_soon(do_receive_some, 10) diff --git a/src/trio/_tests/test_util.py b/src/trio/_tests/test_util.py index 40c2fd11bb..17381dcaf4 100644 --- a/src/trio/_tests/test_util.py +++ b/src/trio/_tests/test_util.py @@ -6,6 +6,7 @@ import pytest import trio +from trio.testing import ExpectedExceptionGroup from .. import _core from .._core._tests.tutil import ( @@ -49,21 +50,19 @@ async def test_ConflictDetector() -> None: with ul2: print("ok") - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(_core.BusyResourceError, match="ul1"): with ul1: with ul1: pass # pragma: no cover - assert "ul1" in str(excinfo.value) async def wait_with_ul1() -> None: with ul1: await wait_all_tasks_blocked() - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError("ul1"))): async with _core.open_nursery() as nursery: nursery.start_soon(wait_with_ul1) nursery.start_soon(wait_with_ul1) - assert "ul1" in str(excinfo.value) def test_module_metadata_is_fixed_up() -> None: diff --git a/src/trio/testing/__init__.py b/src/trio/testing/__init__.py index fa683e1145..d02b6fae19 100644 --- a/src/trio/testing/__init__.py +++ b/src/trio/testing/__init__.py @@ -1,5 +1,7 @@ # Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) +import pytest + from .._core import ( MockClock as MockClock, wait_all_tasks_blocked as wait_all_tasks_blocked, @@ -14,6 +16,7 @@ assert_checkpoints as assert_checkpoints, assert_no_checkpoints as assert_no_checkpoints, ) +from ._exceptiongroup_util import ExpectedExceptionGroup, raises from ._memory_streams import ( MemoryReceiveStream as MemoryReceiveStream, MemorySendStream as MemorySendStream, @@ -29,6 +32,8 @@ ################################################################ +# Highly illegally override `raises` for ours that has ExpectedExceptionGroup support +pytest.raises = raises # type: ignore fixup_module_metadata(__name__, globals()) del fixup_module_metadata diff --git a/src/trio/testing/_check_streams.py b/src/trio/testing/_check_streams.py index 0b9c904275..22e5a427e7 100644 --- a/src/trio/testing/_check_streams.py +++ b/src/trio/testing/_check_streams.py @@ -2,16 +2,19 @@ from __future__ import annotations import random -from contextlib import contextmanager, suppress +from contextlib import suppress from typing import TYPE_CHECKING, Awaitable, Callable, Generic, Tuple, TypeVar +import pytest + +from trio.testing._exceptiongroup_util import ExpectedExceptionGroup + from .. import CancelScope, _core from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully from ._checkpoints import assert_checkpoints if TYPE_CHECKING: - from collections.abc import Generator from types import TracebackType from typing_extensions import ParamSpec, TypeAlias @@ -42,17 +45,6 @@ async def __aexit__( await aclose_forcefully(self._second) -@contextmanager -def _assert_raises(exc: type[BaseException]) -> Generator[None, None, None]: - __tracebackhide__ = True - try: - yield - except exc: - pass - else: - raise AssertionError(f"expected exception: {exc}") - - async def check_one_way_stream( stream_maker: StreamMaker[SendStream, ReceiveStream], clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None, @@ -121,11 +113,11 @@ async def send_empty_then_y() -> None: nursery.start_soon(checked_receive_1, b"2") # max_bytes must be a positive integer - with _assert_raises(ValueError): + with pytest.raises(ValueError): await r.receive_some(-1) - with _assert_raises(ValueError): + with pytest.raises(ValueError): await r.receive_some(0) - with _assert_raises(TypeError): + with pytest.raises(TypeError): await r.receive_some(1.5) # type: ignore[arg-type] # it can also be missing or None async with _core.open_nursery() as nursery: @@ -135,7 +127,7 @@ async def send_empty_then_y() -> None: nursery.start_soon(do_send_all, b"x") assert await do_receive_some(None) == b"x" - with _assert_raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(do_receive_some, 1) nursery.start_soon(do_receive_some, 1) @@ -161,7 +153,7 @@ async def simple_check_wait_send_all_might_not_block( # closing the r side leads to BrokenResourceError on the s side # (eventually) async def expect_broken_stream_on_send() -> None: - with _assert_raises(_core.BrokenResourceError): + with pytest.raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) @@ -170,11 +162,11 @@ async def expect_broken_stream_on_send() -> None: nursery.start_soon(do_aclose, r) # once detected, the stream stays broken - with _assert_raises(_core.BrokenResourceError): + with pytest.raises(_core.BrokenResourceError): await do_send_all(b"x" * 100) # r closed -> ClosedResourceError on the receive side - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_receive_some(4096) # we can close the same stream repeatedly, it's fine @@ -185,15 +177,15 @@ async def expect_broken_stream_on_send() -> None: await do_aclose(s) # now trying to send raises ClosedResourceError - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_send_all(b"x" * 100) # even if it's an empty send - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_send_all(b"") # ditto for wait_send_all_might_not_block - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): with assert_checkpoints(): await s.wait_send_all_might_not_block() @@ -224,17 +216,17 @@ async def receive_send_then_close() -> None: async with _ForceCloseBoth(await stream_maker()) as (s, r): await aclose_forcefully(r) - with _assert_raises(_core.BrokenResourceError): + with pytest.raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_receive_some(4096) async with _ForceCloseBoth(await stream_maker()) as (s, r): await aclose_forcefully(s) - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_send_all(b"123") # after the sender does a forceful close, the receiver might either @@ -253,10 +245,10 @@ async def receive_send_then_close() -> None: scope.cancel() await s.aclose() - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_send_all(b"123") - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await do_receive_some(4096) # Check that we can still gracefully close a stream after an operation has @@ -275,7 +267,7 @@ async def expect_cancelled( *args: ArgsT.args, **kwargs: ArgsT.kwargs, ) -> None: - with _assert_raises(_core.Cancelled): + with pytest.raises(_core.Cancelled): await afn(*args, **kwargs) with _core.CancelScope() as scope: @@ -293,7 +285,7 @@ async def expect_cancelled( async with _ForceCloseBoth(await stream_maker()) as (s, r): async def receive_expecting_closed(): - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await r.receive_some(10) async with _core.open_nursery() as nursery: @@ -333,7 +325,7 @@ async def receiver() -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): # simultaneous wait_send_all_might_not_block fails - with _assert_raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(s.wait_send_all_might_not_block) nursery.start_soon(s.wait_send_all_might_not_block) @@ -342,7 +334,7 @@ async def receiver() -> None: # this test might destroy the stream b/c we end up cancelling # send_all and e.g. SSLStream can't handle that, so we have to # recreate afterwards) - with _assert_raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(s.wait_send_all_might_not_block) nursery.start_soon(s.send_all, b"123") @@ -350,7 +342,7 @@ async def receiver() -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): # send_all and send_all blocked simultaneously should also raise # (but again this might destroy the stream) - with _assert_raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(s.send_all, b"123") nursery.start_soon(s.send_all, b"123") @@ -392,13 +384,13 @@ async def close_soon(s: SendStream) -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): async with _core.open_nursery() as nursery: nursery.start_soon(close_soon, s) - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await s.send_all(b"xyzzy") async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): async with _core.open_nursery() as nursery: nursery.start_soon(close_soon, s) - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await s.wait_send_all_might_not_block() @@ -517,7 +509,7 @@ async def expect_x_then_eof(r: HalfCloseableStream) -> None: nursery.start_soon(expect_x_then_eof, s2) # now sending is disallowed - with _assert_raises(_core.ClosedResourceError): + with pytest.raises(_core.ClosedResourceError): await s1.send_all(b"y") # but we can do send_eof again @@ -532,7 +524,7 @@ async def expect_x_then_eof(r: HalfCloseableStream) -> None: if clogged_stream_maker is not None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2): # send_all and send_eof simultaneously is not ok - with _assert_raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(s1.send_all, b"x") await _core.wait_all_tasks_blocked() @@ -541,7 +533,7 @@ async def expect_x_then_eof(r: HalfCloseableStream) -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2): # wait_send_all_might_not_block and send_eof simultaneously is not # ok either - with _assert_raises(_core.BusyResourceError): + with pytest.raises(ExpectedExceptionGroup(_core.BusyResourceError)): async with _core.open_nursery() as nursery: nursery.start_soon(s1.wait_send_all_might_not_block) await _core.wait_all_tasks_blocked() diff --git a/src/trio/testing/_exceptiongroup_util.py b/src/trio/testing/_exceptiongroup_util.py new file mode 100644 index 0000000000..276dde22fc --- /dev/null +++ b/src/trio/testing/_exceptiongroup_util.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import re +import sys +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Optional, + Pattern, + Tuple, + Type, + TypeVar, + Union, + cast, + final, + overload, +) + +import _pytest +import _pytest._code + +if TYPE_CHECKING: + from typing_extensions import TypeAlias +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + +E = TypeVar("E", bound=BaseException) +E2 = TypeVar("E2", bound=BaseException) +T = TypeVar("T") +EE: TypeAlias = Union[Type[E], "ExpectedExceptionGroup[E]"] +EEE: TypeAlias = Union[E, Type[E], "ExpectedExceptionGroup[E]"] + + +# copy-pasting code between here and WIP pytest PR, which doesn't use future annotations +# ruff: noqa: UP007 + + +# inherit from BaseExceptionGroup for the sake of typing the return type of raises. +# Maybe it also works to do +# `if TYPE_CHECKING: ExpectedExceptionGroup = BaseExceptionGroup` +# though then we would probably need to support `raises(ExceptionGroup(...))` +# class ExpectedExceptionGroup(BaseExceptionGroup[E]): +@final +class ExpectedExceptionGroup(Generic[E]): + # one might consider also accepting `ExpectedExceptionGroup(SyntaxError, ValueError)` + @overload + def __init__(self, exceptions: EEE[E], *args: EEE[E]): + ... + + @overload + def __init__(self, exceptions: tuple[EEE[E], ...]): + ... + + def __init__(self, exceptions: EEE[E] | tuple[EEE[E], ...], *args: EEE[E]): + if isinstance(exceptions, tuple): + if args: + raise ValueError( + "All arguments must be exceptions if passing multiple positional arguments." + ) + self.expected_exceptions = exceptions + else: + self.expected_exceptions = (exceptions, *args) + if not all( + isinstance(exc, (type, BaseException, ExpectedExceptionGroup)) + for exc in self.expected_exceptions + ): + raise ValueError( + "All arguments must be exception instances, types, or ExpectedExceptionGroup." + ) + + # TODO: TypeGuard + def matches( + self, + exc_val: Optional[BaseException], + ) -> bool: + if exc_val is None: + return False + if not isinstance(exc_val, BaseExceptionGroup): + return False + if len(exc_val.exceptions) != len(self.expected_exceptions): + return False + remaining_exceptions = list(self.expected_exceptions) + for e in exc_val.exceptions: + for rem_e in remaining_exceptions: + # TODO: how to print string diff on mismatch? + # Probably accumulate them, and then if fail, print them + if ( + (isinstance(rem_e, type) and isinstance(e, rem_e)) + or ( + isinstance(e, BaseExceptionGroup) + and isinstance(rem_e, ExpectedExceptionGroup) + and rem_e.matches(e) + ) + or ( + isinstance(rem_e, BaseException) + and isinstance(e, type(rem_e)) + and re.search(str(rem_e), str(e)) + ) + ): + remaining_exceptions.remove(rem_e) # type: ignore # ?? + break + else: + return False + return True + + # def __str__(self) -> str: + # return f"ExceptionGroup{self.expected_exceptions}" + # str(tuple(...)) seems to call repr + def __repr__(self) -> str: + # TODO: [Base]ExceptionGroup + return f"ExceptionGroup{self.expected_exceptions}" + + +@overload +def raises( + expected_exception: Union[type[E], tuple[type[E], ...]], + *, + match: Optional[Union[str, Pattern[str]]] = ..., +) -> RaisesContext[E]: + ... + + +@overload +def raises( + expected_exception: Union[ + ExpectedExceptionGroup[E], tuple[ExpectedExceptionGroup[E], ...] + ], + *, + match: Optional[Union[str, Pattern[str]]] = ..., +) -> RaisesContext[BaseExceptionGroup[E]]: + ... + + +@overload +def raises( + expected_exception: tuple[Union[type[E], ExpectedExceptionGroup[E2]], ...], + *, + match: Optional[Union[str, Pattern[str]]] = ..., +) -> RaisesContext[Union[E, BaseExceptionGroup[E2]]]: + ... + + +@overload +def raises( # type: ignore[misc] + expected_exception: Union[type[E], tuple[type[E], ...]], + func: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> _pytest._code.ExceptionInfo[E]: + ... + + +def raises( + expected_exception: Union[ + type[E], + ExpectedExceptionGroup[E2], + tuple[Union[type[E], ExpectedExceptionGroup[E2]], ...], + ], + *args: Any, + **kwargs: Any, +) -> Union[ + RaisesContext[E], + RaisesContext[BaseExceptionGroup[E2]], + RaisesContext[Union[E, BaseExceptionGroup[E2]]], + _pytest._code.ExceptionInfo[E], +]: + r"""Assert that a code block/function call raises ``expected_exception`` + or raise a failure exception otherwise. + + :kwparam match: + If specified, a string containing a regular expression, + or a regular expression object, that is tested against the string + representation of the exception using :py:func:`re.search`. To match a literal + string that may contain :std:ref:`special characters `, the pattern can + first be escaped with :py:func:`re.escape`. + + (This is only used when :py:func:`pytest.raises` is used as a context manager, + and passed through to the function otherwise. + When using :py:func:`pytest.raises` as a function, you can use: + ``pytest.raises(Exc, func, match="passed on").match("my pattern")``.) + + .. currentmodule:: _pytest._code + + Use ``pytest.raises`` as a context manager, which will capture the exception of the given + type:: + + >>> import pytest + >>> with pytest.raises(ZeroDivisionError): + ... 1/0 + + If the code block does not raise the expected exception (``ZeroDivisionError`` in the example + above), or no exception at all, the check will fail instead. + + You can also use the keyword argument ``match`` to assert that the + exception matches a text or regex:: + + >>> with pytest.raises(ValueError, match='must be 0 or None'): + ... raise ValueError("value must be 0 or None") + + >>> with pytest.raises(ValueError, match=r'must be \d+$'): + ... raise ValueError("value must be 42") + + The context manager produces an :class:`ExceptionInfo` object which can be used to inspect the + details of the captured exception:: + + >>> with pytest.raises(ValueError) as exc_info: + ... raise ValueError("value must be 42") + >>> assert exc_info.type is ValueError + >>> assert exc_info.value.args[0] == "value must be 42" + + .. note:: + + When using ``pytest.raises`` as a context manager, it's worthwhile to + note that normal context manager rules apply and that the exception + raised *must* be the final line in the scope of the context manager. + Lines of code after that, within the scope of the context manager will + not be executed. For example:: + + >>> value = 15 + >>> with pytest.raises(ValueError) as exc_info: + ... if value > 10: + ... raise ValueError("value must be <= 10") + ... assert exc_info.type is ValueError # this will not execute + + Instead, the following approach must be taken (note the difference in + scope):: + + >>> with pytest.raises(ValueError) as exc_info: + ... if value > 10: + ... raise ValueError("value must be <= 10") + ... + >>> assert exc_info.type is ValueError + + **Using with** ``pytest.mark.parametrize`` + + When using :ref:`pytest.mark.parametrize ref` + it is possible to parametrize tests such that + some runs raise an exception and others do not. + + See :ref:`parametrizing_conditional_raising` for an example. + + **Legacy form** + + It is possible to specify a callable by passing a to-be-called lambda:: + + >>> raises(ZeroDivisionError, lambda: 1/0) + + + or you can specify an arbitrary callable with arguments:: + + >>> def f(x): return 1/x + ... + >>> raises(ZeroDivisionError, f, 0) + + >>> raises(ZeroDivisionError, f, x=0) + + + The form above is fully supported but discouraged for new code because the + context manager form is regarded as more readable and less error-prone. + + .. note:: + Similar to caught exception objects in Python, explicitly clearing + local references to returned ``ExceptionInfo`` objects can + help the Python interpreter speed up its garbage collection. + + Clearing those references breaks a reference cycle + (``ExceptionInfo`` --> caught exception --> frame stack raising + the exception --> current frame stack --> local variables --> + ``ExceptionInfo``) which makes Python keep all objects referenced + from that cycle (including all local variables in the current + frame) alive until the next cyclic garbage collection run. + More detailed information can be found in the official Python + documentation for :ref:`the try statement `. + """ + __tracebackhide__ = True + + if not expected_exception: + raise ValueError( + f"Expected an exception type or a tuple of exception types, but got `{expected_exception!r}`. " + f"Raising exceptions is already understood as failing the test, so you don't need " + f"any special code to say 'this should never raise an exception'." + ) + + if isinstance(expected_exception, (type, ExpectedExceptionGroup)): + expected_exception_tuple: tuple[ + Union[type[E], ExpectedExceptionGroup[E2]], ... + ] = (expected_exception,) + else: + expected_exception_tuple = expected_exception + for exc in expected_exception_tuple: + if ( + not isinstance(exc, type) or not issubclass(exc, BaseException) + ) and not isinstance(exc, ExpectedExceptionGroup): + msg = "expected exception must be a BaseException type or ExpectedExceptionGroup instance, not {}" # type: ignore[unreachable] + not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__ + raise TypeError(msg.format(not_a)) + + message = f"DID NOT RAISE {expected_exception}" + + if not args: + match: Optional[Union[str, Pattern[str]]] = kwargs.pop("match", None) + if kwargs: + msg = "Unexpected keyword arguments passed to pytest.raises: " + msg += ", ".join(sorted(kwargs)) + msg += "\nUse context-manager form instead?" + raise TypeError(msg) + # the ExpectedExceptionGroup -> BaseExceptionGroup swap necessitates an ignore + return RaisesContext(expected_exception, message, match) # type: ignore[misc] + else: + func = args[0] + + for exc in expected_exception_tuple: + if isinstance(exc, ExpectedExceptionGroup): + raise TypeError( + "Only contextmanager form is supported for ExpectedExceptionGroup" + ) + + if not callable(func): + raise TypeError(f"{func!r} object (type: {type(func)}) must be callable") + try: + func(*args[1:], **kwargs) + except expected_exception as e: # type: ignore[misc] # TypeError raised for any ExpectedExceptionGroup + # We just caught the exception - there is a traceback. + assert e.__traceback__ is not None + return _pytest._code.ExceptionInfo.from_exc_info( + (type(e), e, e.__traceback__) + ) + raise AssertionError(message) + + +@final +class RaisesContext(Generic[E]): + def __init__( + self, + expected_exception: Union[EE[E], tuple[EE[E], ...]], + message: str, + match_expr: Optional[Union[str, Pattern[str]]] = None, + ) -> None: + self.expected_exception = expected_exception + self.message = message + self.match_expr = match_expr + self.excinfo: Optional[_pytest._code.ExceptionInfo[E]] = None + + def __enter__(self) -> _pytest._code.ExceptionInfo[E]: + self.excinfo = _pytest._code.ExceptionInfo.for_later() + return self.excinfo + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + __tracebackhide__ = True + if exc_type is None: + raise AssertionError(self.message) + assert self.excinfo is not None + + if isinstance(self.expected_exception, ExpectedExceptionGroup): + if not self.expected_exception.matches(exc_val): + return False + elif isinstance(self.expected_exception, tuple): + for expected_exc in self.expected_exception: + if ( + isinstance(expected_exc, ExpectedExceptionGroup) + and expected_exc.matches(exc_val) + ) or ( + isinstance(expected_exc, type) + and issubclass(exc_type, expected_exc) + ): + break + else: + return False + elif not issubclass(exc_type, self.expected_exception): + return False + + # Cast to narrow the exception type now that it's verified. + exc_info = cast(Tuple[Type[E], E, TracebackType], (exc_type, exc_val, exc_tb)) + self.excinfo.fill_unfilled(exc_info) + if self.match_expr is not None: + self.excinfo.match(self.match_expr) + return True