diff --git a/src/graphql/pyutils/__init__.py b/src/graphql/pyutils/__init__.py index 797c1321..8c09132a 100644 --- a/src/graphql/pyutils/__init__.py +++ b/src/graphql/pyutils/__init__.py @@ -21,6 +21,7 @@ from .event_emitter import EventEmitter, EventEmitterAsyncIterator from .identity_func import identity_func from .inspect import inspect +from .is_awaitable import is_awaitable from .is_collection import is_collection from .is_finite import is_finite from .is_integer import is_integer @@ -47,6 +48,7 @@ "EventEmitterAsyncIterator", "identity_func", "inspect", + "is_awaitable", "is_collection", "is_finite", "is_integer", diff --git a/src/graphql/pyutils/is_awaitable.py b/src/graphql/pyutils/is_awaitable.py new file mode 100644 index 00000000..80c3be3f --- /dev/null +++ b/src/graphql/pyutils/is_awaitable.py @@ -0,0 +1,24 @@ +import inspect +from typing import Any +from types import CoroutineType, GeneratorType + +__all__ = ["is_awaitable"] + +CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE + + +def is_awaitable(value: Any) -> bool: + """Return true if object can be passed to an ``await`` expression. + + Instead of testing if the object is an instance of abc.Awaitable, it checks + the existence of an `__await__` attribute. This is much faster. + """ + return ( + # check for coroutine objects + isinstance(value, CoroutineType) + # check for old-style generator based coroutine objects + or isinstance(value, GeneratorType) + and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE) + # check for other awaitables (e.g. futures) + or hasattr(value, "__await__") + ) diff --git a/tests/pyutils/test_is_awaitable.py b/tests/pyutils/test_is_awaitable.py new file mode 100644 index 00000000..ac6390e6 --- /dev/null +++ b/tests/pyutils/test_is_awaitable.py @@ -0,0 +1,103 @@ +import asyncio +from inspect import isawaitable + +from pytest import mark # type: ignore + +from graphql.pyutils import is_awaitable + + +def describe_is_awaitable(): + def declines_the_none_value(): + assert not isawaitable(None) + assert not is_awaitable(None) + + def declines_a_boolean_value(): + assert not isawaitable(True) + assert not is_awaitable(True) + + def declines_an_int_value(): + assert not is_awaitable(42) + + def declines_a_string_value(): + assert not isawaitable("some_string") + assert not is_awaitable("some_string") + + def declines_a_dict_value(): + assert not isawaitable({}) + assert not is_awaitable({}) + + def declines_an_object_instance(): + assert not isawaitable(object()) + assert not is_awaitable(object()) + + def declines_the_type_class(): + assert not isawaitable(type) + assert not is_awaitable(type) + + def declines_a_lambda_function(): + assert not isawaitable(lambda: True) # pragma: no cover + assert not is_awaitable(lambda: True) # pragma: no cover + + def declines_a_normal_function(): + def some_function(): + return True + + assert not isawaitable(some_function()) + assert not is_awaitable(some_function) + + def declines_a_normal_generator_function(): + def some_generator(): + yield True # pragma: no cover + + assert not isawaitable(some_generator) + assert not is_awaitable(some_generator) + + def declines_a_normal_generator_object(): + def some_generator(): + yield True # pragma: no cover + + assert not isawaitable(some_generator()) + assert not is_awaitable(some_generator()) + + def declines_a_coroutine_function(): + async def some_coroutine(): + return True # pragma: no cover + + assert not isawaitable(some_coroutine) + assert not is_awaitable(some_coroutine) + + @mark.filterwarnings("ignore::RuntimeWarning") + def recognizes_a_coroutine_object(): + async def some_coroutine(): + return False # pragma: no cover + + assert isawaitable(some_coroutine()) + assert is_awaitable(some_coroutine()) + + @mark.filterwarnings("ignore::RuntimeWarning") + @mark.filterwarnings("ignore::DeprecationWarning") + def recognizes_an_old_style_coroutine(): + @asyncio.coroutine + def some_old_style_coroutine(): + yield False # pragma: no cover + + assert is_awaitable(some_old_style_coroutine()) + assert is_awaitable(some_old_style_coroutine()) + + @mark.filterwarnings("ignore::RuntimeWarning") + def recognizes_a_future_object(): + async def some_coroutine(): + return False # pragma: no cover + + some_future = asyncio.ensure_future(some_coroutine()) + + assert is_awaitable(some_future) + assert is_awaitable(some_future) + + @mark.filterwarnings("ignore::RuntimeWarning") + def declines_an_async_generator(): + async def some_async_generator(): + yield True # pragma: no cover + + assert not isawaitable(some_async_generator()) + assert not is_awaitable(some_async_generator())