Skip to content

asgiref.sync trio support #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions asgiref/_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
__all__ = [
"get_running_loop",
"create_task_threadsafe",
"wrap_task_context",
"run_in_executor",
]

import asyncio
import concurrent.futures
import contextvars
import functools
import sys
import types
from asyncio import get_running_loop
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, Generic, Protocol, TypeVar, Union

from ._context import restore_context as _restore_context

_R = TypeVar("_R")

Coro = Coroutine[Any, Any, _R]


def create_task_threadsafe(
loop: asyncio.AbstractEventLoop, awaitable: Coro[object]
) -> None:
loop.call_soon_threadsafe(loop.create_task, awaitable)


async def wrap_task_context(
loop: asyncio.AbstractEventLoop,
task_context: list[asyncio.Task[Any]],
awaitable: Awaitable[_R],
) -> _R:
if task_context is None:
return await awaitable

current_task = asyncio.current_task(loop)
if current_task is None:
return await awaitable

task_context.append(current_task)
try:
return await awaitable
finally:
task_context.remove(current_task)


ExcInfo = Union[
tuple[type[BaseException], BaseException, types.TracebackType],
tuple[None, None, None],
]


class ThreadHandlerType(Protocol, Generic[_R]):
def __call__(
self,
loop: asyncio.AbstractEventLoop,
exc_info: ExcInfo,
task_context: list[asyncio.Task[Any]],
func: Callable[[Callable[[], _R]], _R],
child: Callable[[], _R],
) -> _R:
...


async def run_in_executor(
*,
loop: asyncio.AbstractEventLoop,
executor: concurrent.futures.ThreadPoolExecutor,
thread_handler: ThreadHandlerType[_R],
child: Callable[[], _R],
) -> _R:
context = contextvars.copy_context()
func = context.run
task_context: list[asyncio.Task[Any]] = []

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
ret: _R
try:
ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
cancel_parent = True
try:
task = task_context[0]
task.cancel()
try:
await task
cancel_parent = False
except asyncio.CancelledError:
pass
except IndexError:
pass
if exec_coro.done():
raise
if cancel_parent:
exec_coro.cancel()
ret = await exec_coro
finally:
_restore_context(context)

return ret
13 changes: 13 additions & 0 deletions asgiref/_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import contextvars


def restore_context(context: contextvars.Context) -> None:
# Check for changes in contextvars, and set them to the current
# context for downstream consumers
for cvar in context:
cvalue = context.get(cvar)
try:
if cvar.get() != cvalue:
cvar.set(cvalue)
except LookupError:
cvar.set(cvalue)
176 changes: 176 additions & 0 deletions asgiref/_trio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import asyncio
import concurrent.futures
import contextvars
import functools
import sys
import types
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, Generic, Protocol, TypeVar, Union

import sniffio
import trio.lowlevel
import trio.to_thread

from . import _asyncio
from ._context import restore_context as _restore_context

_R = TypeVar("_R")

Coro = Coroutine[Any, Any, _R]

Loop = Union[asyncio.AbstractEventLoop, trio.lowlevel.TrioToken]
TaskContext = list[Any]


class TrioThreadCancelled(BaseException):
pass


def get_running_loop() -> Loop:

try:
asynclib = sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
return asyncio.get_running_loop()

if asynclib == "asyncio":
return asyncio.get_running_loop()
if asynclib == "trio":
return trio.lowlevel.current_trio_token()
raise RuntimeError(f"unsupported library {asynclib}")


@trio.lowlevel.disable_ki_protection
async def wrap_awaitable(awaitable: Awaitable[_R]) -> _R:
return await awaitable


def create_task_threadsafe(loop: Loop, awaitable: Coro[_R]) -> None:
if isinstance(loop, trio.lowlevel.TrioToken):
try:
loop.run_sync_soon(
trio.lowlevel.spawn_system_task,
wrap_awaitable,
awaitable,
)
except trio.RunFinishedError:
raise RuntimeError("trio loop no-longer running")
return

_asyncio.create_task_threadsafe(loop, awaitable)


ExcInfo = Union[
tuple[type[BaseException], BaseException, types.TracebackType],
tuple[None, None, None],
]


class ThreadHandlerType(Protocol, Generic[_R]):
def __call__(
self,
loop: Loop,
exc_info: ExcInfo,
task_context: TaskContext,
func: Callable[[Callable[[], _R]], _R],
child: Callable[[], _R],
) -> _R:
...


