Skip to content

Commit

Permalink
Split parsing, validation and execution (#43) (#53)
Browse files Browse the repository at this point in the history
Instead of graphql()/graphql_sync() we now call execute() directly.

This also allows adding custom validation rules and limiting the number
of reported errors.
  • Loading branch information
Cito authored Jul 11, 2020
1 parent db23e62 commit 90cfb09
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 34 deletions.
82 changes: 48 additions & 34 deletions graphql_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import json
from collections import namedtuple
from collections.abc import MutableMapping
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union

from graphql import ExecutionResult, GraphQLError, GraphQLSchema, OperationType
from graphql import format_error as format_error_default
from graphql import get_operation_ast, parse
from graphql.graphql import graphql, graphql_sync
from graphql.error import GraphQLError
from graphql.error import format_error as format_error_default
from graphql.execution import ExecutionResult, execute
from graphql.language import OperationType, parse
from graphql.pyutils import AwaitableOrValue
from graphql.type import GraphQLSchema, validate_schema
from graphql.utilities import get_operation_ast
from graphql.validation import ASTValidationRule, validate

from .error import HttpQueryError
from .version import version, version_info
Expand Down Expand Up @@ -223,36 +226,48 @@ def load_json_variables(variables: Optional[Union[str, Dict]]) -> Optional[Dict]
return variables # type: ignore


def assume_not_awaitable(_value: Any) -> bool:
"""Replacement for isawaitable if everything is assumed to be synchronous."""
return False


def get_response(
schema: GraphQLSchema,
params: GraphQLParams,
catch_exc: Type[BaseException],
allow_only_query: bool = False,
run_sync: bool = True,
validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None,
max_errors: Optional[int] = None,
**kwargs,
) -> Optional[AwaitableOrValue[ExecutionResult]]:
"""Get an individual execution result as response, with option to catch errors.
This does the same as graphql_impl() except that you can either
throw an error on the ExecutionResult if allow_only_query is set to True
or catch errors that belong to an exception class that you need to pass
as a parameter.
This will validate the schema (if the schema is used for the first time),
parse the query, check if this is a query if allow_only_query is set to True,
validate the query (optionally with additional validation rules and limiting
the number of errors), execute the request (asynchronously if run_sync is not
set to True), and return the ExecutionResult. You can also catch all errors that
belong to an exception class specified by catch_exc.
"""

# noinspection PyBroadException
try:
if not params.query:
raise HttpQueryError(400, "Must provide query string.")

schema_validation_errors = validate_schema(schema)
if schema_validation_errors:
return ExecutionResult(data=None, errors=schema_validation_errors)

try:
document = parse(params.query)
except GraphQLError as e:
return ExecutionResult(data=None, errors=[e])
except Exception as e:
e = GraphQLError(str(e), original_error=e)
return ExecutionResult(data=None, errors=[e])

if allow_only_query:
# Parse document to check that only query operations are used
try:
document = parse(params.query)
except GraphQLError as e:
return ExecutionResult(data=None, errors=[e])
except Exception as e:
e = GraphQLError(str(e), original_error=e)
return ExecutionResult(data=None, errors=[e])
operation_ast = get_operation_ast(document, params.operation_name)
if operation_ast:
operation = operation_ast.operation.value
Expand All @@ -264,22 +279,21 @@ def get_response(
headers={"Allow": "POST"},
)

if run_sync:
execution_result = graphql_sync(
schema=schema,
source=params.query,
variable_values=params.variables,
operation_name=params.operation_name,
**kwargs,
)
else:
execution_result = graphql( # type: ignore
schema=schema,
source=params.query,
variable_values=params.variables,
operation_name=params.operation_name,
**kwargs,
)
validation_errors = validate(
schema, document, rules=validation_rules, max_errors=max_errors
)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)

execution_result = execute(
schema,
document,
variable_values=params.variables,
operation_name=params.operation_name,
is_awaitable=assume_not_awaitable if run_sync else None,
**kwargs,
)

except catch_exc:
return None

Expand Down
63 changes: 63 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from graphql.error import GraphQLError
from graphql.execution import ExecutionResult
from graphql.validation import ValidationRule
from pytest import raises

from graphql_server import (
Expand Down Expand Up @@ -123,6 +124,68 @@ def test_reports_validation_errors():
assert response.status_code == 400


def test_reports_custom_validation_errors():
class CustomValidationRule(ValidationRule):
def enter_field(self, node, *_args):
self.report_error(GraphQLError("Custom validation error.", node))

results, params = run_http_query(
schema,
"get",
{},
query_data=dict(query="{ test }"),
validation_rules=[CustomValidationRule],
)

assert as_dicts(results) == [
{
"data": None,
"errors": [
{
"message": "Custom validation error.",
"locations": [{"line": 1, "column": 3}],
"path": None,
}
],
}
]

response = encode_execution_results(results)
assert response.status_code == 400


def test_reports_max_num_of_validation_errors():
results, params = run_http_query(
schema,
"get",
{},
query_data=dict(query="{ test, unknownOne, unknownTwo }"),
max_errors=1,
)

assert as_dicts(results) == [
{
"data": None,
"errors": [
{
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 9}],
"path": None,
},
{
"message": "Too many validation errors, error limit reached."
" Validation aborted.",
"locations": None,
"path": None,
},
],
}
]

response = encode_execution_results(results)
assert response.status_code == 400


def test_non_dict_params_in_non_batch_query():
with raises(HttpQueryError) as exc_info:
# noinspection PyTypeChecker
Expand Down

0 comments on commit 90cfb09

Please sign in to comment.