diff --git a/graphql/__init__.py b/graphql/__init__.py index 08f25683..bf7e17b0 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -193,6 +193,16 @@ Undefined, ) +# Utilities for dynamic execution engines +from .backend import ( + GraphQLBackend, + GraphQLDocument, + GraphQLCoreBackend, + GraphQLDeciderBackend, + GraphQLCachedBackend, + get_default_backend, + set_default_backend, +) VERSION = (2, 0, 1, 'final', 0) __version__ = get_version(VERSION) @@ -282,4 +292,11 @@ 'value_from_ast', 'get_version', 'Undefined', + 'GraphQLBackend', + 'GraphQLDocument', + 'GraphQLCoreBackend', + 'GraphQLDeciderBackend', + 'GraphQLCachedBackend', + 'get_default_backend', + 'set_default_backend', ) diff --git a/graphql/backend/__init__.py b/graphql/backend/__init__.py new file mode 100644 index 00000000..bda60818 --- /dev/null +++ b/graphql/backend/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +""" +This module provides a dynamic way of using different +engines for a GraphQL schema query resolution. +""" + +from .base import GraphQLBackend, GraphQLDocument +from .core import GraphQLCoreBackend +from .decider import GraphQLDeciderBackend +from .cache import GraphQLCachedBackend + +_default_backend = None + + +def get_default_backend(): + global _default_backend + if _default_backend is None: + _default_backend = GraphQLCoreBackend() + return _default_backend + + +def set_default_backend(backend): + global _default_backend + assert isinstance( + backend, GraphQLBackend + ), "backend must be an instance of GraphQLBackend." + _default_backend = backend + + +__all__ = [ + "GraphQLBackend", + "GraphQLDocument", + "GraphQLCoreBackend", + "GraphQLDeciderBackend", + "GraphQLCachedBackend", + "get_default_backend", + "set_default_backend", +] diff --git a/graphql/backend/base.py b/graphql/backend/base.py new file mode 100644 index 00000000..d178abf5 --- /dev/null +++ b/graphql/backend/base.py @@ -0,0 +1,31 @@ +from ..language import ast +from abc import ABCMeta, abstractmethod +import six + + +class GraphQLBackend(six.with_metaclass(ABCMeta)): + @abstractmethod + def document_from_string(self, schema, request_string): + raise NotImplementedError( + "document_from_string method not implemented in {}.".format(self.__class__) + ) + + +class GraphQLDocument(object): + def __init__(self, schema, document_string, document_ast, execute): + self.schema = schema + self.document_string = document_string + self.document_ast = document_ast + self.execute = execute + + def get_operations(self): + document_ast = self.document_ast + operations = {} + for definition in document_ast.definitions: + if isinstance(definition, ast.OperationDefinition): + if definition.name: + operation_name = definition.name.value + else: + operation_name = None + operations[operation_name] = definition.operation + return operations diff --git a/graphql/backend/cache.py b/graphql/backend/cache.py new file mode 100644 index 00000000..aa1631fb --- /dev/null +++ b/graphql/backend/cache.py @@ -0,0 +1,61 @@ +from hashlib import sha1 +from six import string_types +from ..type import GraphQLSchema + +from .base import GraphQLBackend + +_cached_schemas = {} + +_cached_queries = {} + + +def get_unique_schema_id(schema): + """Get a unique id given a GraphQLSchema""" + assert isinstance(schema, GraphQLSchema), ( + "Must receive a GraphQLSchema as schema. Received {}" + ).format(repr(schema)) + + if schema not in _cached_schemas: + _cached_schemas[schema] = sha1(str(schema).encode("utf-8")).hexdigest() + return _cached_schemas[schema] + + +def get_unique_document_id(query_str): + """Get a unique id given a query_string""" + assert isinstance(query_str, string_types), ( + "Must receive a string as query_str. Received {}" + ).format(repr(query_str)) + + if query_str not in _cached_queries: + _cached_queries[query_str] = sha1(str(query_str).encode("utf-8")).hexdigest() + return _cached_queries[query_str] + + +class GraphQLCachedBackend(GraphQLBackend): + def __init__(self, backend, cache_map=None, use_consistent_hash=False): + assert isinstance( + backend, GraphQLBackend + ), "Provided backend must be an instance of GraphQLBackend" + if cache_map is None: + cache_map = {} + self.backend = backend + self.cache_map = cache_map + self.use_consistent_hash = use_consistent_hash + + def get_key_for_schema_and_document_string(self, schema, request_string): + """This method returns a unique key given a schema and a request_string""" + if self.use_consistent_hash: + schema_id = get_unique_schema_id(schema) + document_id = get_unique_document_id(request_string) + return (schema_id, document_id) + return hash((schema, request_string)) + + def document_from_string(self, schema, request_string): + """This method returns a GraphQLQuery (from cache if present)""" + key = self.get_key_for_schema_and_document_string(schema, request_string) + if key not in self.cache_map: + self.cache_map[key] = self.backend.document_from_string( + schema, request_string + ) + + return self.cache_map[key] diff --git a/graphql/backend/compiled.py b/graphql/backend/compiled.py new file mode 100644 index 00000000..9f579252 --- /dev/null +++ b/graphql/backend/compiled.py @@ -0,0 +1,37 @@ +from .base import GraphQLDocument + + +class GraphQLCompiledDocument(GraphQLDocument): + @classmethod + def from_code(cls, schema, code, uptodate=None, extra_namespace=None): + """Creates a GraphQLDocument object from compiled code and the globals. This + is used by the loaders and schema to create a document object. + """ + namespace = {"__file__": code.co_filename} + exec(code, namespace) + if extra_namespace: + namespace.update(extra_namespace) + rv = cls._from_namespace(schema, namespace) + rv._uptodate = uptodate + return rv + + @classmethod + def from_module_dict(cls, schema, module_dict): + """Creates a template object from a module. This is used by the + module loader to create a document object. + """ + return cls._from_namespace(schema, module_dict) + + @classmethod + def _from_namespace(cls, schema, namespace): + document_string = namespace.get("document_string", "") + document_ast = namespace.get("document_ast") + execute = namespace["execute"] + + namespace["schema"] = schema + return cls( + schema=schema, + document_string=document_string, + document_ast=document_ast, + execute=execute, + ) diff --git a/graphql/backend/core.py b/graphql/backend/core.py new file mode 100644 index 00000000..7cbb9d3e --- /dev/null +++ b/graphql/backend/core.py @@ -0,0 +1,42 @@ +from functools import partial +from six import string_types + +from ..execution import execute, ExecutionResult +from ..language.base import parse, print_ast +from ..language import ast +from ..validation import validate + +from .base import GraphQLBackend, GraphQLDocument + + +def execute_and_validate(schema, document_ast, *args, **kwargs): + do_validation = kwargs.get('validate', True) + if do_validation: + validation_errors = validate(schema, document_ast) + if validation_errors: + return ExecutionResult( + errors=validation_errors, + invalid=True, + ) + + return execute(schema, document_ast, *args, **kwargs) + + +class GraphQLCoreBackend(GraphQLBackend): + def __init__(self, executor=None, **kwargs): + super(GraphQLCoreBackend, self).__init__(**kwargs) + self.execute_params = {"executor": executor} + + def document_from_string(self, schema, document_string): + if isinstance(document_string, ast.Document): + document_ast = document_string + document_string = print_ast(document_ast) + else: + assert isinstance(document_string, string_types), "The query must be a string" + document_ast = parse(document_string) + return GraphQLDocument( + schema=schema, + document_string=document_string, + document_ast=document_ast, + execute=partial(execute_and_validate, schema, document_ast, **self.execute_params), + ) diff --git a/graphql/backend/decider.py b/graphql/backend/decider.py new file mode 100644 index 00000000..e923b0ce --- /dev/null +++ b/graphql/backend/decider.py @@ -0,0 +1,24 @@ +from .base import GraphQLBackend + + +class GraphQLDeciderBackend(GraphQLBackend): + def __init__(self, backends=None): + if not backends: + raise Exception("Need to provide backends to decide into.") + if not isinstance(backends, (list, tuple)): + raise Exception("Provided backends need to be a list or tuple.") + self.backends = backends + super(GraphQLDeciderBackend, self).__init__() + + def document_from_string(self, schema, request_string): + for backend in self.backends: + try: + return backend.document_from_string(schema, request_string) + except Exception: + continue + + raise Exception( + "GraphQLDeciderBackend was not able to retrieve a document. Backends tried: {}".format( + repr(self.backends) + ) + ) diff --git a/graphql/backend/quiver_cloud.py b/graphql/backend/quiver_cloud.py new file mode 100644 index 00000000..ca198dee --- /dev/null +++ b/graphql/backend/quiver_cloud.py @@ -0,0 +1,103 @@ +try: + import requests +except ImportError: + raise ImportError( + "requests package is required for Quiver Cloud backend.\n" + "You can install it using: pip install requests" + ) + +from ..utils.schema_printer import print_schema + +from .base import GraphQLBackend +from .compiled import GraphQLCompiledDocument + +from six import urlparse + +GRAPHQL_QUERY = """ +mutation($schemaDsl: String!, $query: String!) { + generateCode( + schemaDsl: $schemaDsl + query: $query, + language: PYTHON, + pythonOptions: { + asyncFramework: PROMISE + } + ) { + code + compilationTime + errors { + type + } + } +} +""" + + +class GraphQLQuiverCloudBackend(GraphQLBackend): + def __init__(self, dsn, python_options=None, **options): + super(GraphQLQuiverCloudBackend, self).__init__(**options) + try: + url = urlparse(dsn.strip()) + except Exception: + raise Exception("Received wrong url {}".format(dsn)) + + netloc = url.hostname + if url.port: + netloc += ":%s" % url.port + + path_bits = url.path.rsplit("/", 1) + if len(path_bits) > 1: + path = path_bits[0] + else: + path = "" + + self.api_url = "%s://%s%s" % (url.scheme.rsplit("+", 1)[-1], netloc, path) + self.public_key = url.username + self.secret_key = url.password + self.extra_namespace = {} + if python_options is None: + python_options = {} + wait_for_promises = python_options.pop("wait_for_promises", None) + if wait_for_promises: + assert callable(wait_for_promises), "wait_for_promises must be callable." + self.extra_namespace["wait_for_promises"] = wait_for_promises + self.python_options = python_options + + def make_post_request(self, url, auth, json_payload): + """This function executes the request with the provided + json payload and return the json response""" + response = requests.post(url, auth=auth, json=json_payload) + return response.json() + + def generate_source(self, schema, query): + variables = {"schemaDsl": print_schema(schema), "query": query} + + json_response = self.make_post_request( + "{}/graphql".format(self.api_url), + auth=(self.public_key, self.secret_key), + json_payload={"query": GRAPHQL_QUERY, "variables": variables}, + ) + + errors = json_response.get('errors') + if errors: + raise Exception(errors[0].get('message')) + data = json_response.get("data", {}) + code_generation = data.get("generateCode", {}) + code = code_generation.get("code") + if not code: + raise Exception("Cant get the code. Received json from Quiver Cloud") + code = str(code) + return code + + def document_from_string(self, schema, request_string): + source = self.generate_source(schema, request_string) + filename = "" + code = compile(source, filename, "exec") + + def uptodate(): + return True + + document = GraphQLCompiledDocument.from_code( + schema, code, uptodate, self.extra_namespace + ) + return document diff --git a/graphql/backend/tests/__init__.py b/graphql/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphql/backend/tests/schema.py b/graphql/backend/tests/schema.py new file mode 100644 index 00000000..a01f0d62 --- /dev/null +++ b/graphql/backend/tests/schema.py @@ -0,0 +1,9 @@ +from graphql.type import (GraphQLField, GraphQLObjectType, + GraphQLSchema, GraphQLString) + + +Query = GraphQLObjectType('Query', lambda: { + 'hello': GraphQLField(GraphQLString, resolver=lambda *_: "World"), +}) + +schema = GraphQLSchema(Query) diff --git a/graphql/backend/tests/test_cache.py b/graphql/backend/tests/test_cache.py new file mode 100644 index 00000000..1aacea02 --- /dev/null +++ b/graphql/backend/tests/test_cache.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Tests for `graphql.backend.cache` module.""" + +import pytest + +from ..core import GraphQLCoreBackend +from ..cache import GraphQLCachedBackend +from graphql.execution.executors.sync import SyncExecutor +from .schema import schema + + +def test_backend_is_cached_when_needed(): + cached_backend = GraphQLCachedBackend(GraphQLCoreBackend()) + document1 = cached_backend.document_from_string(schema, "{ hello }") + document2 = cached_backend.document_from_string(schema, "{ hello }") + assert document1 == document2 diff --git a/graphql/backend/tests/test_core.py b/graphql/backend/tests/test_core.py new file mode 100644 index 00000000..43fb8357 --- /dev/null +++ b/graphql/backend/tests/test_core.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Tests for `graphql.backend.core` module.""" + +import pytest +from graphql.execution.executors.sync import SyncExecutor + +from ..base import GraphQLBackend, GraphQLDocument +from ..core import GraphQLCoreBackend +from .schema import schema + + +def test_core_backend(): + """Sample pytest test function with the pytest fixture as an argument.""" + backend = GraphQLCoreBackend() + assert isinstance(backend, GraphQLBackend) + document = backend.document_from_string(schema, "{ hello }") + assert isinstance(document, GraphQLDocument) + result = document.execute() + assert not result.errors + assert result.data == {"hello": "World"} + + +def test_backend_is_not_cached_by_default(): + """Sample pytest test function with the pytest fixture as an argument.""" + backend = GraphQLCoreBackend() + document1 = backend.document_from_string(schema, "{ hello }") + document2 = backend.document_from_string(schema, "{ hello }") + assert document1 != document2 + + +class BaseExecutor(SyncExecutor): + executed = False + + def execute(self, *args, **kwargs): + self.executed = True + return super(BaseExecutor, self).execute(*args, **kwargs) + + +def test_backend_can_execute_custom_executor(): + executor = BaseExecutor() + backend = GraphQLCoreBackend(executor=executor) + document1 = backend.document_from_string(schema, "{ hello }") + result = document1.execute() + assert not result.errors + assert result.data == {"hello": "World"} + assert executor.executed diff --git a/graphql/backend/tests/test_decider.py b/graphql/backend/tests/test_decider.py new file mode 100644 index 00000000..7ec38a4b --- /dev/null +++ b/graphql/backend/tests/test_decider.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Tests for `graphql.backend.decider` module.""" + +import pytest + +from ..base import GraphQLBackend, GraphQLDocument +from ..core import GraphQLCoreBackend +from ..cache import GraphQLCachedBackend +from ..decider import GraphQLDeciderBackend + +from .schema import schema + + +class FakeBackend(GraphQLBackend): + reached = False + + def __init__(self, raises=False): + self.raises = raises + + def document_from_string(self, *args, **kwargs): + self.reached = True + if self.raises: + raise Exception("Backend failed") + + def reset(self): + self.reached = False + + +def test_decider_backend_healthy_backend(): + backend1 = FakeBackend() + backend2 = FakeBackend() + decider_backend = GraphQLDeciderBackend([backend1, backend2]) + + decider_backend.document_from_string(schema, "{ hello }") + assert backend1.reached + assert not backend2.reached + + +def test_decider_backend_unhealthy_backend(): + backend1 = FakeBackend(raises=True) + backend2 = FakeBackend() + decider_backend = GraphQLDeciderBackend([backend1, backend2]) + + decider_backend.document_from_string(schema, "{ hello }") + assert backend1.reached + assert backend2.reached + + +def test_decider_backend_dont_use_cache(): + backend1 = FakeBackend() + backend2 = FakeBackend() + decider_backend = GraphQLDeciderBackend([backend1, backend2]) + + decider_backend.document_from_string(schema, "{ hello }") + assert backend1.reached + assert not backend2.reached + + backend1.reset() + decider_backend.document_from_string(schema, "{ hello }") + assert backend1.reached + + +def test_decider_backend_use_cache_if_provided(): + backend1 = FakeBackend() + backend2 = FakeBackend() + decider_backend = GraphQLDeciderBackend( + [GraphQLCachedBackend(backend1), GraphQLCachedBackend(backend2)] + ) + + decider_backend.document_from_string(schema, "{ hello }") + assert backend1.reached + assert not backend2.reached + + backend1.reset() + decider_backend.document_from_string(schema, "{ hello }") + assert not backend1.reached diff --git a/graphql/execution/base.py b/graphql/execution/base.py index 21d7378c..266771ac 100644 --- a/graphql/execution/base.py +++ b/graphql/execution/base.py @@ -1,127 +1,15 @@ -# -*- coding: utf-8 -*- -import logging -from traceback import format_exception - -from ..error import GraphQLError -from ..language import ast -from ..pyutils.default_ordered_dict import DefaultOrderedDict -from ..type.definition import GraphQLInterfaceType, GraphQLUnionType -from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective -from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, - TypeNameMetaFieldDef) -from ..utils.type_from_ast import type_from_ast -from .values import get_argument_values, get_variable_values - -logger = logging.getLogger(__name__) - - -class ExecutionContext(object): - """Data that must be available at all points during query execution. - - Namely, schema of the type system that is currently executing, - and the fragments defined in the query document""" - - __slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \ - 'argument_values_cache', 'executor', 'middleware', 'allow_subscriptions', '_subfields_cache' - - def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware, allow_subscriptions): - """Constructs a ExecutionContext object from the arguments passed - to execute, which we will pass throughout the other execution - methods.""" - errors = [] - operation = None - fragments = {} - - for definition in document_ast.definitions: - if isinstance(definition, ast.OperationDefinition): - if not operation_name and operation: - raise GraphQLError( - 'Must provide operation name if query contains multiple operations.') - - if not operation_name or definition.name and definition.name.value == operation_name: - operation = definition - - elif isinstance(definition, ast.FragmentDefinition): - fragments[definition.name.value] = definition - - else: - raise GraphQLError( - u'GraphQL cannot execute a request containing a {}.'.format( - definition.__class__.__name__), - definition - ) - - if not operation: - if operation_name: - raise GraphQLError( - u'Unknown operation named "{}".'.format(operation_name)) - - else: - raise GraphQLError('Must provide an operation.') - - variable_values = get_variable_values( - schema, operation.variable_definitions or [], variable_values) - - self.schema = schema - self.fragments = fragments - self.root_value = root_value - self.operation = operation - self.variable_values = variable_values - self.errors = errors - self.context_value = context_value - self.argument_values_cache = {} - self.executor = executor - self.middleware = middleware - self.allow_subscriptions = allow_subscriptions - self._subfields_cache = {} - - def get_field_resolver(self, field_resolver): - if not self.middleware: - return field_resolver - return self.middleware.get_field_resolver(field_resolver) - - def get_argument_values(self, field_def, field_ast): - k = field_def, field_ast - result = self.argument_values_cache.get(k) - if not result: - result = self.argument_values_cache[k] = get_argument_values(field_def.args, field_ast.arguments, - self.variable_values) - - return result - - def report_error(self, error, traceback=None): - exception = format_exception(type(error), error, getattr(error, 'stack', None) or traceback) - logger.error(''.join(exception)) - self.errors.append(error) - - def get_sub_fields(self, return_type, field_asts): - k = return_type, tuple(field_asts) - if k not in self._subfields_cache: - subfield_asts = DefaultOrderedDict(list) - visited_fragment_names = set() - for field_ast in field_asts: - selection_set = field_ast.selection_set - if selection_set: - subfield_asts = collect_fields( - self, return_type, selection_set, - subfield_asts, visited_fragment_names - ) - self._subfields_cache[k] = subfield_asts - return self._subfields_cache[k] - - -class SubscriberExecutionContext(object): - __slots__ = 'exe_context', 'errors' - - def __init__(self, exe_context): - self.exe_context = exe_context - self.errors = [] - - def reset(self): - self.errors = [] - - def __getattr__(self, name): - return getattr(self.exe_context, name) +# We keep the following imports to preserve compatibility +from .utils import ( + ExecutionContext, + SubscriberExecutionContext, + get_operation_root_type, + collect_fields, + should_include_node, + does_fragment_condition_match, + get_field_entry_key, + default_resolve_fn, + get_field_def +) class ExecutionResult(object): @@ -152,156 +40,12 @@ def __eq__(self, other): ) -def get_operation_root_type(schema, operation): - op = operation.operation - if op == 'query': - return schema.get_query_type() - - elif op == 'mutation': - mutation_type = schema.get_mutation_type() - - if not mutation_type: - raise GraphQLError( - 'Schema is not configured for mutations', - [operation] - ) - - return mutation_type - - elif op == 'subscription': - subscription_type = schema.get_subscription_type() - - if not subscription_type: - raise GraphQLError( - 'Schema is not configured for subscriptions', - [operation] - ) - - return subscription_type - - raise GraphQLError( - 'Can only execute queries, mutations and subscriptions', - [operation] - ) - - -def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names): - """ - Given a selectionSet, adds all of the fields in that selection to - the passed in map of fields, and returns it at the end. - - collect_fields requires the "runtime type" of an object. For a field which - returns and Interface or Union type, the "runtime type" will be the actual - Object type returned by that field. - """ - for selection in selection_set.selections: - directives = selection.directives - - if isinstance(selection, ast.Field): - if not should_include_node(ctx, directives): - continue - - name = get_field_entry_key(selection) - fields[name].append(selection) - - elif isinstance(selection, ast.InlineFragment): - if not should_include_node( - ctx, directives) or not does_fragment_condition_match( - ctx, selection, runtime_type): - continue - - collect_fields(ctx, runtime_type, - selection.selection_set, fields, prev_fragment_names) - - elif isinstance(selection, ast.FragmentSpread): - frag_name = selection.name.value - - if frag_name in prev_fragment_names or not should_include_node(ctx, directives): - continue - - prev_fragment_names.add(frag_name) - fragment = ctx.fragments.get(frag_name) - frag_directives = fragment.directives - if not fragment or not \ - should_include_node(ctx, frag_directives) or not \ - does_fragment_condition_match(ctx, fragment, runtime_type): - continue - - collect_fields(ctx, runtime_type, - fragment.selection_set, fields, prev_fragment_names) - - return fields - - -def should_include_node(ctx, directives): - """Determines if a field should be included based on the @include and - @skip directives, where @skip has higher precidence than @include.""" - # TODO: Refactor based on latest code - if directives: - skip_ast = None - - for directive in directives: - if directive.name.value == GraphQLSkipDirective.name: - skip_ast = directive - break - - if skip_ast: - args = get_argument_values( - GraphQLSkipDirective.args, - skip_ast.arguments, - ctx.variable_values, - ) - if args.get('if') is True: - return False - - include_ast = None - - for directive in directives: - if directive.name.value == GraphQLIncludeDirective.name: - include_ast = directive - break - - if include_ast: - args = get_argument_values( - GraphQLIncludeDirective.args, - include_ast.arguments, - ctx.variable_values, - ) - - if args.get('if') is False: - return False - - return True - - -def does_fragment_condition_match(ctx, fragment, type_): - type_condition_ast = fragment.type_condition - if not type_condition_ast: - return True - - conditional_type = type_from_ast(ctx.schema, type_condition_ast) - if conditional_type.is_same_type(type_): - return True - - if isinstance(conditional_type, (GraphQLInterfaceType, GraphQLUnionType)): - return ctx.schema.is_possible_type(conditional_type, type_) - - return False - - -def get_field_entry_key(node): - """Implements the logic to compute the key of a given field's entry""" - if node.alias: - return node.alias.value - return node.name.value - - class ResolveInfo(object): __slots__ = ('field_name', 'field_asts', 'return_type', 'parent_type', 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'context', 'path') def __init__(self, field_name, field_asts, return_type, parent_type, - schema, fragments, root_value, operation, variable_values, context, path): + schema, fragments, root_value, operation, variable_values, context, path=None): self.field_name = field_name self.field_asts = field_asts self.return_type = return_type @@ -315,28 +59,16 @@ def __init__(self, field_name, field_asts, return_type, parent_type, self.path = path -def default_resolve_fn(source, info, **args): - """If a resolve function is not given, then a default resolve behavior is used which takes the property of the source object - of the same name as the field and returns it as the result, or if it's a function, returns the result of calling that function.""" - name = info.field_name - property = getattr(source, name, None) - if callable(property): - return property() - return property - - -def get_field_def(schema, parent_type, field_name): - """This method looks up the field on the given type defintion. - It has special casing for the two introspection fields, __schema - and __typename. __typename is special because it can always be - queried as a field, even in situations where no other fields - are allowed, like on a Union. __schema could get automatically - added to the query type, but that would require mutating type - definitions, which would cause issues.""" - if field_name == '__schema' and schema.get_query_type() == parent_type: - return SchemaMetaFieldDef - elif field_name == '__type' and schema.get_query_type() == parent_type: - return TypeMetaFieldDef - elif field_name == '__typename': - return TypeNameMetaFieldDef - return parent_type.fields.get(field_name) +__all__ = [ + 'ExecutionResult', + 'ResolveInfo', + 'ExecutionContext', + 'SubscriberExecutionContext', + 'get_operation_root_type', + 'collect_fields', + 'should_include_node', + 'does_fragment_condition_match', + 'get_field_entry_key', + 'default_resolve_fn', + 'get_field_def', +] diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 31428ee6..82a23b87 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -2,6 +2,7 @@ import functools import logging import sys +import warnings from rx import Observable from six import string_types @@ -28,9 +29,31 @@ def subscribe(*args, **kwargs): return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs) -def execute(schema, document_ast, root_value=None, context_value=None, - variable_values=None, operation_name=None, executor=None, - return_promise=False, middleware=None, allow_subscriptions=False): +def execute(schema, document_ast, root=None, context=None, + variables=None, operation_name=None, executor=None, + return_promise=False, middleware=None, allow_subscriptions=False, **options): + + if root is None and 'root_value' in options: + warnings.warn( + 'root_value has been deprecated. Please use root=... instead.', + category=DeprecationWarning, + stacklevel=2 + ) + root = options['root_value'] + if context is None and 'context_value' in options: + warnings.warn( + 'context_value has been deprecated. Please use context=... instead.', + category=DeprecationWarning, + stacklevel=2 + ) + context = options['context_value'] + if variables is None and 'variable_values' in options: + warnings.warn( + 'variable_values has been deprecated. Please use values=... instead.', + category=DeprecationWarning, + stacklevel=2 + ) + variables = options['variable_values'] assert schema, 'Must provide schema' assert isinstance(schema, GraphQLSchema), ( 'Schema must be an instance of GraphQLSchema. Also ensure that there are ' + @@ -49,12 +72,12 @@ def execute(schema, document_ast, root_value=None, context_value=None, if executor is None: executor = SyncExecutor() - context = ExecutionContext( + exe_context = ExecutionContext( schema, document_ast, - root_value, - context_value, - variable_values, + root, + context, + variables or {}, operation_name, executor, middleware, @@ -62,25 +85,25 @@ def execute(schema, document_ast, root_value=None, context_value=None, ) def executor(v): - return execute_operation(context, context.operation, root_value) + return execute_operation(exe_context, exe_context.operation, root) def on_rejected(error): - context.errors.append(error) + exe_context.errors.append(error) return None def on_resolve(data): if isinstance(data, Observable): return data - if not context.errors: + if not exe_context.errors: return ExecutionResult(data=data) - return ExecutionResult(data=data, errors=context.errors) + return ExecutionResult(data=data, errors=exe_context.errors) promise = Promise.resolve(None).then(executor).catch(on_rejected).then(on_resolve) if not return_promise: - context.executor.wait_until_finished() + exe_context.executor.wait_until_finished() return promise.get() return promise diff --git a/graphql/execution/utils.py b/graphql/execution/utils.py new file mode 100644 index 00000000..a9da6bac --- /dev/null +++ b/graphql/execution/utils.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +import logging +from traceback import format_exception + +from ..error import GraphQLError +from ..language import ast +from ..pyutils.default_ordered_dict import DefaultOrderedDict +from ..type.definition import GraphQLInterfaceType, GraphQLUnionType +from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective +from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, + TypeNameMetaFieldDef) +from ..utils.type_from_ast import type_from_ast +from .values import get_argument_values, get_variable_values + +logger = logging.getLogger(__name__) + + +class ExecutionContext(object): + """Data that must be available at all points during query execution. + + Namely, schema of the type system that is currently executing, + and the fragments defined in the query document""" + + __slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \ + 'argument_values_cache', 'executor', 'middleware', 'allow_subscriptions', '_subfields_cache' + + def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware, allow_subscriptions): + """Constructs a ExecutionContext object from the arguments passed + to execute, which we will pass throughout the other execution + methods.""" + errors = [] + operation = None + fragments = {} + + for definition in document_ast.definitions: + if isinstance(definition, ast.OperationDefinition): + if not operation_name and operation: + raise GraphQLError( + 'Must provide operation name if query contains multiple operations.') + + if not operation_name or definition.name and definition.name.value == operation_name: + operation = definition + + elif isinstance(definition, ast.FragmentDefinition): + fragments[definition.name.value] = definition + + else: + raise GraphQLError( + u'GraphQL cannot execute a request containing a {}.'.format( + definition.__class__.__name__), + definition + ) + + if not operation: + if operation_name: + raise GraphQLError( + u'Unknown operation named "{}".'.format(operation_name)) + + else: + raise GraphQLError('Must provide an operation.') + + variable_values = get_variable_values( + schema, operation.variable_definitions or [], variable_values) + + self.schema = schema + self.fragments = fragments + self.root_value = root_value + self.operation = operation + self.variable_values = variable_values + self.errors = errors + self.context_value = context_value + self.argument_values_cache = {} + self.executor = executor + self.middleware = middleware + self.allow_subscriptions = allow_subscriptions + self._subfields_cache = {} + + def get_field_resolver(self, field_resolver): + if not self.middleware: + return field_resolver + return self.middleware.get_field_resolver(field_resolver) + + def get_argument_values(self, field_def, field_ast): + k = field_def, field_ast + result = self.argument_values_cache.get(k) + if not result: + result = self.argument_values_cache[k] = get_argument_values(field_def.args, field_ast.arguments, + self.variable_values) + + return result + + def report_error(self, error, traceback=None): + exception = format_exception(type(error), error, getattr(error, 'stack', None) or traceback) + logger.error(''.join(exception)) + self.errors.append(error) + + def get_sub_fields(self, return_type, field_asts): + k = return_type, tuple(field_asts) + if k not in self._subfields_cache: + subfield_asts = DefaultOrderedDict(list) + visited_fragment_names = set() + for field_ast in field_asts: + selection_set = field_ast.selection_set + if selection_set: + subfield_asts = collect_fields( + self, return_type, selection_set, + subfield_asts, visited_fragment_names + ) + self._subfields_cache[k] = subfield_asts + return self._subfields_cache[k] + + +class SubscriberExecutionContext(object): + __slots__ = 'exe_context', 'errors' + + def __init__(self, exe_context): + self.exe_context = exe_context + self.errors = [] + + def reset(self): + self.errors = [] + + def __getattr__(self, name): + return getattr(self.exe_context, name) + + +def get_operation_root_type(schema, operation): + op = operation.operation + if op == 'query': + return schema.get_query_type() + + elif op == 'mutation': + mutation_type = schema.get_mutation_type() + + if not mutation_type: + raise GraphQLError( + 'Schema is not configured for mutations', + [operation] + ) + + return mutation_type + + elif op == 'subscription': + subscription_type = schema.get_subscription_type() + + if not subscription_type: + raise GraphQLError( + 'Schema is not configured for subscriptions', + [operation] + ) + + return subscription_type + + raise GraphQLError( + 'Can only execute queries, mutations and subscriptions', + [operation] + ) + + +def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names): + """ + Given a selectionSet, adds all of the fields in that selection to + the passed in map of fields, and returns it at the end. + + collect_fields requires the "runtime type" of an object. For a field which + returns and Interface or Union type, the "runtime type" will be the actual + Object type returned by that field. + """ + for selection in selection_set.selections: + directives = selection.directives + + if isinstance(selection, ast.Field): + if not should_include_node(ctx, directives): + continue + + name = get_field_entry_key(selection) + fields[name].append(selection) + + elif isinstance(selection, ast.InlineFragment): + if not should_include_node( + ctx, directives) or not does_fragment_condition_match( + ctx, selection, runtime_type): + continue + + collect_fields(ctx, runtime_type, + selection.selection_set, fields, prev_fragment_names) + + elif isinstance(selection, ast.FragmentSpread): + frag_name = selection.name.value + + if frag_name in prev_fragment_names or not should_include_node(ctx, directives): + continue + + prev_fragment_names.add(frag_name) + fragment = ctx.fragments.get(frag_name) + frag_directives = fragment.directives + if not fragment or not \ + should_include_node(ctx, frag_directives) or not \ + does_fragment_condition_match(ctx, fragment, runtime_type): + continue + + collect_fields(ctx, runtime_type, + fragment.selection_set, fields, prev_fragment_names) + + return fields + + +def should_include_node(ctx, directives): + """Determines if a field should be included based on the @include and + @skip directives, where @skip has higher precidence than @include.""" + # TODO: Refactor based on latest code + if directives: + skip_ast = None + + for directive in directives: + if directive.name.value == GraphQLSkipDirective.name: + skip_ast = directive + break + + if skip_ast: + args = get_argument_values( + GraphQLSkipDirective.args, + skip_ast.arguments, + ctx.variable_values, + ) + if args.get('if') is True: + return False + + include_ast = None + + for directive in directives: + if directive.name.value == GraphQLIncludeDirective.name: + include_ast = directive + break + + if include_ast: + args = get_argument_values( + GraphQLIncludeDirective.args, + include_ast.arguments, + ctx.variable_values, + ) + + if args.get('if') is False: + return False + + return True + + +def does_fragment_condition_match(ctx, fragment, type_): + type_condition_ast = fragment.type_condition + if not type_condition_ast: + return True + + conditional_type = type_from_ast(ctx.schema, type_condition_ast) + if conditional_type.is_same_type(type_): + return True + + if isinstance(conditional_type, (GraphQLInterfaceType, GraphQLUnionType)): + return ctx.schema.is_possible_type(conditional_type, type_) + + return False + + +def get_field_entry_key(node): + """Implements the logic to compute the key of a given field's entry""" + if node.alias: + return node.alias.value + return node.name.value + + +def default_resolve_fn(source, info, **args): + """If a resolve function is not given, then a default resolve behavior is used which takes the property of the source object + of the same name as the field and returns it as the result, or if it's a function, returns the result of calling that function.""" + name = info.field_name + property = getattr(source, name, None) + if callable(property): + return property() + return property + + +def get_field_def(schema, parent_type, field_name): + """This method looks up the field on the given type defintion. + It has special casing for the two introspection fields, __schema + and __typename. __typename is special because it can always be + queried as a field, even in situations where no other fields + are allowed, like on a Union. __schema could get automatically + added to the query type, but that would require mutating type + definitions, which would cause issues.""" + if field_name == '__schema' and schema.get_query_type() == parent_type: + return SchemaMetaFieldDef + elif field_name == '__type' and schema.get_query_type() == parent_type: + return TypeMetaFieldDef + elif field_name == '__typename': + return TypeNameMetaFieldDef + return parent_type.fields.get(field_name) diff --git a/graphql/graphql.py b/graphql/graphql.py index 401a75a2..8dc4a906 100644 --- a/graphql/graphql.py +++ b/graphql/graphql.py @@ -1,8 +1,5 @@ -from .execution import ExecutionResult, execute -from .language.ast import Document -from .language.parser import parse -from .language.source import Source -from .validation import validate +from .execution import ExecutionResult +from .backend import get_default_backend from promise import promisify @@ -38,32 +35,21 @@ def graphql(*args, **kwargs): return execute_graphql(*args, **kwargs) -def execute_graphql(schema, request_string='', root_value=None, context_value=None, - variable_values=None, operation_name=None, executor=None, - return_promise=False, middleware=None, allow_subscriptions=False): +def execute_graphql(schema, request_string='', root=None, context=None, + variables=None, operation_name=None, + middleware=None, backend=None, **execute_options): try: - if isinstance(request_string, Document): - ast = request_string - else: - source = Source(request_string, 'GraphQL request') - ast = parse(source) - validation_errors = validate(schema, ast) - if validation_errors: - return ExecutionResult( - errors=validation_errors, - invalid=True, - ) - return execute( - schema, - ast, - root_value, - context_value, + if backend is None: + backend = get_default_backend() + + document = backend.document_from_string(schema, request_string) + return document.execute( + root=root, + context=context, operation_name=operation_name, - variable_values=variable_values or {}, - executor=executor, + variables=variables, middleware=middleware, - return_promise=return_promise, - allow_subscriptions=allow_subscriptions, + **execute_options ) except Exception as e: return ExecutionResult( diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 7490e633..81a20c99 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -355,5 +355,5 @@ def test_parse_error(): assert result.invalid formatted_error = format_error(result.errors[0]) assert formatted_error['locations'] == [{'column': 9, 'line': 2}] - assert 'Syntax Error GraphQL request (2:9) Unexpected Name "qeury"' in formatted_error['message'] + assert 'Syntax Error GraphQL (2:9) Unexpected Name "qeury"' in formatted_error['message'] assert result.data is None