Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split parsing, validation and execution (#43) #53

Merged
merged 1 commit into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions graphql_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import json
from collections import namedtuple
from collections.abc import MutableMapping
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Type, Union

from graphql import ExecutionResult, GraphQLError, GraphQLSchema, OperationType
from graphql import format_error as format_error_default
from graphql import get_operation_ast, parse
from graphql.graphql import graphql, graphql_sync
from graphql.error import GraphQLError
from graphql.error import format_error as format_error_default
from graphql.execution import ExecutionResult, execute
from graphql.language import OperationType, parse
from graphql.pyutils import AwaitableOrValue
from graphql.type import GraphQLSchema, validate_schema
from graphql.utilities import get_operation_ast
from graphql.validation import ASTValidationRule, validate

from .error import HttpQueryError
from .version import version, version_info
Expand Down Expand Up @@ -223,36 +226,48 @@ def load_json_variables(variables: Optional[Union[str, Dict]]) -> Optional[Dict]
return variables # type: ignore


def assume_not_awaitable(_value: Any) -> bool:
"""Replacement for isawaitable if everything is assumed to be synchronous."""
return False


def get_response(
schema: GraphQLSchema,
params: GraphQLParams,
catch_exc: Type[BaseException],
allow_only_query: bool = False,
run_sync: bool = True,
validation_rules: Optional[Collection[Type[ASTValidationRule]]] = None,
max_errors: Optional[int] = None,
**kwargs,
) -> Optional[AwaitableOrValue[ExecutionResult]]:
"""Get an individual execution result as response, with option to catch errors.

This does the same as graphql_impl() except that you can either
throw an error on the ExecutionResult if allow_only_query is set to True
or catch errors that belong to an exception class that you need to pass
as a parameter.
This will validate the schema (if the schema is used for the first time),
parse the query, check if this is a query if allow_only_query is set to True,
validate the query (optionally with additional validation rules and limiting
the number of errors), execute the request (asynchronously if run_sync is not
set to True), and return the ExecutionResult. You can also catch all errors that
belong to an exception class specified by catch_exc.
"""

# noinspection PyBroadException
try:
if not params.query:
raise HttpQueryError(400, "Must provide query string.")

schema_validation_errors = validate_schema(schema)
if schema_validation_errors:
return ExecutionResult(data=None, errors=schema_validation_errors)

try:
document = parse(params.query)
except GraphQLError as e:
return ExecutionResult(data=None, errors=[e])
except Exception as e:
e = GraphQLError(str(e), original_error=e)
return ExecutionResult(data=None, errors=[e])

if allow_only_query:
# Parse document to check that only query operations are used
try:
document = parse(params.query)
except GraphQLError as e:
return ExecutionResult(data=None, errors=[e])
except Exception as e:
e = GraphQLError(str(e), original_error=e)
return ExecutionResult(data=None, errors=[e])
operation_ast = get_operation_ast(document, params.operation_name)
if operation_ast:
operation = operation_ast.operation.value
Expand All @@ -264,22 +279,21 @@ def get_response(
headers={"Allow": "POST"},
)

if run_sync:
execution_result = graphql_sync(
schema=schema,
source=params.query,
variable_values=params.variables,
operation_name=params.operation_name,
**kwargs,
)
else:
execution_result = graphql( # type: ignore
schema=schema,
source=params.query,
variable_values=params.variables,
operation_name=params.operation_name,
**kwargs,
)
validation_errors = validate(
schema, document, rules=validation_rules, max_errors=max_errors
)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)

execution_result = execute(
schema,
document,
variable_values=params.variables,
operation_name=params.operation_name,
is_awaitable=assume_not_awaitable if run_sync else None,
**kwargs,
)

except catch_exc:
return None

Expand Down
63 changes: 63 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from graphql.error import GraphQLError
from graphql.execution import ExecutionResult
from graphql.validation import ValidationRule
from pytest import raises

from graphql_server import (
Expand Down Expand Up @@ -123,6 +124,68 @@ def test_reports_validation_errors():
assert response.status_code == 400


def test_reports_custom_validation_errors():
class CustomValidationRule(ValidationRule):
def enter_field(self, node, *_args):
self.report_error(GraphQLError("Custom validation error.", node))

results, params = run_http_query(
schema,
"get",
{},
query_data=dict(query="{ test }"),
validation_rules=[CustomValidationRule],
)

assert as_dicts(results) == [
{
"data": None,
"errors": [
{
"message": "Custom validation error.",
"locations": [{"line": 1, "column": 3}],
"path": None,
}
],
}
]

response = encode_execution_results(results)
assert response.status_code == 400


def test_reports_max_num_of_validation_errors():
results, params = run_http_query(
schema,
"get",
{},
query_data=dict(query="{ test, unknownOne, unknownTwo }"),
max_errors=1,
)

assert as_dicts(results) == [
{
"data": None,
"errors": [
{
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 9}],
"path": None,
},
{
"message": "Too many validation errors, error limit reached."
" Validation aborted.",
"locations": None,
"path": None,
},
],
}
]

response = encode_execution_results(results)
assert response.status_code == 400


def test_non_dict_params_in_non_batch_query():
with raises(HttpQueryError) as exc_info:
# noinspection PyTypeChecker
Expand Down