diff --git a/docs-requirements.in b/docs-requirements.in
index 4acf894..1b08fce 100644
--- a/docs-requirements.in
+++ b/docs-requirements.in
@@ -12,3 +12,4 @@ towncrier != 19.9.0,!= 21.3.0
trio >= 0.22.0
outcome >= 1.1.0
pytest >= 7.2.0
+pytest_timeout
diff --git a/docs-requirements.txt b/docs-requirements.txt
index ae0c01d..7cbc706 100644
--- a/docs-requirements.txt
+++ b/docs-requirements.txt
@@ -1,6 +1,6 @@
#
-# This file is autogenerated by pip-compile with python 3.8
-# To update, run:
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
#
# pip-compile docs-requirements.in
#
@@ -40,8 +40,6 @@ idna==3.4
# trio
imagesize==1.4.1
# via sphinx
-importlib-metadata==5.0.0
- # via sphinx
incremental==22.10.0
# via towncrier
iniconfig==1.1.1
@@ -67,6 +65,10 @@ pygments==2.13.0
pyparsing==3.0.9
# via packaging
pytest==7.2.0
+ # via
+ # -r docs-requirements.in
+ # pytest-timeout
+pytest-timeout==2.1.0
# via -r docs-requirements.in
pytz==2022.5
# via babel
@@ -109,8 +111,6 @@ trio==0.22.0
# via -r docs-requirements.in
urllib3==1.26.12
# via requests
-zipp==3.10.0
- # via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/docs/source/index.rst b/docs/source/index.rst
index fd29f82..28173ea 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -33,6 +33,8 @@ and async I/O in Python. Features include:
`__ library, so your async tests can use
property-based testing: just use ``@given`` like you're used to.
+* Integration with `pytest-timeout `
+
* Support for testing projects that use Trio exclusively and want to
use pytest-trio everywhere, and also for testing projects that
support multiple async libraries and only want to enable
diff --git a/docs/source/reference.rst b/docs/source/reference.rst
index 351e8c1..7ae3be3 100644
--- a/docs/source/reference.rst
+++ b/docs/source/reference.rst
@@ -420,3 +420,28 @@ it can be passed directly to the marker.
@pytest.mark.trio(run=qtrio.run)
async def test():
assert True
+
+
+Configuring timeouts with pytest-timeout
+----------------------------------------
+
+Timeouts can be configured using the ``@pytest.mark.timeout`` decorator.
+
+.. code-block:: python
+
+ import pytest
+ import trio
+
+ @pytest.mark.timeout(10)
+ async def test():
+ await trio.sleep_forever() # will error after 10 seconds
+
+To get clean stacktraces that cover all tasks running when the timeout was triggered, enable the ``trio_timeout`` option.
+
+.. code-block:: ini
+
+ # pytest.ini
+ [pytest]
+ trio_timeout = true
+
+This timeout method requires a functioning loop, and hence will not be triggered if your test doesn't yield to the loop. This typically occurs when the test is stuck on some non-async piece of code.
diff --git a/newsfragments/53.feature.rst b/newsfragments/53.feature.rst
new file mode 100644
index 0000000..b1247b2
--- /dev/null
+++ b/newsfragments/53.feature.rst
@@ -0,0 +1 @@
+Add support for pytest-timeout using our own timeout method. This timeout method can be enable via the option ``trio_timeout`` in ``pytest.ini`` and will print structured tracebacks of all tasks running when the timeout happened.
diff --git a/pytest_trio/_tests/test_basic.py b/pytest_trio/_tests/test_basic.py
index f95538f..6020ff0 100644
--- a/pytest_trio/_tests/test_basic.py
+++ b/pytest_trio/_tests/test_basic.py
@@ -1,4 +1,6 @@
+import functools
import pytest
+import trio
def test_async_test_is_executed(testdir):
@@ -73,15 +75,17 @@ def test_invalid():
result.assert_outcomes(errors=1)
-def test_skip_and_xfail(testdir):
+def test_skip_and_xfail(testdir, monkeypatch):
+ monkeypatch.setattr(
+ trio, "run", functools.partial(trio.run, strict_exception_groups=True)
+ )
+
testdir.makepyfile(
"""
import functools
import pytest
import trio
- trio.run = functools.partial(trio.run, strict_exception_groups=True)
-
@pytest.mark.trio
async def test_xfail():
pytest.xfail()
diff --git a/pytest_trio/_tests/test_timeout.py b/pytest_trio/_tests/test_timeout.py
new file mode 100644
index 0000000..473ba77
--- /dev/null
+++ b/pytest_trio/_tests/test_timeout.py
@@ -0,0 +1,50 @@
+import trio
+import functools
+
+
+def test_timeout(testdir):
+ testdir.makepyfile(
+ """
+ from trio import sleep
+ import pytest
+ import pytest_trio.timeout
+
+ @pytest.mark.timeout(0.01)
+ @pytest.mark.trio
+ async def test_will_timeout():
+ await sleep(10)
+ """
+ )
+
+ testdir.makefile(".ini", pytest="[pytest]\ntrio_timeout=true\n")
+
+ result = testdir.runpytest()
+
+ result.stdout.fnmatch_lines(["Timeout reached"])
+ result.assert_outcomes(failed=1)
+
+
+def test_timeout_strict_exception_group(testdir, monkeypatch):
+ monkeypatch.setattr(
+ trio, "run", functools.partial(trio.run, strict_exception_groups=True)
+ )
+
+ testdir.makepyfile(
+ """
+ from trio import sleep
+ import pytest
+ import pytest_trio.timeout
+
+ @pytest.mark.timeout(0.01)
+ @pytest.mark.trio
+ async def test_will_timeout():
+ await sleep(10)
+ """
+ )
+
+ testdir.makefile(".ini", pytest="[pytest]\ntrio_timeout=true\n")
+
+ result = testdir.runpytest()
+
+ result.stdout.fnmatch_lines(["Timeout reached"])
+ result.assert_outcomes(failed=1)
diff --git a/pytest_trio/plugin.py b/pytest_trio/plugin.py
index 1a56a83..4eeeb45 100644
--- a/pytest_trio/plugin.py
+++ b/pytest_trio/plugin.py
@@ -1,4 +1,5 @@
"""pytest-trio implementation."""
+from __future__ import annotations
import sys
from functools import wraps, partial
from collections.abc import Coroutine, Generator
@@ -12,6 +13,13 @@
from trio.testing import MockClock
from _pytest.outcomes import Skipped, XFailed
+# pytest_timeout_set_timer needs to be imported here for pluggy
+from .timeout import (
+ set_timeout,
+ TimeoutTriggeredException,
+ pytest_timeout_set_timer as pytest_timeout_set_timer,
+)
+
if sys.version_info[:2] < (3, 11):
from exceptiongroup import BaseExceptionGroup
@@ -41,6 +49,12 @@ def pytest_addoption(parser):
type="bool",
default=False,
)
+ parser.addini(
+ "trio_timeout",
+ "should pytest-trio handle timeouts on async functions?",
+ type="bool",
+ default=False,
+ )
parser.addini(
"trio_run",
"what runner should pytest-trio use? [trio, qtrio]",
@@ -353,6 +367,8 @@ def wrapper(**kwargs):
ex = queue.pop()
if isinstance(ex, BaseExceptionGroup):
queue.extend(ex.exceptions)
+ elif isinstance(ex, TimeoutTriggeredException):
+ pytest.fail(str(ex), pytrace=False)
else:
leaves.append(ex)
if len(leaves) == 1:
@@ -363,6 +379,8 @@ def wrapper(**kwargs):
# Since our leaf exceptions don't consist of exactly one 'magic'
# skipped or xfailed exception, re-raise the whole group.
raise
+ except TimeoutTriggeredException as ex:
+ pytest.fail(str(ex), pytrace=False)
return wrapper
@@ -404,6 +422,9 @@ async def _bootstrap_fixtures_and_run_test(**kwargs):
contextvars_ctx = contextvars.copy_context()
contextvars_ctx.run(canary.set, "in correct context")
+ if item is not None:
+ set_timeout(item)
+
async with trio.open_nursery() as nursery:
for fixture in test.register_and_collect_dependencies():
nursery.start_soon(
diff --git a/pytest_trio/timeout.py b/pytest_trio/timeout.py
new file mode 100644
index 0000000..863c918
--- /dev/null
+++ b/pytest_trio/timeout.py
@@ -0,0 +1,109 @@
+from __future__ import annotations
+from typing import Optional
+import warnings
+import threading
+import trio
+import pytest
+import pytest_timeout
+from .traceback_format import format_recursive_nursery_stack
+
+
+pytest_timeout_settings = pytest.StashKey[pytest_timeout.Settings]()
+send_timeout_callable = None
+send_timeout_callable_ready_event = threading.Event()
+
+
+def set_timeout(item: pytest.Item) -> None:
+ try:
+ settings = item.stash[pytest_timeout_settings]
+ except KeyError:
+ # No timeout or not our timeout
+ return
+
+ if settings.func_only:
+ warnings.warn(
+ "Function only timeouts are not supported for trio based timeouts"
+ )
+
+ global send_timeout_callable
+
+ # Shouldn't be racy, as xdist uses different processes
+ if send_timeout_callable is None:
+ threading.Thread(target=trio_timeout_thread, daemon=True).start()
+
+ send_timeout_callable_ready_event.wait()
+
+ send_timeout_callable(settings.timeout)
+
+
+@pytest.hookimpl()
+def pytest_timeout_set_timer(
+ item: pytest.Item, settings: pytest_timeout.Settings
+) -> Optional[bool]:
+ if item.get_closest_marker("trio") is not None and item.config.getini(
+ "trio_timeout"
+ ):
+ item.stash[pytest_timeout_settings] = settings
+ return True
+
+
+# No need for pytest_timeout_cancel_timer as we detect that the test loop has exited
+
+
+class TimeoutTriggeredException(Exception):
+ pass
+
+
+def trio_timeout_thread():
+ async def run_timeouts():
+ async with trio.open_nursery() as nursery:
+ token = trio.lowlevel.current_trio_token()
+
+ async def wait_timeout(token: trio.TrioToken, timeout: float) -> None:
+ await trio.sleep(timeout)
+
+ try:
+ token.run_sync_soon(
+ lambda: trio.lowlevel.spawn_system_task(execute_timeout)
+ )
+ except RuntimeError:
+ # test has finished
+ pass
+
+ def send_timeout(timeout: float):
+ test_token = trio.lowlevel.current_trio_token()
+ token.run_sync_soon(
+ lambda: nursery.start_soon(wait_timeout, test_token, timeout)
+ )
+
+ global send_timeout_callable
+ send_timeout_callable = send_timeout
+ send_timeout_callable_ready_event.set()
+
+ await trio.sleep_forever()
+
+ trio.run(run_timeouts)
+
+
+async def execute_timeout() -> None:
+ if pytest_timeout.is_debugging():
+ return
+
+ nursery = get_test_nursery()
+ stack = "\n".join(format_recursive_nursery_stack(nursery) + ["Timeout reached"])
+
+ async def report():
+ raise TimeoutTriggeredException(stack)
+
+ nursery.start_soon(report)
+
+
+def get_test_nursery() -> trio.Nursery:
+ task = trio.lowlevel.current_task().parent_nursery.parent_task
+
+ for nursery in task.child_nurseries:
+ for task in nursery.child_tasks:
+ if task.name.startswith("pytest_trio.plugin._trio_test_runner_factory"):
+ return task.child_nurseries[0]
+
+ raise Exception("Could not find test nursery")
diff --git a/pytest_trio/traceback_format.py b/pytest_trio/traceback_format.py
new file mode 100644
index 0000000..eb4a962
--- /dev/null
+++ b/pytest_trio/traceback_format.py
@@ -0,0 +1,70 @@
+from __future__ import annotations
+from trio.lowlevel import Task
+from itertools import chain
+import traceback
+
+
+def format_stack_for_task(task: Task, prefix: str) -> list[str]:
+ stack = list(task.iter_await_frames())
+
+ nursery_waiting_children = False
+
+ for i, (frame, lineno) in enumerate(stack):
+ if frame.f_code.co_name == "_nested_child_finished":
+ stack = stack[: i - 1]
+ nursery_waiting_children = True
+ break
+ if frame.f_code.co_name == "wait_task_rescheduled":
+ stack = stack[:i]
+ break
+ if frame.f_code.co_name == "checkpoint":
+ stack = stack[:i]
+ break
+
+ stack = (frame for frame in stack if "__tracebackhide__" not in frame[0].f_locals)
+
+ ss = traceback.StackSummary.extract(stack)
+ formated_traceback = list(
+ map(lambda x: prefix + x[2:], "".join(ss.format()).splitlines())
+ )
+
+ if nursery_waiting_children:
+ formated_traceback.append(prefix + "Awaiting completion of children")
+ formated_traceback.append(prefix)
+
+ return formated_traceback
+
+
+def format_task(task: Task, prefix: str = "") -> list[str]:
+ lines = []
+
+ subtasks = list(
+ chain(*(child_nursery.child_tasks for child_nursery in task.child_nurseries))
+ )
+
+ if subtasks:
+ trace_prefix = prefix + "│"
+ else:
+ trace_prefix = prefix + " "
+
+ lines.extend(format_stack_for_task(task, trace_prefix))
+
+ for i, subtask in enumerate(subtasks):
+ if (i + 1) != len(subtasks):
+ lines.append(f"{prefix}├ {subtask.name}")
+ lines.extend(format_task(subtask, prefix=f"{prefix}│ "))
+ else:
+ lines.append(f"{prefix}└ {subtask.name}")
+ lines.extend(format_task(subtask, prefix=f"{prefix} "))
+
+ return lines
+
+
+def format_recursive_nursery_stack(nursery) -> list[str]:
+ stack = []
+
+ for task in nursery.child_tasks:
+ stack.append(task.name)
+ stack.extend(format_task(task))
+
+ return stack
diff --git a/setup.py b/setup.py
index 9fbb81f..98dfcb4 100644
--- a/setup.py
+++ b/setup.py
@@ -19,6 +19,7 @@
"trio >= 0.22.0", # for ExceptionGroup support
"outcome >= 1.1.0",
"pytest >= 7.2.0", # for ExceptionGroup support
+ "pytest_timeout",
],
keywords=[
"async",