Skip to content

Move async methods into class methods #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 94 additions & 77 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,8 @@ def build_response(
defined by the "Response" section of the GraphQL spec.
"""
if self.is_awaitable(data):

async def build_response_async() -> ExecutionResult:
return self.build_response(await data) # type: ignore

return build_response_async()
data = cast(Awaitable, data)
return self.build_response_async(data)
data = cast(Optional[Dict[str, Any]], data)
errors = self.errors
if not errors:
Expand Down Expand Up @@ -350,14 +347,8 @@ def execute_operation(
return None
else:
if self.is_awaitable(result):
# noinspection PyShadowingNames
async def await_result() -> Any:
try:
return await result # type: ignore
except GraphQLError as error:
self.errors.append(error)

return await_result()
result = cast(Awaitable, result)
return self.await_result(result)
return result

def execute_fields_serially(
Expand All @@ -382,42 +373,17 @@ def execute_fields_serially(
if result is Undefined:
continue
if is_awaitable(results):
# noinspection PyShadowingNames
async def await_and_set_result(
results: Awaitable[Dict[str, Any]],
response_name: str,
result: AwaitableOrValue[Any],
) -> Dict[str, Any]:
awaited_results = await results
awaited_results[response_name] = (
await result if is_awaitable(result) else result
)
return awaited_results

results = await_and_set_result(
cast(Awaitable, results), response_name, result
results = ExecutionContext.await_and_set_result(
is_awaitable, cast(Awaitable, results), response_name, result
)
elif is_awaitable(result):
# noinspection PyShadowingNames
async def set_result(
results: Dict[str, Any],
response_name: str,
result: Awaitable,
) -> Dict[str, Any]:
results[response_name] = await result
return results

results = set_result(
results = ExecutionContext.set_result(
cast(Dict[str, Any], results), response_name, result
)
else:
cast(Dict[str, Any], results)[response_name] = result
if is_awaitable(results):
# noinspection PyShadowingNames
async def get_results() -> Any:
return await cast(Awaitable, results)

return get_results()
return ExecutionContext.get_results(cast(Awaitable, results))
return results

def execute_fields(
Expand Down Expand Up @@ -454,16 +420,7 @@ def execute_fields(
# field, which is possibly a coroutine object. Return a coroutine object that
# will yield this same map, but with any coroutines awaited in parallel and
# replaced with the values they yielded.
async def get_results() -> Dict[str, Any]:
results.update(
zip(
awaitable_fields,
await gather(*(results[field] for field in awaitable_fields)),
)
)
return results

return get_results()
return ExecutionContext.get_results_map(awaitable_fields, results)

def build_resolve_info(
self,
Expand Down Expand Up @@ -532,36 +489,15 @@ def execute_field(

completed: AwaitableOrValue[Any]
if self.is_awaitable(result):
# noinspection PyShadowingNames
async def await_result() -> Any:
try:
completed = self.complete_value(
return_type, field_nodes, info, path, await result
)
if self.is_awaitable(completed):
return await completed
return completed
except Exception as raw_error:
error = located_error(raw_error, field_nodes, path.as_list())
self.handle_field_error(error, return_type)
return None

return await_result()
return self.await_field_result(
result, field_nodes, info, path, return_type
)

completed = self.complete_value(
return_type, field_nodes, info, path, result
)
if self.is_awaitable(completed):
# noinspection PyShadowingNames
async def await_completed() -> Any:
try:
return await completed
except Exception as raw_error:
error = located_error(raw_error, field_nodes, path.as_list())
self.handle_field_error(error, return_type)
return None

return await_completed()
return self.await_completed(completed, field_nodes, path, return_type)

return completed
except Exception as raw_error:
Expand Down Expand Up @@ -964,6 +900,87 @@ def collect_subfields(
cache[key] = sub_field_nodes
return sub_field_nodes

# Async methods
async def build_response_async(
self, data: Awaitable[Optional[Dict[str, Any]]]
) -> ExecutionResult:
return self.build_response(await data) # type: ignore

async def await_result(self, result: Awaitable[Dict[str, Any]]) -> Optional[Any]:
try:
return await result
except GraphQLError as error:
self.errors.append(error)
return None

@staticmethod
async def await_and_set_result(
is_awaitable: Callable[[Any], bool],
results: Awaitable[Dict[str, Any]],
response_name: str,
result: AwaitableOrValue[Any],
) -> Dict[str, Any]:
awaited_results = await results
awaited_results[response_name] = (
await result if is_awaitable(result) else result
)
return awaited_results

@staticmethod
async def set_result(
results: Dict[str, Any],
response_name: str,
result: Awaitable,
) -> Dict[str, Any]:
results[response_name] = await result
return results

@staticmethod
async def get_results(results: Awaitable[Dict[str, Any]]) -> Any:
return await results

@staticmethod
async def get_results_map(
awaitable_fields: List[str], results: Dict[str, Any]
) -> Dict[str, Any]:
results.update(
zip(
awaitable_fields,
await gather(*(results[field] for field in awaitable_fields)),
)
)
return results

async def await_field_result(
self,
result: Any,
field_nodes: List[FieldNode],
info: GraphQLResolveInfo,
path: Path,
return_type,
) -> Any:
try:
completed = self.complete_value(
return_type, field_nodes, info, path, await result
)
if self.is_awaitable(completed):
return await completed
return completed
except Exception as raw_error:
error = located_error(raw_error, field_nodes, path.as_list())
self.handle_field_error(error, return_type)
return None

async def await_completed(
self, completed: Any, field_nodes: List[FieldNode], path: Path, return_type
) -> Any:
try:
return await completed
except Exception as raw_error:
error = located_error(raw_error, field_nodes, path.as_list())
self.handle_field_error(error, return_type)
return None


def execute(
schema: GraphQLSchema,
Expand Down