async def run_in_executor(
*,
loop: Loop,
executor: concurrent.futures.ThreadPoolExecutor,
thread_handler: ThreadHandlerType[_R],
child: Callable[[], _R],
) -> _R:
if isinstance(loop, trio.lowlevel.TrioToken):
context = contextvars.copy_context()
func = context.run
task_context: TaskContext = []

# Run the code in the right thread
full_func = functools.partial(
thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
)
try:
if executor is None:

async def handle_cancel() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
if task_context:
task_context[0].cancel()
raise

async with trio.open_nursery() as nursery:
nursery.start_soon(handle_cancel)
try:
return await trio.to_thread.run_sync(
full_func, abandon_on_cancel=False
)
except TrioThreadCancelled:
pass
finally:
nursery.cancel_scope.cancel()
assert False
else:
event = trio.Event()

def callback(fut: object) -> None:
loop.run_sync_soon(event.set)

fut = executor.submit(full_func)
fut.add_done_callback(callback)

async def handle_cancel_fut() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
fut.cancel()
if task_context:
task_context[0].cancel()
raise

async with trio.open_nursery() as nursery:
nursery.start_soon(handle_cancel_fut)
with trio.CancelScope(shield=True):
await event.wait()
nursery.cancel_scope.cancel()
try:
return fut.result()
except TrioThreadCancelled:
pass
assert False
finally:
_restore_context(context)

else:
return await _asyncio.run_in_executor(
loop=loop, executor=executor, thread_handler=thread_handler, child=child
)


async def wrap_task_context(
loop: Loop, task_context: Union[TaskContext, None], awaitable: Awaitable[_R]
) -> _R:
if task_context is None:
return await awaitable

if isinstance(loop, trio.lowlevel.TrioToken):
with trio.CancelScope() as scope:
task_context.append(scope)
try:
return await awaitable
finally:
task_context.remove(scope)
raise TrioThreadCancelled

return await _asyncio.wrap_task_context(loop, task_context, awaitable)
47 changes: 35 additions & 12 deletions asgiref/local.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,32 @@
from typing import Any, Union


def _is_asyncio_running():
try:
asyncio.get_running_loop()
except RuntimeError:
return False
else:
return True


try:
import sniffio
except ModuleNotFoundError:
_is_async = _is_asyncio_running
else:

def _is_async():
try:
sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass
else:
return True

return _is_asyncio_running()


class _CVar:
"""Storage utility for Local."""

@@ -83,18 +109,9 @@ def __init__(self, thread_critical: bool = False) -> None:
def _lock_storage(self):
# Thread safe access to storage
if self._thread_critical:
try:
# this is a test for are we in a async or sync
# thread - will raise RuntimeError if there is
# no current loop
asyncio.get_running_loop()
except RuntimeError:
# We are in a sync thread, the storage is
# just the plain thread local (i.e, "global within
# this thread" - it doesn't matter where you are
# in a call stack you see the same storage)
yield self._storage
else:
# this is a test for are we in a async or sync
# thread
if _is_async():
# We are in an async thread - storage is still
# local to this thread, but additionally should
# behave like a context var (is only visible with
@@ -108,6 +125,12 @@ def _lock_storage(self):
# can't be accessed in another thread (we don't
# need any locks)
yield self._storage.cvar
else:
# We are in a sync thread, the storage is
# just the plain thread local (i.e, "global within
# this thread" - it doesn't matter where you are
# in a call stack you see the same storage)
yield self._storage
else:
# Lock for thread_critical=False as other threads
# can access the exact same storage object
139 changes: 67 additions & 72 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import asyncio.coroutines
import contextvars
import enum
import functools
import inspect
import os
@@ -24,6 +24,7 @@
overload,
)

from ._context import restore_context as _restore_context
from .current_thread_executor import CurrentThreadExecutor
from .local import Local

@@ -36,23 +37,35 @@
# This is not available to import at runtime
from _typeshed import OptExcInfo

from ._trio import (
create_task_threadsafe,
get_running_loop,
run_in_executor,
wrap_task_context,
)
else:
try:
__import__("trio")
except ModuleNotFoundError:
from ._asyncio import (
create_task_threadsafe,
get_running_loop,
run_in_executor,
wrap_task_context,
)
else:
from ._trio import (
create_task_threadsafe,
get_running_loop,
run_in_executor,
wrap_task_context,
)

_F = TypeVar("_F", bound=Callable[..., Any])
_P = ParamSpec("_P")
_R = TypeVar("_R")


