Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trio.testing.wait_all_threads_completed #2937

Merged
merged 3 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/reference-testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Inter-task ordering

.. autofunction:: wait_all_tasks_blocked

.. autofunction:: wait_all_threads_completed

.. autofunction:: active_thread_count


.. _testing-streams:

Expand Down
1 change: 1 addition & 0 deletions newsfragments/2937.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `trio.testing.wait_all_threads_completed`, which blocks until no threads are running tasks. This is intended to be used in the same way as `trio.testing.wait_all_tasks_blocked`.
49 changes: 49 additions & 0 deletions src/trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import slow
from .._threads import (
active_thread_count,
current_default_thread_limiter,
from_thread_check_cancelled,
from_thread_run,
from_thread_run_sync,
to_thread_run_sync,
wait_all_threads_completed,
)
from ..testing import wait_all_tasks_blocked

Expand Down Expand Up @@ -1114,3 +1116,50 @@ async def test_cancellable_warns() -> None:

with pytest.warns(TrioDeprecationWarning):
await to_thread_run_sync(bool, cancellable=True)


async def test_wait_all_threads_completed() -> None:
no_threads_left = False
e1 = Event()
e2 = Event()

e1_exited = Event()
e2_exited = Event()

async def wait_event(e: Event, e_exit: Event) -> None:
def thread() -> None:
from_thread_run(e.wait)

await to_thread_run_sync(thread)
e_exit.set()

async def wait_no_threads_left() -> None:
nonlocal no_threads_left
await wait_all_threads_completed()
no_threads_left = True

async with _core.open_nursery() as nursery:
nursery.start_soon(wait_event, e1, e1_exited)
nursery.start_soon(wait_event, e2, e2_exited)
await wait_all_tasks_blocked()
nursery.start_soon(wait_no_threads_left)
await wait_all_tasks_blocked()
assert not no_threads_left
assert active_thread_count() == 2

e1.set()
await e1_exited.wait()
await wait_all_tasks_blocked()
assert not no_threads_left
assert active_thread_count() == 1

e2.set()
await e2_exited.wait()
await wait_all_tasks_blocked()
assert no_threads_left
assert active_thread_count() == 0


async def test_wait_all_threads_completed_no_threads() -> None:
await wait_all_threads_completed()
assert active_thread_count() == 0
137 changes: 103 additions & 34 deletions src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,25 @@

import attr
import outcome
from attrs import define
from sniffio import current_async_library_cvar

import trio

from ._core import (
RunVar,
TrioToken,
checkpoint,
disable_ki_protection,
enable_ki_protection,
start_thread_soon,
)
from ._deprecate import warn_deprecated
from ._sync import CapacityLimiter
from ._sync import CapacityLimiter, Event
from ._util import coroutine_or_error

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Generator

from trio._core._traps import RaiseCancelT

Expand All @@ -52,6 +54,72 @@ class _ParentTaskData(threading.local):
_thread_counter = count()


@define
class _ActiveThreadCount:
count: int
event: Event


_active_threads_local: RunVar[_ActiveThreadCount] = RunVar("active_threads")


@contextlib.contextmanager
def _track_active_thread() -> Generator[None, None, None]:
try:
active_threads_local = _active_threads_local.get()
except LookupError:
active_threads_local = _ActiveThreadCount(0, Event())
_active_threads_local.set(active_threads_local)

active_threads_local.count += 1
try:
yield
finally:
active_threads_local.count -= 1
if active_threads_local.count == 0:
active_threads_local.event.set()
active_threads_local.event = Event()


async def wait_all_threads_completed() -> None:
"""Wait until no threads are still running tasks.

This is intended to be used when testing code with trio.to_thread to
make sure no tasks are still making progress in a thread. See the
following code for a usage example::

async def wait_all_settled():
while True:
await trio.testing.wait_all_threads_complete()
await trio.testing.wait_all_tasks_blocked()
if trio.testing.active_thread_count() == 0:
break
"""

await checkpoint()

try:
active_threads_local = _active_threads_local.get()
except LookupError:
# If there would have been active threads, the
# _active_threads_local would have been set
return

while active_threads_local.count != 0:
await active_threads_local.event.wait()
Comment on lines +80 to +109
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanting to be sure, I know event wait call is doing what we want and will block until thread count is zero and it's set, but is there any way with race conditions that active_threads_local.event is reset to a new lock before the wait call sees that? Would it be beneficial to move the event resetting to before the active_threads_local.count += 1 line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh. Yeah shouldn't it maybe be something like this:

    try:
        active_threads_local = _active_threads_local.get()
    except LookupError:
        active_threads_local = _ActiveThreadCount(1, Event())
        _active_threads_local.set(active_threads_local)
    else:
        active_threads_local.count += 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, probably something like that and adding a part in the else block there where it resets the event object if the event it has has already fired.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand where a race could arise here. This is all happening in the main thread so everything between awaits is effectively atomic.

I have similar code in a different project and it hasn't given me trouble... Although speaking of that, it would be good to assert active_threads_local.count >= 0 after decrementing it, because if a negative number sneaks in it'd be better to fail fast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think we have to care about thread safety here. Rather, the only times _active_threads_local can be changed (that we worry about) is at an await point.

After all to_thread_run_sync can be only run from the main thread (that has your trio event loop).



def active_thread_count() -> int:
"""Returns the number of threads that are currently running a task

See `trio.testing.wait_all_threads_completed`
"""
try:
return _active_threads_local.get().count
except LookupError:
return 0


def current_default_thread_limiter() -> CapacityLimiter:
"""Get the default `~trio.CapacityLimiter` used by
`trio.to_thread.run_sync`.
Expand Down Expand Up @@ -373,39 +441,40 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None:
current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result)

await limiter.acquire_on_behalf_of(placeholder)
try:
start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name)
except:
limiter.release_on_behalf_of(placeholder)
raise

def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
# fill so from_thread_check_cancelled can raise
cancel_register[0] = raise_cancel
if abandon_bool:
# empty so report_back_in_trio_thread_fn cannot reschedule
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
else:
return trio.lowlevel.Abort.FAILED

while True:
# wait_task_rescheduled return value cannot be typed
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = (
await trio.lowlevel.wait_task_rescheduled(abort)
)
if isinstance(msg_from_thread, outcome.Outcome):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
"trio.to_thread.run_sync received unrecognized thread message {!r}."
"".format(msg_from_thread)
with _track_active_thread():
try:
start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name)
except:
limiter.release_on_behalf_of(placeholder)
raise

def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
# fill so from_thread_check_cancelled can raise
cancel_register[0] = raise_cancel
if abandon_bool:
# empty so report_back_in_trio_thread_fn cannot reschedule
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
else:
return trio.lowlevel.Abort.FAILED

while True:
# wait_task_rescheduled return value cannot be typed
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = (
await trio.lowlevel.wait_task_rescheduled(abort)
)
del msg_from_thread
if isinstance(msg_from_thread, outcome.Outcome):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
"trio.to_thread.run_sync received unrecognized thread message {!r}."
"".format(msg_from_thread)
)
del msg_from_thread


def from_thread_check_cancelled() -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/trio/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
MockClock as MockClock,
wait_all_tasks_blocked as wait_all_tasks_blocked,
)
from .._threads import (
active_thread_count as active_thread_count,
wait_all_threads_completed as wait_all_threads_completed,
)
from .._util import fixup_module_metadata
from ._check_streams import (
check_half_closeable_stream as check_half_closeable_stream,
Expand Down
Loading