From 353fbc76669e0b28dbd7171c6c3478944b21b996 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 11 Jul 2020 22:24:31 +0200 Subject: [PATCH] Split parsing, validation and execution (#43) Instead of graphql()/graphql_sync() we now call execute() directly. This also allows adding custom validation rules and limiting the number of reported errors. --- graphql_server/__init__.py | 82 ++++++++++++++++++++++---------------- tests/test_query.py | 63 +++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 34 deletions(-) diff --git a/graphql_server/__init__.py b/graphql_server/__init__.py index 99452b1..2148389 100644 --- a/graphql_server/__init__.py +++ b/graphql_server/__init__.py @@ -9,13 +9,16 @@ import json from collections import namedtuple from collections.abc import MutableMapping -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union -from graphql import ExecutionResult, GraphQLError, GraphQLSchema, OperationType -from graphql import format_error as format_error_default -from graphql import get_operation_ast, parse -from graphql.graphql import graphql, graphql_sync +from graphql.error import GraphQLError +from graphql.error import format_error as format_error_default +from graphql.execution import ExecutionResult, execute +from graphql.language import OperationType, parse from graphql.pyutils import AwaitableOrValue +from graphql.type import GraphQLSchema, validate_schema +from graphql.utilities import get_operation_ast +from graphql.validation import ASTValidationRule, validate from .error import HttpQueryError from .version import version, version_info @@ -223,36 +226,48 @@ def load_json_variables(variables: Optional[Union[str, Dict]]) -> Optional[Dict] return variables # type: ignore +def assume_not_awaitable(_value: Any) -> bool: + """Replacement for isawaitable if everything is assumed to be synchronous.""" + return False + + def get_response( schema: GraphQLSchema, params: GraphQLParams, catch_exc: Type[BaseException], allow_only_query: bool = False, run_sync: bool = True, + validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None, + max_errors: Optional[int] = None, **kwargs, ) -> Optional[AwaitableOrValue[ExecutionResult]]: """Get an individual execution result as response, with option to catch errors. - This does the same as graphql_impl() except that you can either - throw an error on the ExecutionResult if allow_only_query is set to True - or catch errors that belong to an exception class that you need to pass - as a parameter. + This will validate the schema (if the schema is used for the first time), + parse the query, check if this is a query if allow_only_query is set to True, + validate the query (optionally with additional validation rules and limiting + the number of errors), execute the request (asynchronously if run_sync is not + set to True), and return the ExecutionResult. You can also catch all errors that + belong to an exception class specified by catch_exc. """ - # noinspection PyBroadException try: if not params.query: raise HttpQueryError(400, "Must provide query string.") + schema_validation_errors = validate_schema(schema) + if schema_validation_errors: + return ExecutionResult(data=None, errors=schema_validation_errors) + + try: + document = parse(params.query) + except GraphQLError as e: + return ExecutionResult(data=None, errors=[e]) + except Exception as e: + e = GraphQLError(str(e), original_error=e) + return ExecutionResult(data=None, errors=[e]) + if allow_only_query: - # Parse document to check that only query operations are used - try: - document = parse(params.query) - except GraphQLError as e: - return ExecutionResult(data=None, errors=[e]) - except Exception as e: - e = GraphQLError(str(e), original_error=e) - return ExecutionResult(data=None, errors=[e]) operation_ast = get_operation_ast(document, params.operation_name) if operation_ast: operation = operation_ast.operation.value @@ -264,22 +279,21 @@ def get_response( headers={"Allow": "POST"}, ) - if run_sync: - execution_result = graphql_sync( - schema=schema, - source=params.query, - variable_values=params.variables, - operation_name=params.operation_name, - **kwargs, - ) - else: - execution_result = graphql( # type: ignore - schema=schema, - source=params.query, - variable_values=params.variables, - operation_name=params.operation_name, - **kwargs, - ) + validation_errors = validate( + schema, document, rules=validation_rules, max_errors=max_errors + ) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + + execution_result = execute( + schema, + document, + variable_values=params.variables, + operation_name=params.operation_name, + is_awaitable=assume_not_awaitable if run_sync else None, + **kwargs, + ) + except catch_exc: return None diff --git a/tests/test_query.py b/tests/test_query.py index 7f5ab6f..70f49ac 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,6 +2,7 @@ from graphql.error import GraphQLError from graphql.execution import ExecutionResult +from graphql.validation import ValidationRule from pytest import raises from graphql_server import ( @@ -123,6 +124,68 @@ def test_reports_validation_errors(): assert response.status_code == 400 +def test_reports_custom_validation_errors(): + class CustomValidationRule(ValidationRule): + def enter_field(self, node, *_args): + self.report_error(GraphQLError("Custom validation error.", node)) + + results, params = run_http_query( + schema, + "get", + {}, + query_data=dict(query="{ test }"), + validation_rules=[CustomValidationRule], + ) + + assert as_dicts(results) == [ + { + "data": None, + "errors": [ + { + "message": "Custom validation error.", + "locations": [{"line": 1, "column": 3}], + "path": None, + } + ], + } + ] + + response = encode_execution_results(results) + assert response.status_code == 400 + + +def test_reports_max_num_of_validation_errors(): + results, params = run_http_query( + schema, + "get", + {}, + query_data=dict(query="{ test, unknownOne, unknownTwo }"), + max_errors=1, + ) + + assert as_dicts(results) == [ + { + "data": None, + "errors": [ + { + "message": "Cannot query field 'unknownOne' on type 'QueryRoot'.", + "locations": [{"line": 1, "column": 9}], + "path": None, + }, + { + "message": "Too many validation errors, error limit reached." + " Validation aborted.", + "locations": None, + "path": None, + }, + ], + } + ] + + response = encode_execution_results(results) + assert response.status_code == 400 + + def test_non_dict_params_in_non_batch_query(): with raises(HttpQueryError) as exc_info: # noinspection PyTypeChecker