def _restore_context(context: contextvars.Context) -> None:
# Check for changes in contextvars, and set them to the current
# context for downstream consumers
for cvar in context:
cvalue = context.get(cvar)
try:
if cvar.get() != cvalue:
cvar.set(cvalue)
except LookupError:
cvar.set(cvalue)


# Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for
# inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker.
# The latter is replaced with the inspect.markcoroutinefunction decorator.
@@ -110,6 +123,19 @@ async def __aexit__(self, exc, value, tb):
SyncToAsync.thread_sensitive_context.reset(self.token)


class LoopType(enum.Enum):
ASYNCIO = enum.auto()
TRIO = enum.auto()


def run(async_backend, callable, /, *args):
if async_backend is LoopType.TRIO:
import trio

return trio.run(callable, *args)
return asyncio.run(callable(*args))


class AsyncToSync(Generic[_P, _R]):
"""
Utility class which turns an awaitable that only works on the thread with
@@ -129,16 +155,19 @@ class AsyncToSync(Generic[_P, _R]):

# When we can't find a CurrentThreadExecutor from the context, such as
# inside create_task, we'll look it up here from the running event loop.
loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
loop_thread_executors: "Dict[object, CurrentThreadExecutor]" = {}

def __init__(
self,
awaitable: Union[
Callable[_P, Coroutine[Any, Any, _R]],
Callable[_P, Awaitable[_R]],
],
force_new_loop: bool = False,
force_new_loop: Union[LoopType, bool] = False,
):
if force_new_loop and not isinstance(force_new_loop, LoopType):
force_new_loop = LoopType.ASYNCIO

if not callable(awaitable) or (
not iscoroutinefunction(awaitable)
and not iscoroutinefunction(getattr(awaitable, "__call__", awaitable))
@@ -156,7 +185,7 @@ def __init__(
self.force_new_loop = force_new_loop
self.main_event_loop = None
try:
self.main_event_loop = asyncio.get_running_loop()
self.main_event_loop = get_running_loop()
except RuntimeError:
# There's no event loop in this thread.
pass
@@ -179,7 +208,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:

# You can't call AsyncToSync from a thread with a running event loop
try:
asyncio.get_running_loop()
get_running_loop()
except RuntimeError:
pass
else:
@@ -224,7 +253,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
)

async def new_loop_wrap() -> None:
loop = asyncio.get_running_loop()
loop = get_running_loop()
self.loop_thread_executors[loop] = current_executor
try:
await awaitable
@@ -233,8 +262,9 @@ async def new_loop_wrap() -> None:

if self.main_event_loop is not None:
try:
self.main_event_loop.call_soon_threadsafe(
self.main_event_loop.create_task, awaitable
create_task_threadsafe(
self.main_event_loop,
awaitable,
)
except RuntimeError:
running_in_main_event_loop = False
@@ -248,7 +278,9 @@ async def new_loop_wrap() -> None:
if not running_in_main_event_loop:
# Make our own event loop - in a new thread - and run inside that.
loop_executor = ThreadPoolExecutor(max_workers=1)
loop_future = loop_executor.submit(asyncio.run, new_loop_wrap())
loop_future = loop_executor.submit(
run, self.force_new_loop, new_loop_wrap
)
# Run the CurrentThreadExecutor until the future is done.
current_executor.run_until_future(loop_future)
# Wait for future and/or allow for exception propagation
@@ -283,30 +315,26 @@ async def main_wrap(

__traceback_hide__ = True # noqa: F841

loop = get_running_loop()
if context is not None:
_restore_context(context[0])

current_task = asyncio.current_task()
if current_task is not None and task_context is not None:
task_context.append(current_task)

result: _R
try:
# If we have an exception, run the function inside the except block
# after raising it so exc_info is correctly populated.
if exc_info[1]:
try:
raise exc_info[1]
except BaseException:
result = await awaitable
result = await wrap_task_context(loop, task_context, awaitable)
else:
result = await awaitable
result = await wrap_task_context(loop, task_context, awaitable)
except BaseException as e:
call_result.set_exception(e)
else:
call_result.set_result(result)
finally:
if current_task is not None and task_context is not None:
task_context.remove(current_task)
context[0] = contextvars.copy_context()


@@ -382,7 +410,7 @@ def __init__(

async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
__traceback_hide__ = True # noqa: F841
loop = asyncio.get_running_loop()
loop = get_running_loop()

# Work out what thread to run the code in
if self._thread_sensitive:
@@ -417,49 +445,16 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
# Use the passed in executor, or the loop's default if it is None
executor = self._executor

context = contextvars.copy_context()
child = functools.partial(self.func, *args, **kwargs)
func = context.run
task_context: List[asyncio.Task[Any]] = []

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
self.thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
ret: _R
try:
ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
cancel_parent = True
try:
task = task_context[0]
task.cancel()
try:
await task
cancel_parent = False
except asyncio.CancelledError:
pass
except IndexError:
pass
if exec_coro.done():
raise
if cancel_parent:
exec_coro.cancel()
ret = await exec_coro
return await run_in_executor(
loop=loop,
executor=executor,
thread_handler=self.thread_handler,
child=functools.partial(self.func, *args, **kwargs),
)
finally:
_restore_context(context)
self.deadlock_context.set(False)

return ret

def __get__(
self, parent: Any, objtype: Any
) -> Callable[_P, Coroutine[Any, Any, _R]]:
@@ -496,7 +491,7 @@ def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs):
@overload
def async_to_sync(
*,
force_new_loop: bool = False,
force_new_loop: Union[LoopType, bool] = False,
) -> Callable[
[Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]],
Callable[_P, _R],
@@ -511,7 +506,7 @@ def async_to_sync(
Callable[_P, Awaitable[_R]],
],
*,
force_new_loop: bool = False,
force_new_loop: Union[LoopType, bool] = False,
) -> Callable[_P, _R]:
...

@@ -524,7 +519,7 @@ def async_to_sync(
]
] = None,
*,
force_new_loop: bool = False,
force_new_loop: Union[LoopType, bool] = False,
) -> Union[
Callable[
[Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]],
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ zip_safe = false
tests =
pytest
pytest-asyncio
anyio[trio]
mypy>=1.14.0

[tool:pytest]
103 changes: 55 additions & 48 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import functools
import multiprocessing
import sys
import threading
import time
import warnings
@@ -10,7 +9,9 @@
from typing import Any
from unittest import TestCase

import anyio
import pytest
import trio.to_thread

from asgiref.sync import (
ThreadSensitiveContext,
@@ -21,7 +22,7 @@
from asgiref.timeout import timeout


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async():
"""
Tests we can call sync functions from an async thread
@@ -41,6 +42,16 @@ def sync_function():
end = time.monotonic()
assert result == 42
assert end - start >= 1


