From ae51d2220f8c23d0daaa7f7c67547c447c6831fd Mon Sep 17 00:00:00 2001 From: Eugene Tverdokhleb Date: Tue, 6 Apr 2021 17:50:49 +0300 Subject: [PATCH] Fix type checks for functions passed to async_to_sync/sync_to_async On Python < 3.8 a functool.partial is not detected as async function, which is due to a bug in inspect module, see https://bugs.python.org/issue33261 This commit adds a proper type check for earlier Python versions and fixes #251. --- asgiref/sync.py | 20 ++++++++++++++++++-- tests/test_sync.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index cff5e1d7..b2e9e735 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -1,5 +1,6 @@ import asyncio.coroutines import functools +import inspect import os import sys import threading @@ -27,6 +28,21 @@ def _restore_context(context): cvar.set(context.get(cvar)) +def _iscoroutinefunction_or_partial(func: Any) -> bool: + # Python < 3.8 does not correctly determine partially wrapped + # coroutine functions are coroutine functions, hence the need for + # this to exist. Code taken from CPython. + if sys.version_info >= (3, 8): + return asyncio.iscoroutinefunction(func) + else: + while inspect.ismethod(func): + func = func.__func__ + while isinstance(func, functools.partial): + func = func.func + + return asyncio.iscoroutinefunction(func) + + class ThreadSensitiveContext: """Async context manager to manage context for thread sensitive mode @@ -101,7 +117,7 @@ class AsyncToSync: executors = Local() def __init__(self, awaitable, force_new_loop=False): - if not callable(awaitable) or not asyncio.iscoroutinefunction(awaitable): + if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable): raise TypeError("async_to_sync can only be applied to async functions.") self.awaitable = awaitable try: @@ -336,7 +352,7 @@ def __init__( thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, ) -> None: - if not callable(func) or asyncio.iscoroutinefunction(func): + if not callable(func) or _iscoroutinefunction_or_partial(func): raise TypeError("sync_to_async can only be applied to sync functions.") self.func = func functools.update_wrapper(self, func) diff --git a/tests/test_sync.py b/tests/test_sync.py index 0aa336cc..e7c5bab0 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,4 +1,5 @@ import asyncio +import functools import multiprocessing import threading import time @@ -73,6 +74,24 @@ async def test_function(): ) +@pytest.mark.asyncio +async def test_async_to_sync_fail_partial(): + """ + sync_to_async raises a TypeError when applied to a sync partial. + """ + with pytest.raises(TypeError) as excinfo: + + async def test_function(*args): + pass + + partial_function = functools.partial(test_function, 42) + sync_to_async(partial_function) + + assert excinfo.value.args == ( + "sync_to_async can only be applied to sync functions.", + ) + + @pytest.mark.asyncio async def test_sync_to_async_decorator(): """ @@ -319,6 +338,27 @@ async def test_function(): assert test_function() == 42 +def test_async_to_sync_partial(): + """ + Tests we can call async_to_sync on an async partial. + """ + result = {} + + # Define async function + async def inner_async_function(*args): + await asyncio.sleep(0) + result["worked"] = True + return [*args] + + partial_function = functools.partial(inner_async_function, 42) + + # Run it + sync_function = async_to_sync(partial_function) + out = sync_function(84) + assert out == [42, 84] + assert result["worked"] + + def test_async_to_async_method_self_attribute(): """ Tests async_to_async on a method copies __self__.