Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asyncio.iscoroutine instead of inspect.isawaitable #74

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from asyncio import gather
from inspect import isawaitable
from asyncio import gather, iscoroutine
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -303,7 +302,7 @@ def build_response(
Given a completed execution context and data, build the (data, errors) response
defined by the "Response" section of the GraphQL spec.
"""
if isawaitable(data):
if iscoroutine(data):

async def build_response_async():
return self.build_response(await data) # type: ignore
Expand Down Expand Up @@ -351,7 +350,7 @@ def execute_operation(
self.errors.append(error)
return None
else:
if isawaitable(result):
if iscoroutine(result):
# noinspection PyShadowingNames
async def await_result():
try:
Expand Down Expand Up @@ -384,20 +383,20 @@ def execute_fields_serially(
)
if result is INVALID:
continue
if isawaitable(results):
if iscoroutine(results):
# noinspection PyShadowingNames
async def await_and_set_result(results, response_name, result):
awaited_results = await results
awaited_results[response_name] = (
await result if isawaitable(result) else result
await result if iscoroutine(result) else result
)
return awaited_results

# noinspection PyTypeChecker
results = await_and_set_result(
cast(Awaitable, results), response_name, result
)
elif isawaitable(result):
elif iscoroutine(result):
# noinspection PyShadowingNames
async def set_result(results, response_name, result):
results[response_name] = await result
Expand All @@ -407,7 +406,7 @@ async def set_result(results, response_name, result):
results = set_result(results, response_name, result)
else:
results[response_name] = result
if isawaitable(results):
if iscoroutine(results):
# noinspection PyShadowingNames
async def get_results():
return await cast(Awaitable, results)
Expand Down Expand Up @@ -436,7 +435,7 @@ def execute_fields(
)
if result is not INVALID:
results[response_name] = result
if isawaitable(result):
if iscoroutine(result):
append_awaitable(response_name)

# If there are no coroutines, we can just return the object
Expand Down Expand Up @@ -634,7 +633,7 @@ def resolve_field_value_or_error(
# Note that contrary to the JavaScript implementation, we pass the context
# value as part of the resolve info.
result = resolve_fn(source, info, **args)
if isawaitable(result):
if iscoroutine(result):
# noinspection PyShadowingNames
async def await_result():
try:
Expand Down Expand Up @@ -665,13 +664,13 @@ def complete_value_catching_error(
the execution context.
"""
try:
if isawaitable(result):
if iscoroutine(result):

async def await_result():
value = self.complete_value(
return_type, field_nodes, info, path, await result
)
if isawaitable(value):
if iscoroutine(value):
return await value
return value

Expand All @@ -680,7 +679,7 @@ async def await_result():
completed = self.complete_value(
return_type, field_nodes, info, path, result
)
if isawaitable(completed):
if iscoroutine(completed):
# noinspection PyShadowingNames
async def await_completed():
try:
Expand Down Expand Up @@ -830,7 +829,7 @@ def complete_list_value(
item_type, field_nodes, info, field_path, item
)

if isawaitable(completed_item):
if iscoroutine(completed_item):
append_awaitable(index)
append_result(completed_item)

Expand Down Expand Up @@ -881,7 +880,7 @@ def complete_abstract_value(
resolve_type_fn = return_type.resolve_type or self.type_resolver
runtime_type = resolve_type_fn(result, info, return_type) # type: ignore

if isawaitable(runtime_type):
if iscoroutine(runtime_type):

async def await_complete_object_value():
value = self.complete_object_value(
Expand All @@ -897,7 +896,7 @@ async def await_complete_object_value():
path,
result,
)
if isawaitable(value):
if iscoroutine(value):
return await value # type: ignore
return value

Expand Down Expand Up @@ -965,7 +964,7 @@ def complete_object_value(
if return_type.is_type_of:
is_type_of = return_type.is_type_of(result, info)

if isawaitable(is_type_of):
if iscoroutine(is_type_of):

async def collect_and_execute_subfields_async():
if not await is_type_of: # type: ignore
Expand Down Expand Up @@ -1126,7 +1125,7 @@ def default_type_resolver(
if type_.is_type_of:
is_type_of_result = type_.is_type_of(value, info)

if isawaitable(is_type_of_result):
if iscoroutine(is_type_of_result):
append_awaitable_results(cast(Awaitable, is_type_of_result))
append_awaitable_types(type_)
elif is_type_of_result:
Expand Down
7 changes: 3 additions & 4 deletions src/graphql/graphql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from asyncio import ensure_future
from inspect import isawaitable
from asyncio import ensure_future, iscoroutine
from typing import Any, Awaitable, Dict, Union, Type, cast

from .error import GraphQLError
Expand Down Expand Up @@ -84,7 +83,7 @@ async def graphql(
execution_context_class,
)

if isawaitable(result):
if iscoroutine(result):
return await cast(Awaitable[ExecutionResult], result)

return cast(ExecutionResult, result)
Expand Down Expand Up @@ -123,7 +122,7 @@ def graphql_sync(
)

# Assert that the execution was synchronous.
if isawaitable(result):
if iscoroutine(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")

Expand Down
5 changes: 2 additions & 3 deletions src/graphql/pyutils/event_emitter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import cast, Callable, Dict, List, Optional

from asyncio import AbstractEventLoop, Queue, ensure_future
from inspect import isawaitable
from asyncio import AbstractEventLoop, Queue, ensure_future, iscoroutine

from collections import defaultdict

Expand Down Expand Up @@ -32,7 +31,7 @@ def emit(self, event_name, *args, **kwargs):
return False
for listener in listeners:
result = listener(*args, **kwargs)
if isawaitable(result):
if iscoroutine(result):
ensure_future(result, loop=self.loop)
return True

Expand Down
6 changes: 3 additions & 3 deletions src/graphql/subscription/map_async_iterator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import Event, ensure_future, Future, wait
from asyncio import Event, ensure_future, Future, wait, iscoroutine
from concurrent.futures import FIRST_COMPLETED
from inspect import isasyncgen, isawaitable
from inspect import isasyncgen
from typing import AsyncIterable, Callable, Set

__all__ = ["MapAsyncIterator"]
Expand Down Expand Up @@ -62,7 +62,7 @@ async def __anext__(self):
value = anext.result()
result = self.callback(value)

return await result if isawaitable(result) else result
return await result if iscoroutine(result) else result

async def athrow(self, type_, value=None, traceback=None):
if not self.is_closed:
Expand Down
6 changes: 3 additions & 3 deletions src/graphql/subscription/subscribe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from inspect import isawaitable
from asyncio import iscoroutine
from typing import Any, AsyncIterable, AsyncIterator, Awaitable, Dict, Union, cast

from ..error import GraphQLError, located_error
Expand Down Expand Up @@ -81,7 +81,7 @@ async def map_source_to_response(payload) -> ExecutionResult:
operation_name,
field_resolver,
)
return await result if isawaitable(result) else result # type: ignore
return await result if iscoroutine(result) else result # type: ignore

return MapAsyncIterator(result_or_stream, map_source_to_response)

Expand Down Expand Up @@ -162,7 +162,7 @@ async def create_source_event_stream(
result = context.resolve_field_value_or_error(
field_def, field_nodes, resolve_fn, root_value, info
)
event_stream = await cast(Awaitable, result) if isawaitable(result) else result
event_stream = await cast(Awaitable, result) if iscoroutine(result) else result
# If `event_stream` is an Error, rethrow a located error.
if isinstance(event_stream, Exception):
raise located_error(event_stream, field_nodes, path.as_list())
Expand Down
4 changes: 2 additions & 2 deletions tests/execution/test_nonnull.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from inspect import isawaitable
from asyncio import iscoroutine

from pytest import mark # type: ignore

Expand Down Expand Up @@ -107,7 +107,7 @@ def patch(data):

async def execute_sync_and_async(query, root_value):
sync_result = execute_query(query, root_value)
if isawaitable(sync_result):
if iscoroutine(sync_result):
sync_result = await sync_result
async_result = await execute_query(patch(query), root_value)

Expand Down
4 changes: 2 additions & 2 deletions tests/execution/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from inspect import isawaitable
from asyncio import iscoroutine
from typing import Awaitable, cast

from pytest import mark, raises # type: ignore
Expand Down Expand Up @@ -55,7 +55,7 @@ def does_not_return_a_promise_if_mutation_fields_are_all_synchronous():
async def returns_a_promise_if_any_field_is_asynchronous():
doc = "query Example { syncField, asyncField }"
result = execute(schema, parse(doc), "rootValue")
assert isawaitable(result)
assert iscoroutine(result)
result = cast(Awaitable, result)
assert await result == (
{"syncField": "rootValue", "asyncField": "rootValue"},
Expand Down