diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 1b5e884e..0566d083 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -107,6 +107,8 @@ def execute( if executor is None: executor = SyncExecutor() + # operation_name, document_ast + exe_context = ExecutionContext( schema, document_ast, diff --git a/graphql/execution/utils.py b/graphql/execution/utils.py index b1e7ff25..b3ac2d5d 100644 --- a/graphql/execution/utils.py +++ b/graphql/execution/utils.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- import logging from traceback import format_exception +from copy import deepcopy from ..error import GraphQLError from ..language import ast from ..pyutils.default_ordered_dict import DefaultOrderedDict from ..type.definition import GraphQLInterfaceType, GraphQLUnionType -from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective +from ..type.directives import ( + GraphQLIncludeDirective, + GraphQLSkipDirective, + GraphQLRecursionDirective, +) from ..type.introspection import ( SchemaMetaFieldDef, TypeMetaFieldDef, @@ -57,16 +62,16 @@ class ExecutionContext(object): ) def __init__( - self, - schema, # type: GraphQLSchema - document_ast, # type: Document - root_value, # type: Any - context_value, # type: Any - variable_values, # type: Optional[Dict[str, Any]] - operation_name, # type: Optional[str] - executor, # type: Any - middleware, # type: Optional[Any] - allow_subscriptions, # type: bool + self, + schema, # type: GraphQLSchema + document_ast, # type: Document + root_value, # type: Any + context_value, # type: Any + variable_values, # type: Optional[Dict[str, Any]] + operation_name, # type: Optional[str] + executor, # type: Any + middleware, # type: Optional[Any] + allow_subscriptions, # type: bool ): # type: (...) -> None """Constructs a ExecutionContext object from the arguments passed @@ -84,9 +89,9 @@ def __init__( ) if ( - not operation_name - or definition.name - and definition.name.value == operation_name + not operation_name + or definition.name + and definition.name.value == operation_name ): operation = definition @@ -218,11 +223,11 @@ def get_operation_root_type(schema, operation): def collect_fields( - ctx, # type: ExecutionContext - runtime_type, # type: GraphQLObjectType - selection_set, # type: SelectionSet - fields, # type: DefaultOrderedDict - prev_fragment_names, # type: Set[str] + ctx, # type: ExecutionContext + runtime_type, # type: GraphQLObjectType + selection_set, # type: SelectionSet + fields, # type: DefaultOrderedDict + prev_fragment_names, # type: Set[str] ): # type: (...) -> DefaultOrderedDict """ @@ -237,7 +242,8 @@ def collect_fields( directives = selection.directives if isinstance(selection, ast.Field): - if not should_include_node(ctx, directives): + validate = validate_directives(ctx, directives, selection) + if isinstance(validate, bool) and not validate: continue name = get_field_entry_key(selection) @@ -245,7 +251,7 @@ def collect_fields( elif isinstance(selection, ast.InlineFragment): if not should_include_node( - ctx, directives + ctx, directives ) or not does_fragment_condition_match(ctx, selection, runtime_type): continue @@ -257,7 +263,7 @@ def collect_fields( frag_name = selection.name.value if frag_name in prev_fragment_names or not should_include_node( - ctx, directives + ctx, directives ): continue @@ -265,9 +271,9 @@ def collect_fields( fragment = ctx.fragments[frag_name] frag_directives = fragment.directives if ( - not fragment - or not should_include_node(ctx, frag_directives) - or not does_fragment_condition_match(ctx, fragment, runtime_type) + not fragment + or not should_include_node(ctx, frag_directives) + or not does_fragment_condition_match(ctx, fragment, runtime_type) ): continue @@ -316,10 +322,73 @@ def should_include_node(ctx, directives): return True +def validate_directives(ctx, directives, selection): + for directive in directives: + if directive.name.value in (GraphQLSkipDirective.name, GraphQLIncludeDirective.name): + # @skip, @include checking directive + return should_include_node(ctx, directive) + elif directive.name.value == GraphQLRecursionDirective.name: + # @recursive directive check + build_recursive_selection_set(ctx, directive, selection) + + +def relay_node_check(selection, frame): + """ Check it if relay structure is presented + modules { + edges { + node { + uid # place new recursive query here + } + } + } + """ + if frame: + relay_frame = frame.pop(0) + else: + return True + for selection in selection.selection_set.selections: + if selection.name.value == relay_frame: + return relay_node_check(selection, frame) + return False + + +def insert_recursive_selection(selection, depth, frame=[]): + def insert_in_frame(selection, paste_selection, frame=frame): + if frame: + relay_frame = frame.pop(0) + else: + # remove directive + selection.directives = [] + paste_selection.directives = [] + # return inner selection + returnable_selection_set = selection.selection_set + # insert in depth + returnable_selection_set.selections.append(paste_selection) + return paste_selection + for selection in selection.selection_set.selections: + if selection.name.value == relay_frame: + return insert_in_frame(selection, paste_selection, frame) + + # remove_directive(selection) + for counter in range(int(depth)): + copy_selection = deepcopy(selection) + copy_frame = deepcopy(frame) + selection = insert_in_frame(selection, copy_selection, copy_frame) + + +def build_recursive_selection_set(ctx, directive, selection): + depth_size = directive.arguments[0].value.value + is_relay = relay_node_check(selection, ['edges', 'node']) + if is_relay: + insert_recursive_selection(selection, depth_size, ['edges', 'node']) + else: + insert_recursive_selection(selection, depth_size) + + def does_fragment_condition_match( - ctx, # type: ExecutionContext - fragment, # type: Union[FragmentDefinition, InlineFragment] - type_, # type: GraphQLObjectType + ctx, # type: ExecutionContext + fragment, # type: Union[FragmentDefinition, InlineFragment] + type_, # type: GraphQLObjectType ): # type: (...) -> bool type_condition_ast = fragment.type_condition @@ -356,9 +425,9 @@ def default_resolve_fn(source, info, **args): def get_field_def( - schema, # type: GraphQLSchema - parent_type, # type: GraphQLObjectType - field_name, # type: str + schema, # type: GraphQLSchema + parent_type, # type: GraphQLObjectType + field_name, # type: str ): # type: (...) -> Optional[GraphQLField] """This method looks up the field on the given type defintion. diff --git a/graphql/language/lexer.py b/graphql/language/lexer.py index a60bc6e2..1103422a 100644 --- a/graphql/language/lexer.py +++ b/graphql/language/lexer.py @@ -76,6 +76,7 @@ class TokenKind(object): INT = 17 FLOAT = 18 STRING = 19 + ASTERISK = 20 # recursive symbol def get_token_desc(token): @@ -92,7 +93,7 @@ def get_token_kind_desc(kind): TOKEN_DESCRIPTION = { - TokenKind.EOF: "EOF", + TokenKind.EOF: "EOF", # end of file TokenKind.BANG: "!", TokenKind.DOLLAR: "$", TokenKind.PAREN_L: "(", @@ -111,6 +112,7 @@ def get_token_kind_desc(kind): TokenKind.INT: "Int", TokenKind.FLOAT: "Float", TokenKind.STRING: "String", + TokenKind.ASTERISK: "RS", # recursion selection } @@ -118,7 +120,6 @@ def char_code_at(s, pos): # type: (str, int) -> Optional[int] if 0 <= pos < len(s): return ord(s[pos]) - return None @@ -135,6 +136,7 @@ def char_code_at(s, pos): ord("{"): TokenKind.BRACE_L, ord("|"): TokenKind.PIPE, ord("}"): TokenKind.BRACE_R, + ord("*"): TokenKind.ASTERISK, # recursive } @@ -155,14 +157,14 @@ def read_token(source, from_position): This skips over whitespace and comments until it finds the next lexable token, then lexes punctuators immediately or calls the appropriate - helper fucntion for more complicated tokens.""" + helper function for more complicated tokens.""" body = source.body body_length = len(body) position = position_after_whitespace(body, from_position) if position >= body_length: - return Token(TokenKind.EOF, position, position) + return Token(TokenKind.EOF, position, position) # \n send token code = char_code_at(body, position) if code: @@ -173,15 +175,15 @@ def read_token(source, from_position): kind = PUNCT_CODE_TO_KIND.get(code) if kind is not None: - return Token(kind, position, position + 1) + return Token(kind, position, position + 1) # send token of 20 - asterisk - if code == 46: # . + if code == 46: # . if token is point if ( char_code_at(body, position + 1) == char_code_at(body, position + 2) == 46 ): - return Token(TokenKind.SPREAD, position, position + 3) + return Token(TokenKind.SPREAD, position, position + 3) # this definition of fragments elif 65 <= code <= 90 or code == 95 or 97 <= code <= 122: # A-Z, _, a-z diff --git a/graphql/language/parser.py b/graphql/language/parser.py index 8b658e50..3473a7ee 100644 --- a/graphql/language/parser.py +++ b/graphql/language/parser.py @@ -51,6 +51,12 @@ __all__ = ["parse"] +def parse_recursive_body(source,): + # Attrs: + # source: Type[Source] + pass + + def parse(source, **kwargs): # type: (Union[Source, str], **Any) -> Document """Given a GraphQL source, parses it into a Document.""" @@ -241,11 +247,11 @@ def parse_document(parser): start = parser.token.start definitions = [] while True: + # all root types (query, subscription, mutation) definitions.append(parse_definition(parser)) if skip(parser, TokenKind.EOF): break - return ast.Document(definitions=definitions, loc=loc(parser, start)) @@ -255,6 +261,7 @@ def parse_definition(parser): return parse_operation_definition(parser) if peek(parser, TokenKind.NAME): + name = parser.token.value if name in ("query", "mutation", "subscription"): @@ -307,17 +314,15 @@ def parse_operation_definition(parser): ) +OPERATION_NAMES = frozenset(("query", "mutation", "subscription")) + + def parse_operation_type(parser): # type: (Parser) -> str operation_token = expect(parser, TokenKind.NAME) operation = operation_token.value - if operation == "query": - return "query" - elif operation == "mutation": - return "mutation" - elif operation == "subscription": - return "subscription" - + if operation in OPERATION_NAMES: + return operation raise unexpected(parser, operation_token) diff --git a/graphql/type/directives.py b/graphql/type/directives.py index ef7417c4..12f6937c 100644 --- a/graphql/type/directives.py +++ b/graphql/type/directives.py @@ -3,7 +3,7 @@ from ..pyutils.ordereddict import OrderedDict from ..utils.assert_valid_name import assert_valid_name from .definition import GraphQLArgument, GraphQLNonNull, is_input_type -from .scalars import GraphQLBoolean, GraphQLString +from .scalars import GraphQLBoolean, GraphQLString, GraphQLInt class DirectiveLocation(object): @@ -96,6 +96,26 @@ def __init__(self, name, description=None, args=None, locations=None): ], ) + +# Recursive directive (TimurMardanov for neomodel) +GraphQLRecursionDirective = GraphQLDirective( + name='recursive', + description = "Recursion of the selection set, with depth.", + args = { + 'depth': GraphQLArgument( + type=GraphQLNonNull(GraphQLInt), description='Depth of recursion.', + default_value=1, + ) + }, + locations=[ + DirectiveLocation.FIELD, + DirectiveLocation.FIELD_DEFINITION, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ], +) +# + """Constant string used for default reason for a deprecation.""" DEFAULT_DEPRECATION_REASON = "No longer supported" @@ -121,4 +141,5 @@ def __init__(self, name, description=None, args=None, locations=None): GraphQLIncludeDirective, GraphQLSkipDirective, GraphQLDeprecatedDirective, + GraphQLRecursionDirective, ] diff --git a/graphql/type/schema.py b/graphql/type/schema.py index 3fe84659..c245adad 100644 --- a/graphql/type/schema.py +++ b/graphql/type/schema.py @@ -1,7 +1,7 @@ from collections import Iterable from .definition import GraphQLObjectType -from .directives import GraphQLDirective, specified_directives +from .directives import GraphQLDirective, specified_directives, GraphQLRecursionDirective from .introspection import IntrospectionSchema from .typemap import GraphQLTypeMap @@ -85,7 +85,6 @@ def __init__( self._subscription = subscription if directives is None: directives = specified_directives - assert all( isinstance(d, GraphQLDirective) for d in directives ), "Schema directives must be List[GraphQLDirective] if provided but got: {}.".format( @@ -123,7 +122,7 @@ def get_type(self, name): def get_directives(self): # type: () -> List[GraphQLDirective] - return self._directives + return self._directives + [GraphQLRecursionDirective, ] def get_directive(self, name): # type: (str) -> Optional[GraphQLDirective]