Skip to content

Commit

Permalink
Add parameterization to executor
Browse files Browse the repository at this point in the history
Allows users to pass in an optional executor to SyncToAsync to override the default.
  • Loading branch information
joshuahaertel authored Mar 13, 2021
1 parent 9aae49c commit b3257f6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
31 changes: 25 additions & 6 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Dict
from typing import Any, Callable, Dict, Optional, Union

from .current_thread_executor import CurrentThreadExecutor
from .local import Local
Expand Down Expand Up @@ -293,6 +293,10 @@ class SyncToAsync:
outermost), this will just be the main thread. This is achieved by idling
with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent,
rather than just blocking.
If executor is passed in, that will be used instead of the loop's default executor.
In order to pass in an executor, thread_sensitive must be set to False, otherwise
a TypeError will be raised.
"""

# If they've set ASGI_THREADS, update the default asyncio executor for now
Expand Down Expand Up @@ -326,15 +330,23 @@ class SyncToAsync:
weakref.WeakKeyDictionary()
)

def __init__(self, func, thread_sensitive=True):
def __init__(
self,
func: Callable[..., Any],
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> None:
if not callable(func) or asyncio.iscoroutinefunction(func):
raise TypeError("sync_to_async can only be applied to sync functions.")
self.func = func
functools.update_wrapper(self, func)
self._thread_sensitive = thread_sensitive
self._is_coroutine = asyncio.coroutines._is_coroutine
self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore
if thread_sensitive and executor is not None:
raise TypeError("executor must not be set when thread_sensitive is True")
self._executor = executor
try:
self.__self__ = func.__self__
self.__self__ = func.__self__ # type: ignore
except AttributeError:
pass

Expand Down Expand Up @@ -364,7 +376,8 @@ async def __call__(self, *args, **kwargs):
# Otherwise, we run it in a fixed single thread
executor = self.single_thread_executor
else:
executor = None # Use default
# Use the passed in executor, or the loop's default if it is None
executor = self._executor

if contextvars is not None:
context = contextvars.copy_context()
Expand Down Expand Up @@ -456,13 +469,19 @@ def get_current_task():
async_to_sync = AsyncToSync


def sync_to_async(func=None, thread_sensitive=True):
def sync_to_async(
func: Optional[Callable[..., Any]] = None,
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> Union[SyncToAsync, Callable[[Callable[..., Any]], SyncToAsync]]:
if func is None:
return lambda f: SyncToAsync(
f,
thread_sensitive=thread_sensitive,
executor=executor,
)
return SyncToAsync(
func,
thread_sensitive=thread_sensitive,
executor=executor,
)
37 changes: 37 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,40 @@ def fork_first():
return test_queue.get(True, 1)

assert await sync_to_async(fork_first)() == 42


@pytest.mark.asyncio
async def test_sync_to_async_uses_executor():
"""
Tests that SyncToAsync uses the passed in executor correctly.
"""

class CustomExecutor:
def __init__(self):
self.executor = ThreadPoolExecutor(max_workers=1)
self.times_submit_called = 0

def submit(self, callable_, *args, **kwargs):
self.times_submit_called += 1
return self.executor.submit(callable_, *args, **kwargs)

expected_result = "expected_result"

def sync_func():
return expected_result

custom_executor = CustomExecutor()
async_function = sync_to_async(
sync_func, thread_sensitive=False, executor=custom_executor
)
actual_result = await async_function()
assert actual_result == expected_result
assert custom_executor.times_submit_called == 1

pytest.raises(
TypeError,
sync_to_async,
sync_func,
thread_sensitive=True,
executor=custom_executor,
)

0 comments on commit b3257f6

Please sign in to comment.