Skip to content

Commit

Permalink
Add RaisesGroup, a helper for catching ExceptionGroups in tests (#2898)
Browse files Browse the repository at this point in the history
* Add RaisesGroup, a helper for catching ExceptionGroups in tests
* Added helpers: Matcher and _ExceptionInfo
* Tests and type tests for all of the above
* Rewrite several existing tests to use this helper

---------

Co-authored-by: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
Co-authored-by: Spencer Brown <spencerb21@live.com>
Co-authored-by: EXPLOSION <git@helvetica.moe>
  • Loading branch information
4 people authored Jan 7, 2024
1 parent ec011f4 commit aadd1ea
Show file tree
Hide file tree
Showing 9 changed files with 962 additions and 101 deletions.
13 changes: 13 additions & 0 deletions docs/source/reference-testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,16 @@ Testing checkpoints

.. autofunction:: assert_no_checkpoints
:with:


ExceptionGroup helpers
----------------------

.. autoclass:: RaisesGroup
:members:

.. autoclass:: Matcher
:members:

.. autoclass:: trio.testing._raises_group._ExceptionInfo
:members:
4 changes: 4 additions & 0 deletions newsfragments/2785.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
New helper classes: :class:`~.testing.RaisesGroup` and :class:`~.testing.Matcher`.

In preparation for changing the default of ``strict_exception_groups`` to `True`, we're introducing a set of helper classes that can be used in place of `pytest.raises <https://docs.pytest.org/en/stable/reference/reference.html#pytest.raises>`_ in tests, to check for an expected `ExceptionGroup`.
These are provisional, and only planned to be supplied until there's a good solution in ``pytest``. See https://github.com/pytest-dev/pytest/issues/11538
165 changes: 67 additions & 98 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
from ... import _core
from ..._threads import to_thread_run_sync
from ..._timeouts import fail_after, sleep
from ...testing import Sequencer, assert_checkpoints, wait_all_tasks_blocked
from ...testing import (
Matcher,
RaisesGroup,
Sequencer,
assert_checkpoints,
wait_all_tasks_blocked,
)
from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD
from .tutil import (
check_sequence_matches,
Expand Down Expand Up @@ -192,13 +198,8 @@ async def main() -> NoReturn:
nursery.start_soon(crasher)
raise KeyError

with pytest.raises(ExceptionGroup) as excinfo:
with RaisesGroup(ValueError, KeyError):
_core.run(main)
print(excinfo.value)
assert {type(exc) for exc in excinfo.value.exceptions} == {
ValueError,
KeyError,
}


def test_two_child_crashes() -> None:
Expand All @@ -210,12 +211,8 @@ async def main() -> None:
nursery.start_soon(crasher, KeyError)
nursery.start_soon(crasher, ValueError)

with pytest.raises(ExceptionGroup) as excinfo:
with RaisesGroup(ValueError, KeyError):
_core.run(main)
assert {type(exc) for exc in excinfo.value.exceptions} == {
ValueError,
KeyError,
}


async def test_child_crash_wakes_parent() -> None:
Expand Down Expand Up @@ -429,16 +426,18 @@ async def test_cancel_scope_exceptiongroup_filtering() -> None:
async def crasher() -> NoReturn:
raise KeyError

# check that the inner except is properly executed.
# alternative would be to have a `except BaseException` and an `else`
exception_group_caught_inner = False

# This is outside the outer scope, so all the Cancelled
# exceptions should have been absorbed, leaving just a regular
# KeyError from crasher()
with pytest.raises(KeyError): # noqa: PT012
with _core.CancelScope() as outer:
try:
# Since the outer scope became cancelled before the
# nursery block exited, all cancellations inside the
# nursery block continue propagating to reach the
# outer scope.
with RaisesGroup(
_core.Cancelled, _core.Cancelled, _core.Cancelled, KeyError
) as excinfo:
async with _core.open_nursery() as nursery:
# Two children that get cancelled by the nursery scope
nursery.start_soon(sleep_forever) # t1
Expand All @@ -452,22 +451,9 @@ async def crasher() -> NoReturn:
# And one that raises a different error
nursery.start_soon(crasher) # t4
# and then our __aexit__ also receives an outer Cancelled
except BaseExceptionGroup as multi_exc:
exception_group_caught_inner = True
# Since the outer scope became cancelled before the
# nursery block exited, all cancellations inside the
# nursery block continue propagating to reach the
# outer scope.
# the noqa is for "Found assertion on exception `multi_exc` in `except` block"
assert len(multi_exc.exceptions) == 4 # noqa: PT017
summary: dict[type, int] = {}
for exc in multi_exc.exceptions:
summary.setdefault(type(exc), 0)
summary[type(exc)] += 1
assert summary == {_core.Cancelled: 3, KeyError: 1}
raise

assert exception_group_caught_inner
# reraise the exception caught by RaisesGroup for the
# CancelScope to handle
raise excinfo.value


async def test_precancelled_task() -> None:
Expand Down Expand Up @@ -788,14 +774,22 @@ async def task2() -> None:
RuntimeError, match="which had already been exited"
) as exc_info:
await nursery_mgr.__aexit__(*sys.exc_info())
assert type(exc_info.value.__context__) is ExceptionGroup
assert len(exc_info.value.__context__.exceptions) == 3
cancelled_in_context = False
for exc in exc_info.value.__context__.exceptions:
assert isinstance(exc, RuntimeError)
assert "closed before the task exited" in str(exc)
cancelled_in_context |= isinstance(exc.__context__, _core.Cancelled)
assert cancelled_in_context # for the sleep_forever

def no_context(exc: RuntimeError) -> bool:
return exc.__context__ is None

msg = "closed before the task exited"
group = RaisesGroup(
Matcher(RuntimeError, match=msg, check=no_context),
Matcher(RuntimeError, match=msg, check=no_context),
# sleep_forever
Matcher(
RuntimeError,
match=msg,
check=lambda x: isinstance(x.__context__, _core.Cancelled),
),
)
assert group.matches(exc_info.value.__context__)

# Trying to exit a cancel scope from an unrelated task raises an error
# without affecting any state
Expand Down Expand Up @@ -949,11 +943,7 @@ async def main() -> None:
with pytest.raises(_core.TrioInternalError) as excinfo:
_core.run(main)

me = excinfo.value.__cause__
assert isinstance(me, ExceptionGroup)
assert len(me.exceptions) == 2
for exc in me.exceptions:
assert isinstance(exc, (KeyError, ValueError))
assert RaisesGroup(KeyError, ValueError).matches(excinfo.value.__cause__)


def test_system_task_crash_plus_Cancelled() -> None:
Expand Down Expand Up @@ -1210,12 +1200,11 @@ async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None:
async def crasher() -> NoReturn:
raise KeyError

with pytest.raises(ExceptionGroup) as excinfo: # noqa: PT012
# the ExceptionGroup should not have the KeyError or ValueError as context
with RaisesGroup(ValueError, KeyError, check=lambda x: x.__context__ is None):
async with _core.open_nursery() as nursery:
nursery.start_soon(crasher)
raise ValueError
# the ExceptionGroup should not have the KeyError or ValueError as context
assert excinfo.value.__context__ is None


def test_TrioToken_identity() -> None:
Expand Down Expand Up @@ -1980,11 +1969,10 @@ async def test_nursery_stop_iteration() -> None:
async def fail() -> NoReturn:
raise ValueError

with pytest.raises(ExceptionGroup) as excinfo: # noqa: PT012
with RaisesGroup(StopIteration, ValueError):
async with _core.open_nursery() as nursery:
nursery.start_soon(fail)
raise StopIteration
assert tuple(map(type, excinfo.value.exceptions)) == (StopIteration, ValueError)


async def test_nursery_stop_async_iteration() -> None:
Expand Down Expand Up @@ -2033,7 +2021,18 @@ async def test_traceback_frame_removal() -> None:
async def my_child_task() -> NoReturn:
raise KeyError()

with pytest.raises(ExceptionGroup) as excinfo: # noqa: PT012
def check_traceback(exc: KeyError) -> bool:
# The top frame in the exception traceback should be inside the child
# task, not trio/contextvars internals. And there's only one frame
# inside the child task, so this will also detect if our frame-removal
# is too eager.
tb = exc.__traceback__
assert tb is not None
return tb.tb_frame.f_code is my_child_task.__code__

expected_exception = Matcher(KeyError, check=check_traceback)

with RaisesGroup(expected_exception, expected_exception):
# Trick: For now cancel/nursery scopes still leave a bunch of tb gunk
# behind. But if there's an ExceptionGroup, they leave it on the group,
# which lets us get a clean look at the KeyError itself. Someday I
Expand All @@ -2042,15 +2041,6 @@ async def my_child_task() -> NoReturn:
async with _core.open_nursery() as nursery:
nursery.start_soon(my_child_task)
nursery.start_soon(my_child_task)
first_exc = excinfo.value.exceptions[0]
assert isinstance(first_exc, KeyError)
# The top frame in the exception traceback should be inside the child
# task, not trio/contextvars internals. And there's only one frame
# inside the child task, so this will also detect if our frame-removal
# is too eager.
tb = first_exc.__traceback__
assert tb is not None
assert tb.tb_frame.f_code is my_child_task.__code__


def test_contextvar_support() -> None:
Expand Down Expand Up @@ -2529,15 +2519,12 @@ async def main() -> NoReturn:
async with _core.open_nursery():
raise Exception("foo")

with pytest.raises(
ExceptionGroup, match="^Exceptions from Trio nursery \\(1 sub-exception\\)$"
) as exc:
with RaisesGroup(
Matcher(Exception, match="^foo$"),
match="^Exceptions from Trio nursery \\(1 sub-exception\\)$",
):
_core.run(main, strict_exception_groups=True)

assert len(exc.value.exceptions) == 1
assert type(exc.value.exceptions[0]) is Exception
assert exc.value.exceptions[0].args == ("foo",)


def test_run_strict_exception_groups_nursery_override() -> None:
"""
Expand All @@ -2555,14 +2542,10 @@ async def main() -> NoReturn:

async def test_nursery_strict_exception_groups() -> None:
"""Test that strict exception groups can be enabled on a per-nursery basis."""
with pytest.raises(ExceptionGroup) as exc:
with RaisesGroup(Matcher(Exception, match="^foo$")):
async with _core.open_nursery(strict_exception_groups=True):
raise Exception("foo")

assert len(exc.value.exceptions) == 1
assert type(exc.value.exceptions[0]) is Exception
assert exc.value.exceptions[0].args == ("foo",)


async def test_nursery_loose_exception_groups() -> None:
"""Test that loose exception groups can be enabled on a per-nursery basis."""
Expand All @@ -2573,20 +2556,18 @@ async def raise_error() -> NoReturn:
with pytest.raises(RuntimeError, match="^test error$"):
async with _core.open_nursery(strict_exception_groups=False) as nursery:
nursery.start_soon(raise_error)

with pytest.raises( # noqa: PT012 # multiple statements
ExceptionGroup, match="^Exceptions from Trio nursery \\(2 sub-exceptions\\)$"
) as exc:
m = Matcher(RuntimeError, match="^test error$")

with RaisesGroup(
m,
m,
match="Exceptions from Trio nursery \\(2 sub-exceptions\\)",
check=lambda x: x.__notes__ == [_core._run.NONSTRICT_EXCEPTIONGROUP_NOTE],
):
async with _core.open_nursery(strict_exception_groups=False) as nursery:
nursery.start_soon(raise_error)
nursery.start_soon(raise_error)

assert exc.value.__notes__ == [_core._run.NONSTRICT_EXCEPTIONGROUP_NOTE]
assert len(exc.value.exceptions) == 2
for subexc in exc.value.exceptions:
assert type(subexc) is RuntimeError
assert subexc.args == ("test error",)


async def test_nursery_collapse_strict() -> None:
"""
Expand All @@ -2597,7 +2578,7 @@ async def test_nursery_collapse_strict() -> None:
async def raise_error() -> NoReturn:
raise RuntimeError("test error")

with pytest.raises(ExceptionGroup) as exc: # noqa: PT012
with RaisesGroup(RuntimeError, RaisesGroup(RuntimeError)):
async with _core.open_nursery() as nursery:
nursery.start_soon(sleep_forever)
nursery.start_soon(raise_error)
Expand All @@ -2606,13 +2587,6 @@ async def raise_error() -> NoReturn:
nursery2.start_soon(raise_error)
nursery.cancel_scope.cancel()

exceptions = exc.value.exceptions
assert len(exceptions) == 2
assert isinstance(exceptions[0], RuntimeError)
assert isinstance(exceptions[1], ExceptionGroup)
assert len(exceptions[1].exceptions) == 1
assert isinstance(exceptions[1].exceptions[0], RuntimeError)


async def test_nursery_collapse_loose() -> None:
"""
Expand All @@ -2623,7 +2597,7 @@ async def test_nursery_collapse_loose() -> None:
async def raise_error() -> NoReturn:
raise RuntimeError("test error")

with pytest.raises(ExceptionGroup) as exc: # noqa: PT012
with RaisesGroup(RuntimeError, RuntimeError):
async with _core.open_nursery() as nursery:
nursery.start_soon(sleep_forever)
nursery.start_soon(raise_error)
Expand All @@ -2632,19 +2606,14 @@ async def raise_error() -> NoReturn:
nursery2.start_soon(raise_error)
nursery.cancel_scope.cancel()

exceptions = exc.value.exceptions
assert len(exceptions) == 2
assert isinstance(exceptions[0], RuntimeError)
assert isinstance(exceptions[1], RuntimeError)


async def test_cancel_scope_no_cancellederror() -> None:
"""
Test that when a cancel scope encounters an exception group that does NOT contain
a Cancelled exception, it will NOT set the ``cancelled_caught`` flag.
"""

with pytest.raises(ExceptionGroup): # noqa: PT012
with RaisesGroup(RuntimeError, RuntimeError, match="test"):
with _core.CancelScope() as scope:
scope.cancel()
raise ExceptionGroup("test", [RuntimeError(), RuntimeError()])
Expand Down
13 changes: 12 additions & 1 deletion src/trio/_tests/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ def lookup_symbol(symbol: str) -> dict[str, str]:
if module_name == "trio.socket" and class_name in dir(stdlib_socket):
continue

# ignore class that does dirty tricks
if class_ is trio.testing.RaisesGroup:
continue

# dir() and inspect.getmembers doesn't display properties from the metaclass
# also ignore some dunder methods that tend to differ but are of no consequence
ignore_names = set(dir(type(class_))) | {
Expand Down Expand Up @@ -429,7 +433,9 @@ def lookup_symbol(symbol: str) -> dict[str, str]:
if tool == "mypy" and class_ == trio.Nursery:
extra.remove("cancel_scope")

# TODO: I'm not so sure about these, but should still be looked at.
# These are (mostly? solely?) *runtime* attributes, often set in
# __init__, which doesn't show up with dir() or inspect.getmembers,
# but we get them in the way we query mypy & jedi
EXTRAS = {
trio.DTLSChannel: {"peer_address", "endpoint"},
trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
Expand All @@ -444,6 +450,11 @@ def lookup_symbol(symbol: str) -> dict[str, str]:
"send_all_hook",
"wait_send_all_might_not_block_hook",
},
trio.testing.Matcher: {
"exception_type",
"match",
"check",
},
}
if tool == "mypy" and class_ in EXTRAS:
before = len(extra)
Expand Down
9 changes: 7 additions & 2 deletions src/trio/_tests/test_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
reorder_for_rfc_6555_section_5_4,
)
from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
from trio.testing import Matcher, RaisesGroup

if TYPE_CHECKING:
from trio.testing import MockClock
Expand Down Expand Up @@ -530,8 +531,12 @@ async def test_all_fail(autojump_clock: MockClock) -> None:
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert isinstance(exc.__cause__, BaseExceptionGroup)
assert len(exc.__cause__.exceptions) == 4

subexceptions = (Matcher(OSError, match="^sorry$"),) * 4
assert RaisesGroup(
*subexceptions, match="all attempts to connect to test.example.com:80 failed"
).matches(exc.__cause__)

assert trio.current_time() == (0.1 + 0.2 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
Expand Down
Loading

0 comments on commit aadd1ea

Please sign in to comment.