@pytest.mark.asyncio
async def test_sync_to_async_one_worker():
# Define sync function
@sync_to_async
def async_function():
time.sleep(1)
return 42

# Set workers to 1, call it twice and make sure that works right
loop = asyncio.get_running_loop()
old_executor = loop._default_executor or ThreadPoolExecutor()
@@ -72,7 +83,7 @@ def test_sync_to_async_fail_non_function():
)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_fail_async():
"""
sync_to_async raises a TypeError when applied to a sync function.
@@ -88,7 +99,7 @@ async def test_function():
)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_async_to_sync_fail_partial():
"""
sync_to_async raises a TypeError when applied to a sync partial.
@@ -106,7 +117,7 @@ async def test_function(*args):
)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_raises_typeerror_for_async_callable_instance():
class CallableClass:
async def __call__(self):
@@ -118,7 +129,7 @@ async def __call__(self):
sync_to_async(CallableClass())


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_decorator():
"""
Tests sync_to_async as a decorator
@@ -134,7 +145,7 @@ def test_function():
assert result == 43


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_nested_sync_to_async_retains_wrapped_function_attributes():
"""
Tests that attributes of functions wrapped by sync_to_async are retained
@@ -157,7 +168,7 @@ def test_function():
assert test_function.__name__ == "test_function"


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_method_decorator():
"""
Tests sync_to_async as a method decorator
@@ -175,7 +186,7 @@ def test_method(self):
assert result == 44


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_method_self_attribute():
"""
Tests sync_to_async on a method copies __self__
@@ -197,7 +208,7 @@ def test_method(self):
assert method.__self__ == instance


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_async_to_sync_to_async():
"""
Tests we can call async functions from a sync thread created by async_to_sync
@@ -225,7 +236,7 @@ def sync_function():
assert result["thread"] == threading.current_thread()


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_async_to_sync_to_async_decorator():
"""
Test async_to_sync as a function decorator uses the outer thread
@@ -253,9 +264,8 @@ def sync_function():
assert result["thread"] == threading.current_thread()


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9")
async def test_async_to_sync_to_thread_decorator():
@pytest.mark.anyio
async def test_async_to_sync_to_thread_decorator(anyio_backend_name):
"""
Test async_to_sync as a function decorator uses the outer thread
when used inside another sync thread.
@@ -270,7 +280,10 @@ async def inner_async_function():
return 42

# Check it works right
number = await asyncio.to_thread(inner_async_function)
if anyio_backend_name == "trio":
number = await trio.to_thread.run_sync(inner_async_function)
else:
number = await asyncio.to_thread(inner_async_function)
assert number == 42
assert result["worked"]
# Make sure that it didn't needlessly make a new async loop
@@ -363,7 +376,7 @@ async def test_function(self):
assert result["worked"]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_async_to_sync_in_async():
"""
Makes sure async_to_sync bails if you try to call it from an async loop
@@ -509,7 +522,7 @@ def inner_task():
assert result["thread2"] == threading.current_thread()


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_thread_sensitive_outside_async():
"""
Tests that thread_sensitive SyncToAsync where the outside is async code runs
@@ -535,16 +548,16 @@ def inner(result):
result["thread"] = threading.current_thread()

# Run it (in supposed parallel!)
await asyncio.wait(
[asyncio.create_task(outer(result_1)), asyncio.create_task(inner(result_2))]
)
async with anyio.create_task_group() as tg:
tg.start_soon(outer, result_1)
await inner(result_2)

# They should not have run in the main thread, but in the same thread
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] == result_2["thread"]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_thread_sensitive_with_context_matches():
result_1 = {}
result_2 = {}
@@ -557,12 +570,9 @@ def store_thread(result):
async def fn():
async with ThreadSensitiveContext():
# Run it (in supposed parallel!)
await asyncio.wait(
[
asyncio.create_task(store_thread_async(result_1)),
asyncio.create_task(store_thread_async(result_2)),
]
)
async with anyio.create_task_group() as tg:
tg.start_soon(store_thread_async, result_1)
await store_thread_async(result_2)

await fn()

@@ -571,7 +581,7 @@ async def fn():
assert result_1["thread"] == result_2["thread"]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_thread_sensitive_nested_context():
result_1 = {}
result_2 = {}
@@ -590,7 +600,7 @@ def store_thread(result):
assert result_1["thread"] == result_2["thread"]


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_thread_sensitive_context_without_sync_work():
async with ThreadSensitiveContext():
pass
@@ -629,7 +639,7 @@ def level4():
assert result["thread"] == threading.current_thread()


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_thread_sensitive_double_nested_async():
"""
Tests that thread_sensitive SyncToAsync nests inside itself where the
@@ -729,7 +739,7 @@ def fork_first():
return queue.get(True, 1)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_multiprocessing():
"""
Tests that a forked process can use async_to_sync without it looking for
@@ -738,7 +748,7 @@ async def test_multiprocessing():
assert await sync_to_async(fork_first)() == 42


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_uses_executor():
"""
Tests that SyncToAsync uses the passed in executor correctly.
@@ -834,7 +844,7 @@ async def async_process_that_triggers_event():
await trigger_task


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_with_blocker_non_thread_sensitive():
"""
Tests sync_to_async running on a long-time blocker in a non_thread_sensitive context.
@@ -850,23 +860,20 @@ async def async_process_waiting_on_event():

async def async_process_that_triggers_event():
"""Sleep, then set the event."""
await asyncio.sleep(1)
await anyio.sleep(1)
await sync_to_async(event.set)()

# Run the event setter as a task.
trigger_task = asyncio.ensure_future(async_process_that_triggers_event())
async with anyio.create_task_group() as tg:
# Run the event setter as a task.
tg.start_soon(async_process_that_triggers_event)

try:
# wait on the event waiter, which is now blocking the event setter.
async with timeout(delay + 1):
assert await async_process_waiting_on_event() == 42
except asyncio.TimeoutError:
# In case of timeout, set the event to unblock things, else
# downstream tests will get fouled up.
event.set()
raise
finally:
await trigger_task
try:
with anyio.fail_after(delay + 1):
assert await async_process_waiting_on_event() == 42
except TimeoutError:
# In case of timeout, set the event to unblock things, else
# downstream tests will get fouled up.
event.set()


@pytest.mark.asyncio
@@ -1194,7 +1201,7 @@ async def test_function(**kwargs: Any) -> None:
test_function(context=1)


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_sync_to_async_overlapping_kwargs() -> None:
"""
Tests that SyncToAsync correctly passes through kwargs to the wrapped function,
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tox]
envlist =
py{38,39,310,311,312,313}-{test,mypy}
py{38,39,310,311,312,313}-{test,mypy,trio}
qa

[testenv]