diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index d6b6573c..74dacbbb 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -299,11 +299,8 @@ def build_response( defined by the "Response" section of the GraphQL spec. """ if self.is_awaitable(data): - - async def build_response_async() -> ExecutionResult: - return self.build_response(await data) # type: ignore - - return build_response_async() + data = cast(Awaitable, data) + return self.build_response_async(data) data = cast(Optional[Dict[str, Any]], data) errors = self.errors if not errors: @@ -350,14 +347,8 @@ def execute_operation( return None else: if self.is_awaitable(result): - # noinspection PyShadowingNames - async def await_result() -> Any: - try: - return await result # type: ignore - except GraphQLError as error: - self.errors.append(error) - - return await_result() + result = cast(Awaitable, result) + return self.await_result(result) return result def execute_fields_serially( @@ -382,42 +373,17 @@ def execute_fields_serially( if result is Undefined: continue if is_awaitable(results): - # noinspection PyShadowingNames - async def await_and_set_result( - results: Awaitable[Dict[str, Any]], - response_name: str, - result: AwaitableOrValue[Any], - ) -> Dict[str, Any]: - awaited_results = await results - awaited_results[response_name] = ( - await result if is_awaitable(result) else result - ) - return awaited_results - - results = await_and_set_result( - cast(Awaitable, results), response_name, result + results = ExecutionContext.await_and_set_result( + is_awaitable, cast(Awaitable, results), response_name, result ) elif is_awaitable(result): - # noinspection PyShadowingNames - async def set_result( - results: Dict[str, Any], - response_name: str, - result: Awaitable, - ) -> Dict[str, Any]: - results[response_name] = await result - return results - - results = set_result( + results = ExecutionContext.set_result( cast(Dict[str, Any], results), response_name, result ) else: cast(Dict[str, Any], results)[response_name] = result if is_awaitable(results): - # noinspection PyShadowingNames - async def get_results() -> Any: - return await cast(Awaitable, results) - - return get_results() + return ExecutionContext.get_results(cast(Awaitable, results)) return results def execute_fields( @@ -454,16 +420,7 @@ def execute_fields( # field, which is possibly a coroutine object. Return a coroutine object that # will yield this same map, but with any coroutines awaited in parallel and # replaced with the values they yielded. - async def get_results() -> Dict[str, Any]: - results.update( - zip( - awaitable_fields, - await gather(*(results[field] for field in awaitable_fields)), - ) - ) - return results - - return get_results() + return ExecutionContext.get_results_map(awaitable_fields, results) def build_resolve_info( self, @@ -532,36 +489,15 @@ def execute_field( completed: AwaitableOrValue[Any] if self.is_awaitable(result): - # noinspection PyShadowingNames - async def await_result() -> Any: - try: - completed = self.complete_value( - return_type, field_nodes, info, path, await result - ) - if self.is_awaitable(completed): - return await completed - return completed - except Exception as raw_error: - error = located_error(raw_error, field_nodes, path.as_list()) - self.handle_field_error(error, return_type) - return None - - return await_result() + return self.await_field_result( + result, field_nodes, info, path, return_type + ) completed = self.complete_value( return_type, field_nodes, info, path, result ) if self.is_awaitable(completed): - # noinspection PyShadowingNames - async def await_completed() -> Any: - try: - return await completed - except Exception as raw_error: - error = located_error(raw_error, field_nodes, path.as_list()) - self.handle_field_error(error, return_type) - return None - - return await_completed() + return self.await_completed(completed, field_nodes, path, return_type) return completed except Exception as raw_error: @@ -964,6 +900,87 @@ def collect_subfields( cache[key] = sub_field_nodes return sub_field_nodes + # Async methods + async def build_response_async( + self, data: Awaitable[Optional[Dict[str, Any]]] + ) -> ExecutionResult: + return self.build_response(await data) # type: ignore + + async def await_result(self, result: Awaitable[Dict[str, Any]]) -> Optional[Any]: + try: + return await result + except GraphQLError as error: + self.errors.append(error) + return None + + @staticmethod + async def await_and_set_result( + is_awaitable: Callable[[Any], bool], + results: Awaitable[Dict[str, Any]], + response_name: str, + result: AwaitableOrValue[Any], + ) -> Dict[str, Any]: + awaited_results = await results + awaited_results[response_name] = ( + await result if is_awaitable(result) else result + ) + return awaited_results + + @staticmethod + async def set_result( + results: Dict[str, Any], + response_name: str, + result: Awaitable, + ) -> Dict[str, Any]: + results[response_name] = await result + return results + + @staticmethod + async def get_results(results: Awaitable[Dict[str, Any]]) -> Any: + return await results + + @staticmethod + async def get_results_map( + awaitable_fields: List[str], results: Dict[str, Any] + ) -> Dict[str, Any]: + results.update( + zip( + awaitable_fields, + await gather(*(results[field] for field in awaitable_fields)), + ) + ) + return results + + async def await_field_result( + self, + result: Any, + field_nodes: List[FieldNode], + info: GraphQLResolveInfo, + path: Path, + return_type, + ) -> Any: + try: + completed = self.complete_value( + return_type, field_nodes, info, path, await result + ) + if self.is_awaitable(completed): + return await completed + return completed + except Exception as raw_error: + error = located_error(raw_error, field_nodes, path.as_list()) + self.handle_field_error(error, return_type) + return None + + async def await_completed( + self, completed: Any, field_nodes: List[FieldNode], path: Path, return_type + ) -> Any: + try: + return await completed + except Exception as raw_error: + error = located_error(raw_error, field_nodes, path.as_list()) + self.handle_field_error(error, return_type) + return None + def execute( schema: GraphQLSchema,