diff --git a/lib/graphql/internal_representation/rewrite.rb b/lib/graphql/internal_representation/rewrite.rb index 033ffc3a25..5f5c7a30a7 100644 --- a/lib/graphql/internal_representation/rewrite.rb +++ b/lib/graphql/internal_representation/rewrite.rb @@ -13,127 +13,29 @@ module InternalRepresentation # # The rewritten query tree serves as the basis for the `FieldsWillMerge` validation. # - class Rewrite + module Rewrite include GraphQL::Language NO_DIRECTIVES = [].freeze # @return InternalRepresentation::Document - attr_reader :document + attr_reader :rewrite_document - def initialize - @document = InternalRepresentation::Document.new - end - - # @return [Hash] Roots of this query - def operations - warn "#{self.class}#operations is deprecated; use `document.operation_definitions` instead" - document.operation_definitions - end - - def validate(context) - visitor = context.visitor - query = context.query + def initialize(*) + super + @query = context.query + @rewrite_document = InternalRepresentation::Document.new # Hash Set> # A record of fragment spreads and the irep nodes that used them - spread_parents = Hash.new { |h, k| h[k] = Set.new } + @rewrite_spread_parents = Hash.new { |h, k| h[k] = Set.new } # Hash Scope> - spread_scopes = {} + @rewrite_spread_scopes = {} # Array> # The current point of the irep_tree during visitation - nodes_stack = [] + @rewrite_nodes_stack = [] # Array - scopes_stack = [] - - skip_nodes = Set.new - - visit_op = VisitDefinition.new(context, @document.operation_definitions, nodes_stack, scopes_stack) - visitor[Nodes::OperationDefinition].enter << visit_op.method(:enter) - visitor[Nodes::OperationDefinition].leave << visit_op.method(:leave) - - visit_frag = VisitDefinition.new(context, @document.fragment_definitions, nodes_stack, scopes_stack) - visitor[Nodes::FragmentDefinition].enter << visit_frag.method(:enter) - visitor[Nodes::FragmentDefinition].leave << visit_frag.method(:leave) - - visitor[Nodes::InlineFragment].enter << ->(ast_node, ast_parent) { - # Inline fragments provide two things to the rewritten tree: - # - They _may_ narrow the scope by their type condition - # - They _may_ apply their directives to their children - if skip?(ast_node, query) - skip_nodes.add(ast_node) - end - - if skip_nodes.none? - scopes_stack.push(scopes_stack.last.enter(context.type_definition)) - end - } - - visitor[Nodes::InlineFragment].leave << ->(ast_node, ast_parent) { - if skip_nodes.none? - scopes_stack.pop - end - - if skip_nodes.include?(ast_node) - skip_nodes.delete(ast_node) - end - } - - visitor[Nodes::Field].enter << ->(ast_node, ast_parent) { - if skip?(ast_node, query) - skip_nodes.add(ast_node) - end - - if skip_nodes.none? - node_name = ast_node.alias || ast_node.name - parent_nodes = nodes_stack.last - next_nodes = [] - - field_defn = context.field_definition - if field_defn.nil? - # It's a non-existent field - new_scope = nil - else - field_return_type = field_defn.type - scopes_stack.last.each do |scope_type| - parent_nodes.each do |parent_node| - node = parent_node.scoped_children[scope_type][node_name] ||= Node.new( - parent: parent_node, - name: node_name, - owner_type: scope_type, - query: query, - return_type: field_return_type, - ) - node.ast_nodes << ast_node - node.definitions << field_defn - next_nodes << node - end - end - new_scope = Scope.new(query, field_return_type.unwrap) - end - - nodes_stack.push(next_nodes) - scopes_stack.push(new_scope) - end - } - - visitor[Nodes::Field].leave << ->(ast_node, ast_parent) { - if skip_nodes.none? - nodes_stack.pop - scopes_stack.pop - end - - if skip_nodes.include?(ast_node) - skip_nodes.delete(ast_node) - end - } - - visitor[Nodes::FragmentSpread].enter << ->(ast_node, ast_parent) { - if skip_nodes.none? && !skip?(ast_node, query) - # Register the irep nodes that depend on this AST node: - spread_parents[ast_node].merge(nodes_stack.last) - spread_scopes[ast_node] = scopes_stack.last - end - } + @rewrite_scopes_stack = [] + @rewrite_skip_nodes = Set.new # Resolve fragment spreads. # Fragment definitions got their own irep trees during visitation. @@ -142,12 +44,12 @@ def validate(context) # can be shared between its usages. context.on_dependency_resolve do |defn_ast_node, spread_ast_nodes, frag_ast_node| frag_name = frag_ast_node.name - fragment_node = @document.fragment_definitions[frag_name] + fragment_node = @rewrite_document.fragment_definitions[frag_name] if fragment_node spread_ast_nodes.each do |spread_ast_node| - parent_nodes = spread_parents[spread_ast_node] - parent_scope = spread_scopes[spread_ast_node] + parent_nodes = @rewrite_spread_parents[spread_ast_node] + parent_scope = @rewrite_spread_scopes[spread_ast_node] parent_nodes.each do |parent_node| parent_node.deep_merge_node(fragment_node, scope: parent_scope, merge_self: false) end @@ -156,43 +58,126 @@ def validate(context) end end - def skip?(ast_node, query) - dir = ast_node.directives - dir.any? && !GraphQL::Execution::DirectiveChecks.include?(dir, query) + # @return [Hash] Roots of this query + def operations + warn "#{self.class}#operations is deprecated; use `document.operation_definitions` instead" + @document.operation_definitions + end + + def on_operation_definition(ast_node, parent) + push_root_node(ast_node, @rewrite_document.operation_definitions) { super } + end + + def on_fragment_definition(ast_node, parent) + push_root_node(ast_node, @rewrite_document.fragment_definitions) { super } + end + + def push_root_node(ast_node, definitions) + # Either QueryType or the fragment type condition + owner_type = context.type_definition + defn_name = ast_node.name + + node = Node.new( + parent: nil, + name: defn_name, + owner_type: owner_type, + query: @query, + ast_nodes: [ast_node], + return_type: owner_type, + ) + + definitions[defn_name] = node + @rewrite_scopes_stack.push(Scope.new(@query, owner_type)) + @rewrite_nodes_stack.push([node]) + yield + @rewrite_nodes_stack.pop + @rewrite_scopes_stack.pop + end + + def on_inline_fragment(node, parent) + # Inline fragments provide two things to the rewritten tree: + # - They _may_ narrow the scope by their type condition + # - They _may_ apply their directives to their children + if skip?(node) + @rewrite_skip_nodes.add(node) + end + + if @rewrite_skip_nodes.none? + @rewrite_scopes_stack.push(@rewrite_scopes_stack.last.enter(context.type_definition)) + end + + super + + if @rewrite_skip_nodes.none? + @rewrite_scopes_stack.pop + end + + if @rewrite_skip_nodes.include?(node) + @rewrite_skip_nodes.delete(node) + end end - class VisitDefinition - def initialize(context, definitions, nodes_stack, scopes_stack) - @context = context - @query = context.query - @definitions = definitions - @nodes_stack = nodes_stack - @scopes_stack = scopes_stack + def on_field(ast_node, ast_parent) + if skip?(ast_node) + @rewrite_skip_nodes.add(ast_node) + end + + if @rewrite_skip_nodes.none? + node_name = ast_node.alias || ast_node.name + parent_nodes = @rewrite_nodes_stack.last + next_nodes = [] + + field_defn = context.field_definition + if field_defn.nil? + # It's a non-existent field + new_scope = nil + else + field_return_type = field_defn.type + @rewrite_scopes_stack.last.each do |scope_type| + parent_nodes.each do |parent_node| + node = parent_node.scoped_children[scope_type][node_name] ||= Node.new( + parent: parent_node, + name: node_name, + owner_type: scope_type, + query: @query, + return_type: field_return_type, + ) + node.ast_nodes << ast_node + node.definitions << field_defn + next_nodes << node + end + end + new_scope = Scope.new(@query, field_return_type.unwrap) + end + + @rewrite_nodes_stack.push(next_nodes) + @rewrite_scopes_stack.push(new_scope) + end + + super + + if @rewrite_skip_nodes.none? + @rewrite_nodes_stack.pop + @rewrite_scopes_stack.pop end - def enter(ast_node, ast_parent) - # Either QueryType or the fragment type condition - owner_type = @context.type_definition && @context.type_definition.unwrap - defn_name = ast_node.name - - node = Node.new( - parent: nil, - name: defn_name, - owner_type: owner_type, - query: @query, - ast_nodes: [ast_node], - return_type: @context.type_definition, - ) - - @definitions[defn_name] = node - @scopes_stack.push(Scope.new(@query, owner_type)) - @nodes_stack.push([node]) + if @rewrite_skip_nodes.include?(ast_node) + @rewrite_skip_nodes.delete(ast_node) end + end - def leave(ast_node, ast_parent) - @nodes_stack.pop - @scopes_stack.pop + def on_fragment_spread(ast_node, ast_parent) + if @rewrite_skip_nodes.none? && !skip?(ast_node) + # Register the irep nodes that depend on this AST node: + @rewrite_spread_parents[ast_node].merge(@rewrite_nodes_stack.last) + @rewrite_spread_scopes[ast_node] = @rewrite_scopes_stack.last end + super + end + + def skip?(ast_node) + dir = ast_node.directives + dir.any? && !GraphQL::Execution::DirectiveChecks.include?(dir, @query) end end end diff --git a/lib/graphql/language/nodes.rb b/lib/graphql/language/nodes.rb index 63ae2aa679..977e2ef243 100644 --- a/lib/graphql/language/nodes.rb +++ b/lib/graphql/language/nodes.rb @@ -58,6 +58,11 @@ def scalars [] end + # @return [Symbol] the method to call on {Language::Visitor} for this node + def visit_method + raise NotImplementedError, "#{self.class.name}#visit_method shold return a symbol" + end + def position [line, col] end @@ -112,6 +117,10 @@ def scalars def children [value].flatten.select { |v| v.is_a?(AbstractNode) } end + + def visit_method + :on_argument + end end class Directive < AbstractNode @@ -123,6 +132,10 @@ def initialize_node(name: nil, arguments: []) @name = name @arguments = arguments end + + def visit_method + :on_directive + end end class DirectiveDefinition < AbstractNode @@ -139,9 +152,17 @@ def initialize_node(name: nil, arguments: [], locations: [], description: nil) def children arguments + locations end + + def visit_method + :on_directive_definition + end end - class DirectiveLocation < NameOnlyNode; end + class DirectiveLocation < NameOnlyNode + def visit_method + :on_directive_location + end + end # This is the AST root for normal queries # @@ -174,13 +195,25 @@ def initialize_node(definitions: []) def slice_definition(name) GraphQL::Language::DefinitionSlice.slice(self, name) end + + def visit_method + :on_document + end end # An enum value. The string is available as {#name}. - class Enum < NameOnlyNode; end + class Enum < NameOnlyNode + def visit_method + :on_enum + end + end # A null value literal. - class NullValue < NameOnlyNode; end + class NullValue < NameOnlyNode + def visit_method + :on_null_value + end + end # A single selection in a GraphQL query. class Field < AbstractNode @@ -205,6 +238,10 @@ def scalars def children arguments + directives + selections end + + def visit_method + :on_field + end end # A reusable fragment, defined at document-level. @@ -230,6 +267,10 @@ def children def scalars [name, type] end + + def visit_method + :on_fragment_definition + end end # Application of a named fragment in a selection @@ -245,6 +286,10 @@ def initialize_node(name: nil, directives: []) @name = name @directives = directives end + + def visit_method + :on_fragment_spread + end end # An unnamed fragment, defined directly in the query with `... { }` @@ -267,6 +312,10 @@ def children def scalars [type] end + + def visit_method + :on_inline_fragment + end end # A collection of key-value inputs which may be a field argument @@ -290,6 +339,9 @@ def to_h(options={}) end end + def visit_method + :on_input_object + end private def serialize_value_for_hash(value) @@ -312,10 +364,18 @@ def serialize_value_for_hash(value) # A list type definition, denoted with `[...]` (used for variable type definitions) - class ListType < WrapperType; end + class ListType < WrapperType + def visit_method + :on_list_type + end + end # A non-null type definition, denoted with `...!` (used for variable type definitions) - class NonNullType < WrapperType; end + class NonNullType < WrapperType + def visit_method + :on_non_null_type + end + end # A query, mutation or subscription. # May be anonymous or named. @@ -350,10 +410,18 @@ def children def scalars [operation_type, name] end + + def visit_method + :on_operation_definition + end end # A type name, used for variable definitions - class TypeName < NameOnlyNode; end + class TypeName < NameOnlyNode + def visit_method + :on_type_name + end + end # An operation-level query variable class VariableDefinition < AbstractNode @@ -374,13 +442,21 @@ def initialize_node(name: nil, type: nil, default_value: nil) @default_value = default_value end + def visit_method + :on_variable_definition + end + def scalars [name, type, default_value] end end # Usage of a variable in a query. Name does _not_ include `$`. - class VariableIdentifier < NameOnlyNode; end + class VariableIdentifier < NameOnlyNode + def visit_method + :on_variable_identifier + end + end class SchemaDefinition < AbstractNode attr_reader :query, :mutation, :subscription, :directives @@ -396,6 +472,10 @@ def scalars [query, mutation, subscription] end + def visit_method + :on_schema_definition + end + alias :children :directives end @@ -414,6 +494,10 @@ def scalars end alias :children :directives + + def visit_method + :on_schema_extension + end end class ScalarTypeDefinition < AbstractNode @@ -426,6 +510,10 @@ def initialize_node(name:, directives: [], description: nil) @directives = directives @description = description end + + def visit_method + :on_scalar_type_definition + end end class ScalarTypeExtension < AbstractNode @@ -436,6 +524,10 @@ def initialize_node(name:, directives: []) @name = name @directives = directives end + + def visit_method + :on_scalar_type_extension + end end class ObjectTypeDefinition < AbstractNode @@ -453,6 +545,10 @@ def initialize_node(name:, interfaces:, fields:, directives: [], description: ni def children interfaces + fields + directives end + + def visit_method + :on_object_type_definition + end end class ObjectTypeExtension < AbstractNode @@ -468,6 +564,10 @@ def initialize_node(name:, interfaces:, fields:, directives: []) def children interfaces + fields + directives end + + def visit_method + :on_object_type_extension + end end class InputValueDefinition < AbstractNode @@ -485,6 +585,10 @@ def initialize_node(name:, type:, default_value: nil, directives: [], descriptio def scalars [name, type, default_value] end + + def visit_method + :on_input_value_definition + end end class FieldDefinition < AbstractNode @@ -505,6 +609,10 @@ def children def scalars [name, type] end + + def visit_method + :on_field_definition + end end class InterfaceTypeDefinition < AbstractNode @@ -521,6 +629,10 @@ def initialize_node(name:, fields:, directives: [], description: nil) def children fields + directives end + + def visit_method + :on_interface_type_definition + end end class InterfaceTypeExtension < AbstractNode @@ -535,12 +647,16 @@ def initialize_node(name:, fields:, directives: []) def children fields + directives end + + def visit_method + :on_interface_type_extension + end end class UnionTypeDefinition < AbstractNode attr_reader :name, :types, :directives, :description include Scalars::Name - + def initialize_node(name:, types:, directives: [], description: nil) @name = name @types = types @@ -551,6 +667,10 @@ def initialize_node(name:, types:, directives: [], description: nil) def children types + directives end + + def visit_method + :on_union_type_definition + end end class UnionTypeExtension < AbstractNode @@ -565,6 +685,10 @@ def initialize_node(name:, types:, directives: []) def children types + directives end + + def visit_method + :on_union_type_extension + end end class EnumTypeDefinition < AbstractNode @@ -581,6 +705,10 @@ def initialize_node(name:, values:, directives: [], description: nil) def children values + directives end + + def visit_method + :on_enum_type_extension + end end class EnumTypeExtension < AbstractNode @@ -595,6 +723,10 @@ def initialize_node(name:, values:, directives: []) def children values + directives end + + def visit_method + :on_enum_type_extension + end end class EnumValueDefinition < AbstractNode @@ -607,6 +739,10 @@ def initialize_node(name:, directives: [], description: nil) @directives = directives @description = description end + + def visit_method + :on_enum_type_definition + end end class InputObjectTypeDefinition < AbstractNode @@ -620,6 +756,10 @@ def initialize_node(name:, fields:, directives: [], description: nil) @description = description end + def visit_method + :on_input_object_type_definition + end + def children fields + directives end @@ -637,6 +777,10 @@ def initialize_node(name:, fields:, directives: []) def children fields + directives end + + def visit_method + :on_input_object_type_extension + end end end end diff --git a/lib/graphql/language/visitor.rb b/lib/graphql/language/visitor.rb index 509af2dae8..0efca8b8dc 100644 --- a/lib/graphql/language/visitor.rb +++ b/lib/graphql/language/visitor.rb @@ -3,19 +3,37 @@ module GraphQL module Language # Depth-first traversal through the tree, calling hooks at each stop. # - # @example Create a visitor, add hooks, then search a document - # total_field_count = 0 - # visitor = GraphQL::Language::Visitor.new(document) - # # Whenever you find a field, increment the field count: - # visitor[GraphQL::Language::Nodes::Field] << ->(node) { total_field_count += 1 } - # # When we finish, print the field count: - # visitor[GraphQL::Language::Nodes::Document].leave << ->(node) { p total_field_count } - # visitor.visit - # # => 6 + # @example Create a visitor counting certain field names + # class NameCounter < GraphQL::Language::Visitor + # def initialize(document, field_name) + # super(document) + # @field_name + # @count = 0 + # end + # + # attr_reader :count + # + # def on_field(node, parent) + # # if this field matches our search, increment the counter + # if node.name == @field_name + # @count = 0 + # end + # # Continue visiting subfields: + # super + # end + # end # + # # Initialize a visitor + # visitor = GraphQL::Language::Visitor.new(document, "name") + # # Run it + # visitor.visit + # # Check the result + # visitor.count + # # => 3 class Visitor # If any hook returns this value, the {Visitor} stops visiting this # node right away + # @deprecated Use `super` to continue the visit; or don't call it to halt. SKIP = :_skip def initialize(document) @@ -29,6 +47,7 @@ def initialize(document) # # @example Run a hook whenever you enter a new Field # visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) { p "Here's a field" } + # @deprecated see `on_` methods, like {#on_field} def [](node_class) @visitors[node_class] ||= NodeVisitor.new end @@ -36,19 +55,67 @@ def [](node_class) # Visit `document` and all children, applying hooks as you go # @return [void] def visit - visit_node(@document, nil) + on_document(@document, nil) end - private - - def visit_node(node, parent) - begin_hooks_ok = begin_visit(node, parent) + # The default implementation for visiting an AST node. + # It doesn't _do_ anything, but it continues to visiting the node's children. + # To customize this hook, override one of its aliases (or the base method?) + # in your subclasses. + # + # For compatibility, it calls hook procs, too. + # @param node [GraphQL::Language::Nodes::AbstractNode] the node being visited + # @param parent [GraphQL::Language::Nodes::AbstractNode, nil] the previously-visited node, or `nil` if this is the root node. + # @return [void] + def on_abstract_node(node, parent) + # Run hooks if there are any + begin_hooks_ok = @visitors.none? || begin_visit(node, parent) if begin_hooks_ok - node.children.each { |child| visit_node(child, node) } + node.children.each do |child_node| + public_send(child_node.visit_method, child_node, node) + end end - end_visit(node, parent) + @visitors.any? && end_visit(node, parent) end + alias :on_argument :on_abstract_node + alias :on_directive :on_abstract_node + alias :on_directive_definition :on_abstract_node + alias :on_directive_location :on_abstract_node + alias :on_document :on_abstract_node + alias :on_enum :on_abstract_node + alias :on_enum_type_definition :on_abstract_node + alias :on_enum_type_extension :on_abstract_node + alias :on_enum_value_definition :on_abstract_node + alias :on_field :on_abstract_node + alias :on_field_definition :on_abstract_node + alias :on_fragment_definition :on_abstract_node + alias :on_fragment_spread :on_abstract_node + alias :on_inline_fragment :on_abstract_node + alias :on_input_object :on_abstract_node + alias :on_input_object_type_definition :on_abstract_node + alias :on_input_object_type_extension :on_abstract_node + alias :on_input_value_definition :on_abstract_node + alias :on_interface_type_definition :on_abstract_node + alias :on_interface_type_extension :on_abstract_node + alias :on_list_type :on_abstract_node + alias :on_non_null_type :on_abstract_node + alias :on_null_value :on_abstract_node + alias :on_object_type_definition :on_abstract_node + alias :on_object_type_extension :on_abstract_node + alias :on_operation_definition :on_abstract_node + alias :on_scalar_type_definition :on_abstract_node + alias :on_scalar_type_extension :on_abstract_node + alias :on_schema_definition :on_abstract_node + alias :on_schema_extension :on_abstract_node + alias :on_type_name :on_abstract_node + alias :on_union_type_definition :on_abstract_node + alias :on_union_type_extension :on_abstract_node + alias :on_variable_definition :on_abstract_node + alias :on_variable_identifier :on_abstract_node + + private + def begin_visit(node, parent) node_visitor = self[node.class] self.class.apply_hooks(node_visitor.enter, node, parent) diff --git a/lib/graphql/static_validation.rb b/lib/graphql/static_validation.rb index caac1cfd23..9727bd0203 100644 --- a/lib/graphql/static_validation.rb +++ b/lib/graphql/static_validation.rb @@ -1,12 +1,12 @@ # frozen_string_literal: true require "graphql/static_validation/message" -require "graphql/static_validation/arguments_validator" require "graphql/static_validation/definition_dependencies" require "graphql/static_validation/type_stack" require "graphql/static_validation/validator" require "graphql/static_validation/validation_context" require "graphql/static_validation/literal_validator" - +require "graphql/static_validation/base_visitor" +require "graphql/static_validation/no_validate_visitor" rules_glob = File.expand_path("../static_validation/rules/*.rb", __FILE__) Dir.glob(rules_glob).each do |file| @@ -14,3 +14,4 @@ end require "graphql/static_validation/all_rules" +require "graphql/static_validation/default_visitor" diff --git a/lib/graphql/static_validation/all_rules.rb b/lib/graphql/static_validation/all_rules.rb index d0b5d1b6e2..30edb5f8a9 100644 --- a/lib/graphql/static_validation/all_rules.rb +++ b/lib/graphql/static_validation/all_rules.rb @@ -11,9 +11,10 @@ module StaticValidation GraphQL::StaticValidation::DirectivesAreDefined, GraphQL::StaticValidation::DirectivesAreInValidLocations, GraphQL::StaticValidation::UniqueDirectivesPerLocation, + GraphQL::StaticValidation::OperationNamesAreValid, + GraphQL::StaticValidation::FragmentNamesAreUnique, GraphQL::StaticValidation::FragmentsAreFinite, GraphQL::StaticValidation::FragmentsAreNamed, - GraphQL::StaticValidation::FragmentNamesAreUnique, GraphQL::StaticValidation::FragmentsAreUsed, GraphQL::StaticValidation::FragmentTypesExist, GraphQL::StaticValidation::FragmentsAreOnCompositeTypes, @@ -32,7 +33,6 @@ module StaticValidation GraphQL::StaticValidation::VariableUsagesAreAllowed, GraphQL::StaticValidation::MutationRootExists, GraphQL::StaticValidation::SubscriptionRootExists, - GraphQL::StaticValidation::OperationNamesAreValid, ] end end diff --git a/lib/graphql/static_validation/arguments_validator.rb b/lib/graphql/static_validation/arguments_validator.rb deleted file mode 100644 index ff72d4ccae..0000000000 --- a/lib/graphql/static_validation/arguments_validator.rb +++ /dev/null @@ -1,50 +0,0 @@ -# frozen_string_literal: true -module GraphQL - module StaticValidation - # Implement validate_node - class ArgumentsValidator - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - visitor = context.visitor - visitor[GraphQL::Language::Nodes::Argument] << ->(node, parent) { - case parent - when GraphQL::Language::Nodes::InputObject - arg_defn = context.argument_definition - if arg_defn.nil? - return - else - parent_defn = arg_defn.type.unwrap - if !parent_defn.is_a?(GraphQL::InputObjectType) - return - end - end - when GraphQL::Language::Nodes::Directive - parent_defn = context.schema.directives[parent.name] - when GraphQL::Language::Nodes::Field - parent_defn = context.field_definition - else - raise "Unexpected argument parent: #{parent.class} (##{parent})" - end - validate_node(parent, node, parent_defn, context) - } - end - - private - - def parent_name(parent, type_defn) - if parent.is_a?(GraphQL::Language::Nodes::Field) - parent.alias || parent.name - elsif parent.is_a?(GraphQL::Language::Nodes::InputObject) - type_defn.name - else - parent.name - end - end - - def node_type(parent) - parent.class.name.split("::").last - end - end - end -end diff --git a/lib/graphql/static_validation/base_visitor.rb b/lib/graphql/static_validation/base_visitor.rb new file mode 100644 index 0000000000..0894365bcc --- /dev/null +++ b/lib/graphql/static_validation/base_visitor.rb @@ -0,0 +1,184 @@ +# frozen_string_literal: true +module GraphQL + module StaticValidation + class BaseVisitor < GraphQL::Language::Visitor + def initialize(document, context) + @path = [] + @object_types = [] + @directives = [] + @field_definitions = [] + @argument_definitions = [] + @directive_definitions = [] + @context = context + @schema = context.schema + super(document) + end + + attr_reader :context + + # @return [Array] Types whose scope we've entered + attr_reader :object_types + + # @return [Array] The nesting of the current position in the AST + def path + @path.dup + end + + # Build a class to visit the AST and perform validation, + # or use a pre-built class if rules is `ALL_RULES` or empty. + # @param rules [Array] + # @return [Class] A class for validating `rules` during visitation + def self.including_rules(rules) + if rules.none? + NoValidateVisitor + elsif rules == ALL_RULES + DefaultVisitor + else + visitor_class = Class.new(self) do + include(GraphQL::StaticValidation::DefinitionDependencies) + end + + rules.reverse_each do |r| + # If it's a class, it gets attached later. + if !r.is_a?(Class) + visitor_class.include(r) + end + end + + visitor_class.include(GraphQL::InternalRepresentation::Rewrite) + visitor_class.include(ContextMethods) + visitor_class + end + end + + module ContextMethods + def on_operation_definition(node, parent) + object_type = @schema.root_type_for_operation(node.operation_type) + @object_types.push(object_type) + @path.push("#{node.operation_type}#{node.name ? " #{node.name}" : ""}") + super + @object_types.pop + @path.pop + end + + def on_fragment_definition(node, parent) + on_fragment_with_type(node) do + @path.push("fragment #{node.name}") + super + end + end + + def on_inline_fragment(node, parent) + on_fragment_with_type(node) do + @path.push("...#{node.type ? " on #{node.type.to_query_string}" : ""}") + super + end + end + + def on_field(node, parent) + parent_type = @object_types.last + field_definition = @schema.get_field(parent_type, node.name) + @field_definitions.push(field_definition) + if !field_definition.nil? + next_object_type = field_definition.type.unwrap + @object_types.push(next_object_type) + else + @object_types.push(nil) + end + @path.push(node.alias || node.name) + super + @field_definitions.pop + @object_types.pop + @path.pop + end + + def on_directive(node, parent) + directive_defn = @schema.directives[node.name] + @directive_definitions.push(directive_defn) + super + @directive_definitions.pop + end + + def on_argument(node, parent) + argument_defn = if (arg = @argument_definitions.last) + arg_type = arg.type.unwrap + if arg_type.kind.input_object? + arg_type.input_fields[node.name] + else + nil + end + elsif (directive_defn = @directive_definitions.last) + directive_defn.arguments[node.name] + elsif (field_defn = @field_definitions.last) + field_defn.arguments[node.name] + else + nil + end + + @argument_definitions.push(argument_defn) + @path.push(node.name) + super + @argument_definitions.pop + @path.pop + end + + def on_fragment_spread(node, parent) + @path.push("... #{node.name}") + super + @path.pop + end + + # @return [GraphQL::BaseType] The current object type + def type_definition + @object_types.last + end + + # @return [GraphQL::BaseType] The type which the current type came from + def parent_type_definition + @object_types[-2] + end + + # @return [GraphQL::Field, nil] The most-recently-entered GraphQL::Field, if currently inside one + def field_definition + @field_definitions.last + end + + # @return [GraphQL::Directive, nil] The most-recently-entered GraphQL::Directive, if currently inside one + def directive_definition + @directive_definitions.last + end + + # @return [GraphQL::Argument, nil] The most-recently-entered GraphQL::Argument, if currently inside one + def argument_definition + # Don't get the _last_ one because that's the current one. + # Get the second-to-last one, which is the parent of the current one. + @argument_definitions[-2] + end + + private + + def on_fragment_with_type(node) + object_type = if node.type + @schema.types.fetch(node.type.name, nil) + else + @object_types.last + end + @object_types.push(object_type) + yield(node) + @object_types.pop + @path.pop + end + end + + private + + # Error `message` is located at `node` + def add_error(message, nodes, path: nil) + path ||= @path.dup + nodes = Array(nodes) + m = GraphQL::StaticValidation::Message.new(message, nodes: nodes, path: path) + context.errors << m + end + end + end +end diff --git a/lib/graphql/static_validation/default_visitor.rb b/lib/graphql/static_validation/default_visitor.rb new file mode 100644 index 0000000000..1202f5b3f2 --- /dev/null +++ b/lib/graphql/static_validation/default_visitor.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true +module GraphQL + module StaticValidation + class DefaultVisitor < BaseVisitor + include(GraphQL::StaticValidation::DefinitionDependencies) + + StaticValidation::ALL_RULES.reverse_each do |r| + include(r) + end + + include(GraphQL::InternalRepresentation::Rewrite) + include(ContextMethods) + end + end +end diff --git a/lib/graphql/static_validation/definition_dependencies.rb b/lib/graphql/static_validation/definition_dependencies.rb index 8deba284e6..9fd1fde127 100644 --- a/lib/graphql/static_validation/definition_dependencies.rb +++ b/lib/graphql/static_validation/definition_dependencies.rb @@ -4,79 +4,72 @@ module StaticValidation # Track fragment dependencies for operations # and expose the fragment definitions which # are used by a given operation - class DefinitionDependencies - def self.mount(visitor) - deps = self.new - deps.mount(visitor) - deps - end + module DefinitionDependencies + attr_reader :dependencies - def initialize - @node_paths = {} + def initialize(*) + super + @defdep_node_paths = {} # { name => node } pairs for fragments - @fragment_definitions = {} + @defdep_fragment_definitions = {} # This tracks dependencies from fragment to Node where it was used # { fragment_definition_node => [dependent_node, dependent_node]} - @dependent_definitions = Hash.new { |h, k| h[k] = Set.new } + @defdep_dependent_definitions = Hash.new { |h, k| h[k] = Set.new } # First-level usages of spreads within definitions # (When a key has an empty list as its value, # we can resolve that key's depenedents) # { definition_node => [node, node ...] } - @immediate_dependencies = Hash.new { |h, k| h[k] = Set.new } - end - - # A map of operation definitions to an array of that operation's dependencies - # @return [DependencyMap] - def dependency_map(&block) - @dependency_map ||= resolve_dependencies(&block) - end + @defdep_immediate_dependencies = Hash.new { |h, k| h[k] = Set.new } - def mount(context) - visitor = context.visitor # When we encounter a spread, # this node is the one who depends on it - current_parent = nil - - visitor[GraphQL::Language::Nodes::Document] << ->(node, prev_node) { - node.definitions.each do |definition| - case definition - when GraphQL::Language::Nodes::OperationDefinition - when GraphQL::Language::Nodes::FragmentDefinition - @fragment_definitions[definition.name] = definition - end - end - } + @defdep_current_parent = nil + end - visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, prev_node) { - @node_paths[node] = NodeWithPath.new(node, context.path) - current_parent = node + def on_document(node, parent) + node.definitions.each do |definition| + if definition.is_a? GraphQL::Language::Nodes::FragmentDefinition + @defdep_fragment_definitions[definition.name] = definition + end + end + super + @dependencies = dependency_map { |defn, spreads, frag| + context.on_dependency_resolve_handlers.each { |h| h.call(defn, spreads, frag) } } + end - visitor[GraphQL::Language::Nodes::OperationDefinition].leave << ->(node, prev_node) { - current_parent = nil - } + def on_operation_definition(node, prev_node) + @defdep_node_paths[node] = NodeWithPath.new(node, context.path) + @defdep_current_parent = node + super + @defdep_current_parent = nil + end - visitor[GraphQL::Language::Nodes::FragmentDefinition] << ->(node, prev_node) { - @node_paths[node] = NodeWithPath.new(node, context.path) - current_parent = node - } + def on_fragment_definition(node, parent) + @defdep_node_paths[node] = NodeWithPath.new(node, context.path) + @defdep_current_parent = node + super + @defdep_current_parent = nil + end - visitor[GraphQL::Language::Nodes::FragmentDefinition].leave << ->(node, prev_node) { - current_parent = nil - } + def on_fragment_spread(node, parent) + @defdep_node_paths[node] = NodeWithPath.new(node, context.path) - visitor[GraphQL::Language::Nodes::FragmentSpread] << ->(node, prev_node) { - @node_paths[node] = NodeWithPath.new(node, context.path) + # Track both sides of the dependency + @defdep_dependent_definitions[@defdep_fragment_definitions[node.name]] << @defdep_current_parent + @defdep_immediate_dependencies[@defdep_current_parent] << node + end - # Track both sides of the dependency - @dependent_definitions[@fragment_definitions[node.name]] << current_parent - @immediate_dependencies[current_parent] << node - } + # A map of operation definitions to an array of that operation's dependencies + # @return [DependencyMap] + def dependency_map(&block) + @dependency_map ||= resolve_dependencies(&block) end + # Map definition AST nodes to the definition AST nodes they depend on. # Expose circular depednencies. class DependencyMap @@ -122,14 +115,14 @@ def resolve_dependencies dependency_map = DependencyMap.new # Don't allow the loop to run more times # than the number of fragments in the document - max_loops = @fragment_definitions.size + max_loops = @defdep_fragment_definitions.size loops = 0 # Instead of tracking independent fragments _as you visit_, # determine them at the end. This way, we can treat fragments with the # same name as if they were the same name. If _any_ of the fragments # with that name has a dependency, we record it. - independent_fragment_nodes = @fragment_definitions.values - @immediate_dependencies.keys + independent_fragment_nodes = @defdep_fragment_definitions.values - @defdep_immediate_dependencies.keys while fragment_node = independent_fragment_nodes.pop loops += 1 @@ -138,22 +131,22 @@ def resolve_dependencies end # Since it's independent, let's remove it from here. # That way, we can use the remainder to identify cycles - @immediate_dependencies.delete(fragment_node) - fragment_usages = @dependent_definitions[fragment_node] + @defdep_immediate_dependencies.delete(fragment_node) + fragment_usages = @defdep_dependent_definitions[fragment_node] if fragment_usages.none? # If we didn't record any usages during the visit, # then this fragment is unused. - dependency_map.unused_dependencies << @node_paths[fragment_node] + dependency_map.unused_dependencies << @defdep_node_paths[fragment_node] else fragment_usages.each do |definition_node| # Register the dependency AND second-order dependencies dependency_map[definition_node] << fragment_node dependency_map[definition_node].concat(dependency_map[fragment_node]) # Since we've regestered it, remove it from our to-do list - deps = @immediate_dependencies[definition_node] + deps = @defdep_immediate_dependencies[definition_node] # Can't find a way to _just_ delete from `deps` and return the deleted entries removed, remaining = deps.partition { |spread| spread.name == fragment_node.name } - @immediate_dependencies[definition_node] = remaining + @defdep_immediate_dependencies[definition_node] = remaining if block_given? yield(definition_node, removed, fragment_node) end @@ -170,20 +163,20 @@ def resolve_dependencies # If any dependencies were _unmet_ # (eg, spreads with no corresponding definition) # then they're still in there - @immediate_dependencies.each do |defn_node, deps| + @defdep_immediate_dependencies.each do |defn_node, deps| deps.each do |spread| - if @fragment_definitions[spread.name].nil? - dependency_map.unmet_dependencies[@node_paths[defn_node]] << @node_paths[spread] + if @defdep_fragment_definitions[spread.name].nil? + dependency_map.unmet_dependencies[@defdep_node_paths[defn_node]] << @defdep_node_paths[spread] deps.delete(spread) end end if deps.none? - @immediate_dependencies.delete(defn_node) + @defdep_immediate_dependencies.delete(defn_node) end end # Anything left in @immediate_dependencies is cyclical - cyclical_nodes = @immediate_dependencies.keys.map { |n| @node_paths[n] } + cyclical_nodes = @defdep_immediate_dependencies.keys.map { |n| @defdep_node_paths[n] } # @immediate_dependencies also includes operation names, but we don't care about # those. They became nil when we looked them up on `@fragment_definitions`, so remove them. cyclical_nodes.compact! diff --git a/lib/graphql/static_validation/no_validate_visitor.rb b/lib/graphql/static_validation/no_validate_visitor.rb new file mode 100644 index 0000000000..4fc3303433 --- /dev/null +++ b/lib/graphql/static_validation/no_validate_visitor.rb @@ -0,0 +1,10 @@ +# frozen_string_literal: true +module GraphQL + module StaticValidation + class NoValidateVisitor < StaticValidation::BaseVisitor + include(GraphQL::InternalRepresentation::Rewrite) + include(GraphQL::StaticValidation::DefinitionDependencies) + include(ContextMethods) + end + end +end diff --git a/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb b/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb index 0fb4872dee..476d3dcce2 100644 --- a/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb +++ b/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb @@ -1,27 +1,69 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class ArgumentLiteralsAreCompatible < GraphQL::StaticValidation::ArgumentsValidator - def validate_node(parent, node, defn, context) - return if node.value.is_a?(GraphQL::Language::Nodes::VariableIdentifier) - arg_defn = defn.arguments[node.name] - return unless arg_defn - - begin - valid = context.valid_literal?(node.value, arg_defn.type) - rescue GraphQL::CoercionError => err - error_message = err.message + module ArgumentLiteralsAreCompatible + # TODO dedup with ArgumentsAreDefined + def on_argument(node, parent) + parent_defn = case parent + when GraphQL::Language::Nodes::InputObject + arg_defn = context.argument_definition + if arg_defn.nil? + nil + else + arg_ret_type = arg_defn.type.unwrap + if !arg_ret_type.is_a?(GraphQL::InputObjectType) + nil + else + arg_ret_type + end + end + when GraphQL::Language::Nodes::Directive + context.schema.directives[parent.name] + when GraphQL::Language::Nodes::Field + context.field_definition + else + raise "Unexpected argument parent: #{parent.class} (##{parent})" end - return if valid + if parent_defn && !node.value.is_a?(GraphQL::Language::Nodes::VariableIdentifier) + arg_defn = parent_defn.arguments[node.name] + if arg_defn + begin + valid = context.valid_literal?(node.value, arg_defn.type) + rescue GraphQL::CoercionError => err + error_message = err.message + end - error_message ||= begin - kind_of_node = node_type(parent) - error_arg_name = parent_name(parent, defn) - "Argument '#{node.name}' on #{kind_of_node} '#{error_arg_name}' has an invalid value. Expected type '#{arg_defn.type}'." + if !valid + error_message ||= begin + kind_of_node = node_type(parent) + error_arg_name = parent_name(parent, parent_defn) + "Argument '#{node.name}' on #{kind_of_node} '#{error_arg_name}' has an invalid value. Expected type '#{arg_defn.type}'." + end + + add_error(error_message, parent) + end + end + end + + super + end + + + private + + def parent_name(parent, type_defn) + if parent.is_a?(GraphQL::Language::Nodes::Field) + parent.alias || parent.name + elsif parent.is_a?(GraphQL::Language::Nodes::InputObject) + type_defn.name + else + parent.name end + end - context.errors << message(error_message, parent, context: context) + def node_type(parent) + parent.class.name.split("::").last end end end diff --git a/lib/graphql/static_validation/rules/argument_names_are_unique.rb b/lib/graphql/static_validation/rules/argument_names_are_unique.rb index 43c7f68f9a..cb77e4f60c 100644 --- a/lib/graphql/static_validation/rules/argument_names_are_unique.rb +++ b/lib/graphql/static_validation/rules/argument_names_are_unique.rb @@ -1,27 +1,27 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class ArgumentNamesAreUnique + module ArgumentNamesAreUnique include GraphQL::StaticValidation::Message::MessageHelper - def validate(context) - context.visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) { - validate_arguments(node, context) - } + def on_field(node, parent) + validate_arguments(node) + super + end - context.visitor[GraphQL::Language::Nodes::Directive] << ->(node, parent) { - validate_arguments(node, context) - } + def on_directive(node, parent) + validate_arguments(node) + super end - def validate_arguments(node, context) + def validate_arguments(node) argument_defns = node.arguments if argument_defns.any? args_by_name = Hash.new { |h, k| h[k] = [] } argument_defns.each { |a| args_by_name[a.name] << a } args_by_name.each do |name, defns| if defns.size > 1 - context.errors << message("There can be only one argument named \"#{name}\"", defns, context: context) + add_error("There can be only one argument named \"#{name}\"", defns) end end end diff --git a/lib/graphql/static_validation/rules/arguments_are_defined.rb b/lib/graphql/static_validation/rules/arguments_are_defined.rb index 2bee1b344f..6f99ea4dd0 100644 --- a/lib/graphql/static_validation/rules/arguments_are_defined.rb +++ b/lib/graphql/static_validation/rules/arguments_are_defined.rb @@ -1,18 +1,56 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class ArgumentsAreDefined < GraphQL::StaticValidation::ArgumentsValidator - def validate_node(parent, node, defn, context) - argument_defn = context.warden.arguments(defn).find { |arg| arg.name == node.name } - if argument_defn.nil? + module ArgumentsAreDefined + def on_argument(node, parent) + parent_defn = case parent + when GraphQL::Language::Nodes::InputObject + arg_defn = context.argument_definition + if arg_defn.nil? + nil + else + arg_ret_type = arg_defn.type.unwrap + if !arg_ret_type.is_a?(GraphQL::InputObjectType) + nil + else + arg_ret_type + end + end + when GraphQL::Language::Nodes::Directive + context.schema.directives[parent.name] + when GraphQL::Language::Nodes::Field + context.field_definition + else + raise "Unexpected argument parent: #{parent.class} (##{parent})" + end + + if parent_defn && context.warden.arguments(parent_defn).any? { |arg| arg.name == node.name } + super + elsif parent_defn kind_of_node = node_type(parent) - error_arg_name = parent_name(parent, defn) - context.errors << message("#{kind_of_node} '#{error_arg_name}' doesn't accept argument '#{node.name}'", node, context: context) - GraphQL::Language::Visitor::SKIP + error_arg_name = parent_name(parent, parent_defn) + add_error("#{kind_of_node} '#{error_arg_name}' doesn't accept argument '#{node.name}'", node) + else + # Some other weird error + super + end + end + + private + + def parent_name(parent, type_defn) + if parent.is_a?(GraphQL::Language::Nodes::Field) + parent.alias || parent.name + elsif parent.is_a?(GraphQL::Language::Nodes::InputObject) + type_defn.name else - nil + parent.name end end + + def node_type(parent) + parent.class.name.split("::").last + end end end end diff --git a/lib/graphql/static_validation/rules/directives_are_defined.rb b/lib/graphql/static_validation/rules/directives_are_defined.rb index 5811005f9b..e2f14347fb 100644 --- a/lib/graphql/static_validation/rules/directives_are_defined.rb +++ b/lib/graphql/static_validation/rules/directives_are_defined.rb @@ -1,24 +1,17 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class DirectivesAreDefined - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - directive_names = context.schema.directives.keys - context.visitor[GraphQL::Language::Nodes::Directive] << ->(node, parent) { - validate_directive(node, directive_names, context) - } + module DirectivesAreDefined + def initialize(*) + super + @directive_names = context.schema.directives.keys end - private - - def validate_directive(ast_directive, directive_names, context) - if !directive_names.include?(ast_directive.name) - context.errors << message("Directive @#{ast_directive.name} is not defined", ast_directive, context: context) - GraphQL::Language::Visitor::SKIP + def on_directive(node, parent) + if !@directive_names.include?(node.name) + add_error("Directive @#{node.name} is not defined", node) else - nil + super end end end diff --git a/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb b/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb index e14d5bb4ac..5a52d03593 100644 --- a/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb +++ b/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb @@ -1,16 +1,12 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class DirectivesAreInValidLocations - include GraphQL::StaticValidation::Message::MessageHelper + module DirectivesAreInValidLocations include GraphQL::Language - def validate(context) - directives = context.schema.directives - - context.visitor[Nodes::Directive] << ->(node, parent) { - validate_location(node, parent, directives, context) - } + def on_directive(node, parent) + validate_location(node, parent, context.schema.directives) + super end private @@ -34,25 +30,25 @@ def validate(context) SIMPLE_LOCATION_NODES = SIMPLE_LOCATIONS.keys - def validate_location(ast_directive, ast_parent, directives, context) + def validate_location(ast_directive, ast_parent, directives) directive_defn = directives[ast_directive.name] case ast_parent when Nodes::OperationDefinition required_location = GraphQL::Directive.const_get(ast_parent.operation_type.upcase) - assert_includes_location(directive_defn, ast_directive, required_location, context) + assert_includes_location(directive_defn, ast_directive, required_location) when *SIMPLE_LOCATION_NODES required_location = SIMPLE_LOCATIONS[ast_parent.class] - assert_includes_location(directive_defn, ast_directive, required_location, context) + assert_includes_location(directive_defn, ast_directive, required_location) else - context.errors << message("Directives can't be applied to #{ast_parent.class.name}s", ast_directive, context: context) + add_error("Directives can't be applied to #{ast_parent.class.name}s", ast_directive) end end - def assert_includes_location(directive_defn, directive_ast, required_location, context) + def assert_includes_location(directive_defn, directive_ast, required_location) if !directive_defn.locations.include?(required_location) location_name = LOCATION_MESSAGE_NAMES[required_location] allowed_location_names = directive_defn.locations.map { |loc| LOCATION_MESSAGE_NAMES[loc] } - context.errors << message("'@#{directive_defn.name}' can't be applied to #{location_name} (allowed: #{allowed_location_names.join(", ")})", directive_ast, context: context) + add_error("'@#{directive_defn.name}' can't be applied to #{location_name} (allowed: #{allowed_location_names.join(", ")})", directive_ast) end end end diff --git a/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb b/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb index 7870148729..d942de9558 100644 --- a/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb +++ b/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb @@ -1,30 +1,19 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FieldsAreDefinedOnType - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - visitor = context.visitor - visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) { - parent_type = context.object_types[-2] - parent_type = parent_type.unwrap - validate_field(context, node, parent_type, parent) - } - end - - private - - def validate_field(context, ast_field, parent_type, parent) - field = context.warden.get_field(parent_type, ast_field.name) + module FieldsAreDefinedOnType + def on_field(node, parent) + parent_type = @object_types[-2] + field = context.warden.get_field(parent_type, node.name) if field.nil? if parent_type.kind.union? - context.errors << message("Selections can't be made directly on unions (see selections on #{parent_type.name})", parent, context: context) + add_error("Selections can't be made directly on unions (see selections on #{parent_type.name})", parent) else - context.errors << message("Field '#{ast_field.name}' doesn't exist on type '#{parent_type.name}'", ast_field, context: context) + add_error("Field '#{node.name}' doesn't exist on type '#{parent_type.name}'", node) end - return GraphQL::Language::Visitor::SKIP + else + super end end end diff --git a/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb b/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb index 71f3cf3023..865fb6efa2 100644 --- a/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb +++ b/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb @@ -3,24 +3,26 @@ module GraphQL module StaticValidation # Scalars _can't_ have selections # Objects _must_ have selections - class FieldsHaveAppropriateSelections + module FieldsHaveAppropriateSelections include GraphQL::StaticValidation::Message::MessageHelper - def validate(context) - context.visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) { - field_defn = context.field_definition - validate_field_selections(node, field_defn.type.unwrap, context) - } + def on_field(node, parent) + field_defn = field_definition + if validate_field_selections(node, field_defn.type.unwrap) + super + end + end - context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) { - validate_field_selections(node, context.type_definition, context) - } + def on_operation_definition(node, _parent) + if validate_field_selections(node, type_definition) + super + end end private - def validate_field_selections(ast_node, resolved_type, context) + def validate_field_selections(ast_node, resolved_type) msg = if resolved_type.nil? nil elsif resolved_type.kind.scalar? && ast_node.selections.any? @@ -48,8 +50,10 @@ def validate_field_selections(ast_node, resolved_type, context) else raise("Unexpected node #{ast_node}") end - context.errors << message(msg % { node_name: node_name }, ast_node, context: context) - GraphQL::Language::Visitor::SKIP + add_error(msg % { node_name: node_name }, ast_node) + false + else + true end end end diff --git a/lib/graphql/static_validation/rules/fields_will_merge.rb b/lib/graphql/static_validation/rules/fields_will_merge.rb index 5897e936d7..af9e9df3c9 100644 --- a/lib/graphql/static_validation/rules/fields_will_merge.rb +++ b/lib/graphql/static_validation/rules/fields_will_merge.rb @@ -1,11 +1,13 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FieldsWillMerge + module FieldsWillMerge # Special handling for fields without arguments NO_ARGS = {}.freeze - def validate(context) + def initialize(*) + super + context.each_irep_node do |node| if node.ast_nodes.size > 1 defn_names = Set.new(node.ast_nodes.map(&:name)) diff --git a/lib/graphql/static_validation/rules/fragment_names_are_unique.rb b/lib/graphql/static_validation/rules/fragment_names_are_unique.rb index f4e9dff7de..aa650ccfbe 100644 --- a/lib/graphql/static_validation/rules/fragment_names_are_unique.rb +++ b/lib/graphql/static_validation/rules/fragment_names_are_unique.rb @@ -1,22 +1,25 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentNamesAreUnique - include GraphQL::StaticValidation::Message::MessageHelper + module FragmentNamesAreUnique - def validate(context) - fragments_by_name = Hash.new { |h, k| h[k] = [] } - context.visitor[GraphQL::Language::Nodes::FragmentDefinition] << ->(node, parent) { - fragments_by_name[node.name] << node - } + def initialize(*) + super + @fragments_by_name = Hash.new { |h, k| h[k] = [] } + end + + def on_fragment_definition(node, parent) + @fragments_by_name[node.name] << node + super + end - context.visitor[GraphQL::Language::Nodes::Document].leave << ->(node, parent) { - fragments_by_name.each do |name, fragments| - if fragments.length > 1 - context.errors << message(%|Fragment name "#{name}" must be unique|, fragments, context: context) - end + def on_document(_n, _p) + super + @fragments_by_name.each do |name, fragments| + if fragments.length > 1 + add_error(%|Fragment name "#{name}" must be unique|, fragments) end - } + end end end end diff --git a/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb b/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb index 5351c28bbb..57421d1a38 100644 --- a/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb +++ b/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb @@ -1,39 +1,40 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentSpreadsArePossible - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - - context.visitor[GraphQL::Language::Nodes::InlineFragment] << ->(node, parent) { - fragment_parent = context.object_types[-2] - fragment_child = context.object_types.last - if fragment_child - validate_fragment_in_scope(fragment_parent, fragment_child, node, context, context.path) - end - } + module FragmentSpreadsArePossible + def initialize(*) + super + @spreads_to_validate = [] + end - spreads_to_validate = [] + def on_inline_fragment(node, parent) + fragment_parent = context.object_types[-2] + fragment_child = context.object_types.last + if fragment_child + validate_fragment_in_scope(fragment_parent, fragment_child, node, context, context.path) + end + super + end - context.visitor[GraphQL::Language::Nodes::FragmentSpread] << ->(node, parent) { - fragment_parent = context.object_types.last - spreads_to_validate << FragmentSpread.new(node: node, parent_type: fragment_parent, path: context.path) - } + def on_fragment_spread(node, parent) + fragment_parent = context.object_types.last + @spreads_to_validate << FragmentSpread.new(node: node, parent_type: fragment_parent, path: context.path) + super + end - context.visitor[GraphQL::Language::Nodes::Document].leave << ->(doc_node, parent) { - spreads_to_validate.each do |frag_spread| - frag_node = context.fragments[frag_spread.node.name] - if frag_node - fragment_child_name = frag_node.type.name - fragment_child = context.warden.get_type(fragment_child_name) - # Might be non-existent type name - if fragment_child - validate_fragment_in_scope(frag_spread.parent_type, fragment_child, frag_spread.node, context, frag_spread.path) - end + def on_document(node, parent) + super + @spreads_to_validate.each do |frag_spread| + frag_node = context.fragments[frag_spread.node.name] + if frag_node + fragment_child_name = frag_node.type.name + fragment_child = context.warden.get_type(fragment_child_name) + # Might be non-existent type name + if fragment_child + validate_fragment_in_scope(frag_spread.parent_type, fragment_child, frag_spread.node, context, frag_spread.path) end end - } + end end private @@ -48,7 +49,7 @@ def validate_fragment_in_scope(parent_type, child_type, node, context, path) if child_types.none? { |c| parent_types.include?(c) } name = node.respond_to?(:name) ? " #{node.name}" : "" - context.errors << message("Fragment#{name} on #{child_type.name} can't be spread inside #{parent_type.name}", node, path: path) + add_error("Fragment#{name} on #{child_type.name} can't be spread inside #{parent_type.name}", node, path: path) end end diff --git a/lib/graphql/static_validation/rules/fragment_types_exist.rb b/lib/graphql/static_validation/rules/fragment_types_exist.rb index 06fac1e4d8..f087fc6e0b 100644 --- a/lib/graphql/static_validation/rules/fragment_types_exist.rb +++ b/lib/graphql/static_validation/rules/fragment_types_exist.rb @@ -1,29 +1,33 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentTypesExist - include GraphQL::StaticValidation::Message::MessageHelper - - FRAGMENTS_ON_TYPES = [ - GraphQL::Language::Nodes::FragmentDefinition, - GraphQL::Language::Nodes::InlineFragment, - ] + module FragmentTypesExist + def on_fragment_definition(node, _parent) + if validate_type_exists(node) + super + end + end - def validate(context) - FRAGMENTS_ON_TYPES.each do |node_class| - context.visitor[node_class] << ->(node, parent) { validate_type_exists(node, context) } + def on_inline_fragment(node, _parent) + if validate_type_exists(node) + super end end private - def validate_type_exists(node, context) - return unless node.type - type_name = node.type.name - type = context.warden.get_type(type_name) - if type.nil? - context.errors << message("No such type #{type_name}, so it can't be a fragment condition", node, context: context) - GraphQL::Language::Visitor::SKIP + def validate_type_exists(fragment_node) + if !fragment_node.type + true + else + type_name = fragment_node.type.name + type = context.warden.get_type(type_name) + if type.nil? + add_error("No such type #{type_name}, so it can't be a fragment condition", fragment_node) + false + else + true + end end end end diff --git a/lib/graphql/static_validation/rules/fragments_are_finite.rb b/lib/graphql/static_validation/rules/fragments_are_finite.rb index da0d198f54..d0bd368e55 100644 --- a/lib/graphql/static_validation/rules/fragments_are_finite.rb +++ b/lib/graphql/static_validation/rules/fragments_are_finite.rb @@ -1,16 +1,13 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentsAreFinite - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - context.visitor[GraphQL::Language::Nodes::Document].leave << ->(_n, _p) do - dependency_map = context.dependencies - dependency_map.cyclical_definitions.each do |defn| - if defn.node.is_a?(GraphQL::Language::Nodes::FragmentDefinition) - context.errors << message("Fragment #{defn.name} contains an infinite loop", defn.node, path: defn.path) - end + module FragmentsAreFinite + def on_document(_n, _p) + super + dependency_map = context.dependencies + dependency_map.cyclical_definitions.each do |defn| + if defn.node.is_a?(GraphQL::Language::Nodes::FragmentDefinition) + context.errors << message("Fragment #{defn.name} contains an infinite loop", defn.node, path: defn.path) end end end diff --git a/lib/graphql/static_validation/rules/fragments_are_named.rb b/lib/graphql/static_validation/rules/fragments_are_named.rb index d48dc1f6d2..c460ad0adb 100644 --- a/lib/graphql/static_validation/rules/fragments_are_named.rb +++ b/lib/graphql/static_validation/rules/fragments_are_named.rb @@ -1,19 +1,12 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentsAreNamed - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - context.visitor[GraphQL::Language::Nodes::FragmentDefinition] << ->(node, parent) { validate_name_exists(node, context) } - end - - private - - def validate_name_exists(node, context) + module FragmentsAreNamed + def on_fragment_definition(node, _parent) if node.name.nil? - context.errors << message("Fragment definition has no name", node, context: context) + add_error("Fragment definition has no name", node) end + super end end end diff --git a/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb b/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb index 5d92cfa535..ce92ca1beb 100644 --- a/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb +++ b/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb @@ -1,34 +1,30 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentsAreOnCompositeTypes - include GraphQL::StaticValidation::Message::MessageHelper - - HAS_TYPE_CONDITION = [ - GraphQL::Language::Nodes::FragmentDefinition, - GraphQL::Language::Nodes::InlineFragment, - ] + module FragmentsAreOnCompositeTypes + def on_fragment_definition(node, parent) + validate_type_is_composite(node) && super + end - def validate(context) - HAS_TYPE_CONDITION.each do |node_class| - context.visitor[node_class] << ->(node, parent) { - validate_type_is_composite(node, context) - } - end + def on_inline_fragment(node, parent) + validate_type_is_composite(node) && super end private - def validate_type_is_composite(node, context) + def validate_type_is_composite(node) node_type = node.type if node_type.nil? # Inline fragment on the same type + true else type_name = node_type.to_query_string type_def = context.warden.get_type(type_name) if type_def.nil? || !type_def.kind.composite? - context.errors << message("Invalid fragment on type #{type_name} (must be Union, Interface or Object)", node, context: context) - GraphQL::Language::Visitor::SKIP + add_error("Invalid fragment on type #{type_name} (must be Union, Interface or Object)", node) + false + else + true end end end diff --git a/lib/graphql/static_validation/rules/fragments_are_used.rb b/lib/graphql/static_validation/rules/fragments_are_used.rb index 5489818087..0d76ef224a 100644 --- a/lib/graphql/static_validation/rules/fragments_are_used.rb +++ b/lib/graphql/static_validation/rules/fragments_are_used.rb @@ -1,22 +1,19 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class FragmentsAreUsed - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - context.visitor[GraphQL::Language::Nodes::Document].leave << ->(_n, _p) do - dependency_map = context.dependencies - dependency_map.unmet_dependencies.each do |op_defn, spreads| - spreads.each do |fragment_spread| - context.errors << message("Fragment #{fragment_spread.name} was used, but not defined", fragment_spread.node, path: fragment_spread.path) - end + module FragmentsAreUsed + def on_document(node, parent) + super + dependency_map = context.dependencies + dependency_map.unmet_dependencies.each do |op_defn, spreads| + spreads.each do |fragment_spread| + add_error("Fragment #{fragment_spread.name} was used, but not defined", fragment_spread.node, path: fragment_spread.path) end + end - dependency_map.unused_dependencies.each do |fragment| - if !fragment.name.nil? - context.errors << message("Fragment #{fragment.name} was defined, but not used", fragment.node, path: fragment.path) - end + dependency_map.unused_dependencies.each do |fragment| + if !fragment.name.nil? + add_error("Fragment #{fragment.name} was defined, but not used", fragment.node, path: fragment.path) end end end diff --git a/lib/graphql/static_validation/rules/mutation_root_exists.rb b/lib/graphql/static_validation/rules/mutation_root_exists.rb index 0f56859e68..8d65f95e0e 100644 --- a/lib/graphql/static_validation/rules/mutation_root_exists.rb +++ b/lib/graphql/static_validation/rules/mutation_root_exists.rb @@ -1,20 +1,13 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class MutationRootExists - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - return if context.warden.root_type_for_operation("mutation") - - visitor = context.visitor - - visitor[GraphQL::Language::Nodes::OperationDefinition].enter << ->(ast_node, prev_ast_node) { - if ast_node.operation_type == 'mutation' - context.errors << message('Schema is not configured for mutations', ast_node, context: context) - return GraphQL::Language::Visitor::SKIP - end - } + module MutationRootExists + def on_operation_definition(node, _parent) + if node.operation_type == 'mutation' && context.warden.root_type_for_operation("mutation").nil? + add_error('Schema is not configured for mutations', node) + else + super + end end end end diff --git a/lib/graphql/static_validation/rules/no_definitions_are_present.rb b/lib/graphql/static_validation/rules/no_definitions_are_present.rb index c2e661fb2e..60ef505a3d 100644 --- a/lib/graphql/static_validation/rules/no_definitions_are_present.rb +++ b/lib/graphql/static_validation/rules/no_definitions_are_present.rb @@ -1,40 +1,39 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class NoDefinitionsArePresent + module NoDefinitionsArePresent include GraphQL::StaticValidation::Message::MessageHelper - def validate(context) - schema_definition_nodes = [] - register_node = ->(node, _p) { - schema_definition_nodes << node - GraphQL::Language::Visitor::SKIP - } - - visitor = context.visitor + def initialize(*) + super + @schema_definition_nodes = [] + end - visitor[GraphQL::Language::Nodes::DirectiveDefinition] << register_node - visitor[GraphQL::Language::Nodes::SchemaDefinition] << register_node - visitor[GraphQL::Language::Nodes::ScalarTypeDefinition] << register_node - visitor[GraphQL::Language::Nodes::ObjectTypeDefinition] << register_node - visitor[GraphQL::Language::Nodes::InputObjectTypeDefinition] << register_node - visitor[GraphQL::Language::Nodes::InterfaceTypeDefinition] << register_node - visitor[GraphQL::Language::Nodes::UnionTypeDefinition] << register_node - visitor[GraphQL::Language::Nodes::EnumTypeDefinition] << register_node + def on_invalid_node(node, parent) + @schema_definition_nodes << node + end - visitor[GraphQL::Language::Nodes::SchemaExtension] << register_node - visitor[GraphQL::Language::Nodes::ScalarTypeExtension] << register_node - visitor[GraphQL::Language::Nodes::ObjectTypeExtension] << register_node - visitor[GraphQL::Language::Nodes::InputObjectTypeExtension] << register_node - visitor[GraphQL::Language::Nodes::InterfaceTypeExtension] << register_node - visitor[GraphQL::Language::Nodes::UnionTypeExtension] << register_node - visitor[GraphQL::Language::Nodes::EnumTypeExtension] << register_node + alias :on_directive_definition :on_invalid_node + alias :on_schema_definition :on_invalid_node + alias :on_scalar_type_definition :on_invalid_node + alias :on_object_type_definition :on_invalid_node + alias :on_input_object_type_definition :on_invalid_node + alias :on_interface_type_definition :on_invalid_node + alias :on_union_type_definition :on_invalid_node + alias :on_enum_type_definition :on_invalid_node + alias :on_schema_extension :on_invalid_node + alias :on_scalar_type_extension :on_invalid_node + alias :on_object_type_extension :on_invalid_node + alias :on_input_object_type_extension :on_invalid_node + alias :on_interface_type_extension :on_invalid_node + alias :on_union_type_extension :on_invalid_node + alias :on_enum_type_extension :on_invalid_node - visitor[GraphQL::Language::Nodes::Document].leave << ->(node, _p) { - if schema_definition_nodes.any? - context.errors << message(%|Query cannot contain schema definitions|, schema_definition_nodes, context: context) - end - } + def on_document(node, parent) + super + if @schema_definition_nodes.any? + add_error(%|Query cannot contain schema definitions|, @schema_definition_nodes) + end end end end diff --git a/lib/graphql/static_validation/rules/operation_names_are_valid.rb b/lib/graphql/static_validation/rules/operation_names_are_valid.rb index e464dc1d2d..9a7ad2226e 100644 --- a/lib/graphql/static_validation/rules/operation_names_are_valid.rb +++ b/lib/graphql/static_validation/rules/operation_names_are_valid.rb @@ -1,27 +1,28 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class OperationNamesAreValid - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - op_names = Hash.new { |h, k| h[k] = [] } + module OperationNamesAreValid + def initialize(*) + super + @operation_names = Hash.new { |h, k| h[k] = [] } + end - context.visitor[GraphQL::Language::Nodes::OperationDefinition].enter << ->(node, _parent) { - op_names[node.name] << node - } + def on_operation_definition(node, parent) + @operation_names[node.name] << node + super + end - context.visitor[GraphQL::Language::Nodes::Document].leave << ->(node, _parent) { - op_count = op_names.values.inject(0) { |m, v| m + v.size } + def on_document(node, parent) + super + op_count = @operation_names.values.inject(0) { |m, v| m + v.size } - op_names.each do |name, nodes| - if name.nil? && op_count > 1 - context.errors << message(%|Operation name is required when multiple operations are present|, nodes, context: context) - elsif nodes.length > 1 - context.errors << message(%|Operation name "#{name}" must be unique|, nodes, context: context) - end + @operation_names.each do |name, nodes| + if name.nil? && op_count > 1 + add_error(%|Operation name is required when multiple operations are present|, nodes) + elsif nodes.length > 1 + add_error(%|Operation name "#{name}" must be unique|, nodes) end - } + end end end end diff --git a/lib/graphql/static_validation/rules/required_arguments_are_present.rb b/lib/graphql/static_validation/rules/required_arguments_are_present.rb index 50e50e97bf..3569dd910f 100644 --- a/lib/graphql/static_validation/rules/required_arguments_are_present.rb +++ b/lib/graphql/static_validation/rules/required_arguments_are_present.rb @@ -1,28 +1,21 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class RequiredArgumentsArePresent - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - v = context.visitor - v[GraphQL::Language::Nodes::Field] << ->(node, parent) { validate_field(node, context) } - v[GraphQL::Language::Nodes::Directive] << ->(node, parent) { validate_directive(node, context) } + module RequiredArgumentsArePresent + def on_field(node, _parent) + assert_required_args(node, field_definition) + super end - private - - def validate_directive(ast_directive, context) - directive_defn = context.schema.directives[ast_directive.name] - assert_required_args(ast_directive, directive_defn, context) + def on_directive(node, _parent) + directive_defn = context.schema.directives[node.name] + assert_required_args(node, directive_defn) + super end - def validate_field(ast_field, context) - defn = context.field_definition - assert_required_args(ast_field, defn, context) - end + private - def assert_required_args(ast_node, defn, context) + def assert_required_args(ast_node, defn) present_argument_names = ast_node.arguments.map(&:name) required_argument_names = defn.arguments.values .select { |a| a.type.kind.non_null? } @@ -30,7 +23,7 @@ def assert_required_args(ast_node, defn, context) missing_names = required_argument_names - present_argument_names if missing_names.any? - context.errors << message("#{ast_node.class.name.split("::").last} '#{ast_node.name}' is missing required arguments: #{missing_names.join(", ")}", ast_node, context: context) + add_error("#{ast_node.class.name.split("::").last} '#{ast_node.name}' is missing required arguments: #{missing_names.join(", ")}", ast_node) end end end diff --git a/lib/graphql/static_validation/rules/subscription_root_exists.rb b/lib/graphql/static_validation/rules/subscription_root_exists.rb index 065b9b0a42..1143aa7ef1 100644 --- a/lib/graphql/static_validation/rules/subscription_root_exists.rb +++ b/lib/graphql/static_validation/rules/subscription_root_exists.rb @@ -1,20 +1,13 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class SubscriptionRootExists - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - return if context.warden.root_type_for_operation("subscription") - - visitor = context.visitor - - visitor[GraphQL::Language::Nodes::OperationDefinition].enter << ->(ast_node, prev_ast_node) { - if ast_node.operation_type == 'subscription' - context.errors << message('Schema is not configured for subscriptions', ast_node, context: context) - return GraphQL::Language::Visitor::SKIP - end - } + module SubscriptionRootExists + def on_operation_definition(node, _parent) + if node.operation_type == "subscription" && context.warden.root_type_for_operation("subscription").nil? + add_error('Schema is not configured for subscriptions', node) + else + super + end end end end diff --git a/lib/graphql/static_validation/rules/unique_directives_per_location.rb b/lib/graphql/static_validation/rules/unique_directives_per_location.rb index 43944c4225..ec8370cedc 100644 --- a/lib/graphql/static_validation/rules/unique_directives_per_location.rb +++ b/lib/graphql/static_validation/rules/unique_directives_per_location.rb @@ -1,33 +1,43 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class UniqueDirectivesPerLocation - include GraphQL::StaticValidation::Message::MessageHelper + module UniqueDirectivesPerLocation + DIRECTIVE_NODE_HOOKS = [ + :on_fragment_definition, + :on_fragment_spread, + :on_inline_fragment, + :on_operation_definition, + :on_scalar_type_definition, + :on_object_type_definition, + :on_input_value_definition, + :on_field_definition, + :on_interface_type_definition, + :on_union_type_definition, + :on_enum_type_definition, + :on_enum_value_definition, + :on_input_object_type_definition, + :on_field, + ] - NODES_WITH_DIRECTIVES = GraphQL::Language::Nodes.constants - .map{|c| GraphQL::Language::Nodes.const_get(c)} - .select{|c| c.is_a?(Class) && c.instance_methods.include?(:directives)} - - def validate(context) - NODES_WITH_DIRECTIVES.each do |node_class| - context.visitor[node_class] << ->(node, _) { - validate_directives(node, context) unless node.directives.empty? - } + DIRECTIVE_NODE_HOOKS.each do |method_name| + define_method(method_name) do |node, parent| + if node.directives.any? + validate_directive_location(node) + end + super(node, parent) end end private - def validate_directives(node, context) + def validate_directive_location(node) used_directives = {} - node.directives.each do |ast_directive| directive_name = ast_directive.name if used_directives[directive_name] - context.errors << message( + add_error( "The directive \"#{directive_name}\" can only be used once at this location.", - [used_directives[directive_name], ast_directive], - context: context + [used_directives[directive_name], ast_directive] ) else used_directives[directive_name] = ast_directive diff --git a/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb b/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb index 83d552a7ba..b799e0d202 100644 --- a/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb +++ b/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb @@ -1,39 +1,33 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class VariableDefaultValuesAreCorrectlyTyped - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - context.visitor[GraphQL::Language::Nodes::VariableDefinition] << ->(node, parent) { - if !node.default_value.nil? - validate_default_value(node, context) - end - } - end - - def validate_default_value(node, context) - value = node.default_value - if node.type.is_a?(GraphQL::Language::Nodes::NonNullType) - context.errors << message("Non-null variable $#{node.name} can't have a default value", node, context: context) - else - type = context.schema.type_from_ast(node.type) - if type.nil? - # This is handled by another validator + module VariableDefaultValuesAreCorrectlyTyped + def on_variable_definition(node, parent) + if !node.default_value.nil? + value = node.default_value + if node.type.is_a?(GraphQL::Language::Nodes::NonNullType) + add_error("Non-null variable $#{node.name} can't have a default value", node) else - begin - valid = context.valid_literal?(value, type) - rescue GraphQL::CoercionError => err - error_message = err.message - end + type = context.schema.type_from_ast(node.type) + if type.nil? + # This is handled by another validator + else + begin + valid = context.valid_literal?(value, type) + rescue GraphQL::CoercionError => err + error_message = err.message + end - if !valid - error_message ||= "Default value for $#{node.name} doesn't match type #{type}" - context.errors << message(error_message, node, context: context) - end - end + if !valid + error_message ||= "Default value for $#{node.name} doesn't match type #{type}" + add_error(error_message, node) + end + end + end end + + super end end - end + end end diff --git a/lib/graphql/static_validation/rules/variable_names_are_unique.rb b/lib/graphql/static_validation/rules/variable_names_are_unique.rb index 340aa2eb47..c8117a00ec 100644 --- a/lib/graphql/static_validation/rules/variable_names_are_unique.rb +++ b/lib/graphql/static_validation/rules/variable_names_are_unique.rb @@ -1,22 +1,19 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class VariableNamesAreUnique - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) { - var_defns = node.variables - if var_defns.any? - vars_by_name = Hash.new { |h, k| h[k] = [] } - var_defns.each { |v| vars_by_name[v.name] << v } - vars_by_name.each do |name, defns| - if defns.size > 1 - context.errors << message("There can only be one variable named \"#{name}\"", defns, context: context) - end + module VariableNamesAreUnique + def on_operation_definition(node, parent) + var_defns = node.variables + if var_defns.any? + vars_by_name = Hash.new { |h, k| h[k] = [] } + var_defns.each { |v| vars_by_name[v.name] << v } + vars_by_name.each do |name, defns| + if defns.size > 1 + add_error("There can only be one variable named \"#{name}\"", defns) end end - } + end + super end end end diff --git a/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb b/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb index 50a9d6fe3e..5ce3605560 100644 --- a/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb +++ b/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb @@ -1,53 +1,57 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class VariableUsagesAreAllowed - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) + module VariableUsagesAreAllowed + def initialize(*) + super # holds { name => ast_node } pairs - declared_variables = {} - context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) { - declared_variables = node.variables.each_with_object({}) { |var, memo| memo[var.name] = var } - } - - context.visitor[GraphQL::Language::Nodes::Argument] << ->(node, parent) { - node_values = if node.value.is_a?(Array) - node.value - else - [node.value] - end - node_values = node_values.select { |value| value.is_a? GraphQL::Language::Nodes::VariableIdentifier } + @declared_variables = {} + end - return if node_values.none? + def on_operation_definition(node, parent) + @declared_variables = node.variables.each_with_object({}) { |var, memo| memo[var.name] = var } + super + end - arguments = nil - case parent + def on_argument(node, parent) + node_values = if node.value.is_a?(Array) + node.value + else + [node.value] + end + node_values = node_values.select { |value| value.is_a? GraphQL::Language::Nodes::VariableIdentifier } + + if node_values.any? + arguments = case parent when GraphQL::Language::Nodes::Field - arguments = context.field_definition.arguments + context.field_definition.arguments when GraphQL::Language::Nodes::Directive - arguments = context.directive_definition.arguments + context.directive_definition.arguments when GraphQL::Language::Nodes::InputObject arg_type = context.argument_definition.type.unwrap if arg_type.is_a?(GraphQL::InputObjectType) arguments = arg_type.input_fields + else + # This is some kind of error + nil end else raise("Unexpected argument parent: #{parent}") end node_values.each do |node_value| - var_defn_ast = declared_variables[node_value.name] + var_defn_ast = @declared_variables[node_value.name] # Might be undefined :( # VariablesAreUsedAndDefined can't finalize its search until the end of the document. - var_defn_ast && arguments && validate_usage(arguments, node, var_defn_ast, context) + var_defn_ast && arguments && validate_usage(arguments, node, var_defn_ast) end - } + end + super end private - def validate_usage(arguments, arg_node, ast_var, context) + def validate_usage(arguments, arg_node, ast_var) var_type = context.schema.type_from_ast(ast_var.type) if var_type.nil? return @@ -71,16 +75,16 @@ def validate_usage(arguments, arg_node, ast_var, context) var_type = wrap_var_type_with_depth_of_arg(var_type, arg_node) if var_inner_type != arg_inner_type - context.errors << create_error("Type mismatch", var_type, ast_var, arg_defn, arg_node, context) + create_error("Type mismatch", var_type, ast_var, arg_defn, arg_node) elsif list_dimension(var_type) != list_dimension(arg_defn_type) - context.errors << create_error("List dimension mismatch", var_type, ast_var, arg_defn, arg_node, context) + create_error("List dimension mismatch", var_type, ast_var, arg_defn, arg_node) elsif !non_null_levels_match(arg_defn_type, var_type) - context.errors << create_error("Nullability mismatch", var_type, ast_var, arg_defn, arg_node, context) + create_error("Nullability mismatch", var_type, ast_var, arg_defn, arg_node) end end - def create_error(error_message, var_type, ast_var, arg_defn, arg_node, context) - message("#{error_message} on variable $#{ast_var.name} and argument #{arg_node.name} (#{var_type.to_s} / #{arg_defn.type.to_s})", arg_node, context: context) + def create_error(error_message, var_type, ast_var, arg_defn, arg_node) + add_error("#{error_message} on variable $#{ast_var.name} and argument #{arg_node.name} (#{var_type.to_s} / #{arg_defn.type.to_s})", arg_node) end def wrap_var_type_with_depth_of_arg(var_type, arg_node) diff --git a/lib/graphql/static_validation/rules/variables_are_input_types.rb b/lib/graphql/static_validation/rules/variables_are_input_types.rb index 0f3f6552af..01f1a5be36 100644 --- a/lib/graphql/static_validation/rules/variables_are_input_types.rb +++ b/lib/graphql/static_validation/rules/variables_are_input_types.rb @@ -1,28 +1,22 @@ # frozen_string_literal: true module GraphQL module StaticValidation - class VariablesAreInputTypes - include GraphQL::StaticValidation::Message::MessageHelper - - def validate(context) - context.visitor[GraphQL::Language::Nodes::VariableDefinition] << ->(node, parent) { - validate_is_input_type(node, context) - } - end - - private - - def validate_is_input_type(node, context) + module VariablesAreInputTypes + def on_variable_definition(node, parent) type_name = get_type_name(node.type) type = context.warden.get_type(type_name) if type.nil? - context.errors << message("#{type_name} isn't a defined input type (on $#{node.name})", node, context: context) + add_error("#{type_name} isn't a defined input type (on $#{node.name})", node) elsif !type.kind.input? - context.errors << message("#{type.name} isn't a valid input type (on $#{node.name})", node, context: context) + add_error("#{type.name} isn't a valid input type (on $#{node.name})", node) end + + super end + private + def get_type_name(ast_type) if ast_type.respond_to?(:of_type) get_type_name(ast_type.of_type) diff --git a/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb b/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb index c6277e627a..380748d09b 100644 --- a/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb +++ b/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb @@ -11,9 +11,7 @@ module StaticValidation # - re-visiting the AST for each validator # - allowing validators to say `followSpreads: true` # - class VariablesAreUsedAndDefined - include GraphQL::StaticValidation::Message::MessageHelper - + module VariablesAreUsedAndDefined class VariableUsage attr_accessor :ast_node, :used_by, :declared_by, :path def used? @@ -25,73 +23,67 @@ def declared? end end - def variable_hash - Hash.new {|h, k| h[k] = VariableUsage.new } + def initialize(*) + super + @variable_usages_for_context = Hash.new {|hash, key| hash[key] = Hash.new {|h, k| h[k] = VariableUsage.new } } + @spreads_for_context = Hash.new {|hash, key| hash[key] = [] } + @variable_context_stack = [] end - def validate(context) - variable_usages_for_context = Hash.new {|hash, key| hash[key] = variable_hash } - spreads_for_context = Hash.new {|hash, key| hash[key] = [] } - variable_context_stack = [] - - # OperationDefinitions and FragmentDefinitions - # both push themselves onto the context stack (and pop themselves off) - push_variable_context_stack = ->(node, parent) { - # initialize the hash of vars for this context: - variable_usages_for_context[node] - variable_context_stack.push(node) + def on_operation_definition(node, parent) + # initialize the hash of vars for this context: + @variable_usages_for_context[node] + @variable_context_stack.push(node) + # mark variables as defined: + var_hash = @variable_usages_for_context[node] + node.variables.each { |var| + var_usage = var_hash[var.name] + var_usage.declared_by = node + var_usage.path = context.path } + super + @variable_context_stack.pop + end - pop_variable_context_stack = ->(node, parent) { - variable_context_stack.pop - } - - - context.visitor[GraphQL::Language::Nodes::OperationDefinition] << push_variable_context_stack - context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) { - # mark variables as defined: - var_hash = variable_usages_for_context[node] - node.variables.each { |var| - var_usage = var_hash[var.name] - var_usage.declared_by = node - var_usage.path = context.path - } - } - context.visitor[GraphQL::Language::Nodes::OperationDefinition].leave << pop_variable_context_stack - - context.visitor[GraphQL::Language::Nodes::FragmentDefinition] << push_variable_context_stack - context.visitor[GraphQL::Language::Nodes::FragmentDefinition].leave << pop_variable_context_stack - - # For FragmentSpreads: - # - find the context on the stack - # - mark the context as containing this spread - context.visitor[GraphQL::Language::Nodes::FragmentSpread] << ->(node, parent) { - variable_context = variable_context_stack.last - spreads_for_context[variable_context] << node.name - } + def on_fragment_definition(node, parent) + # initialize the hash of vars for this context: + @variable_usages_for_context[node] + @variable_context_stack.push(node) + super + @variable_context_stack.pop + end - # For VariableIdentifiers: - # - mark the variable as used - # - assign its AST node - context.visitor[GraphQL::Language::Nodes::VariableIdentifier] << ->(node, parent) { - usage_context = variable_context_stack.last - declared_variables = variable_usages_for_context[usage_context] - usage = declared_variables[node.name] - usage.used_by = usage_context - usage.ast_node = node - usage.path = context.path - } + # For FragmentSpreads: + # - find the context on the stack + # - mark the context as containing this spread + def on_fragment_spread(node, parent) + variable_context = @variable_context_stack.last + @spreads_for_context[variable_context] << node.name + super + end + # For VariableIdentifiers: + # - mark the variable as used + # - assign its AST node + def on_variable_identifier(node, parent) + usage_context = @variable_context_stack.last + declared_variables = @variable_usages_for_context[usage_context] + usage = declared_variables[node.name] + usage.used_by = usage_context + usage.ast_node = node + usage.path = context.path + super + end - context.visitor[GraphQL::Language::Nodes::Document].leave << ->(node, parent) { - fragment_definitions = variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::FragmentDefinition) } - operation_definitions = variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::OperationDefinition) } + def on_document(node, parent) + super + fragment_definitions = @variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::FragmentDefinition) } + operation_definitions = @variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::OperationDefinition) } - operation_definitions.each do |node, node_variables| - follow_spreads(node, node_variables, spreads_for_context, fragment_definitions, []) - create_errors(node_variables, context) - end - } + operation_definitions.each do |node, node_variables| + follow_spreads(node, node_variables, @spreads_for_context, fragment_definitions, []) + create_errors(node_variables) + end end private @@ -129,16 +121,16 @@ def follow_spreads(node, parent_variables, spreads_for_context, fragment_definit # Determine all the error messages, # Then push messages into the validation context - def create_errors(node_variables, context) + def create_errors(node_variables) # Declared but not used: node_variables .select { |name, usage| usage.declared? && !usage.used? } - .each { |var_name, usage| context.errors << message("Variable $#{var_name} is declared by #{usage.declared_by.name} but not used", usage.declared_by, path: usage.path) } + .each { |var_name, usage| add_error("Variable $#{var_name} is declared by #{usage.declared_by.name} but not used", usage.declared_by, path: usage.path) } # Used but not declared: node_variables .select { |name, usage| usage.used? && !usage.declared? } - .each { |var_name, usage| context.errors << message("Variable $#{var_name} is used by #{usage.used_by.name} but not declared", usage.ast_node, path: usage.path) } + .each { |var_name, usage| add_error("Variable $#{var_name} is used by #{usage.used_by.name} but not declared", usage.ast_node, path: usage.path) } end end end diff --git a/lib/graphql/static_validation/validation_context.rb b/lib/graphql/static_validation/validation_context.rb index 7787936cfe..855b5b8c91 100644 --- a/lib/graphql/static_validation/validation_context.rb +++ b/lib/graphql/static_validation/validation_context.rb @@ -14,70 +14,35 @@ module StaticValidation class ValidationContext extend Forwardable - attr_reader :query, :errors, :visitor, :dependencies, :each_irep_node_handlers + + attr_reader :query, :errors, :visitor, + :on_dependency_resolve_handlers, :each_irep_node_handlers def_delegators :@query, :schema, :document, :fragments, :operations, :warden - def initialize(query) + def initialize(query, visitor_class) @query = query @literal_validator = LiteralValidator.new(context: query.context) @errors = [] - @visitor = GraphQL::Language::Visitor.new(document) - @type_stack = GraphQL::StaticValidation::TypeStack.new(schema, visitor) - definition_dependencies = DefinitionDependencies.mount(self) - @on_dependency_resolve_handlers = [] + # TODO it will take some finegalling but I think all this state could + # be moved to `Visitor` @each_irep_node_handlers = [] - visitor[GraphQL::Language::Nodes::Document].leave << ->(_n, _p) { - @dependencies = definition_dependencies.dependency_map { |defn, spreads, frag| - @on_dependency_resolve_handlers.each { |h| h.call(defn, spreads, frag) } - } - } + @on_dependency_resolve_handlers = [] + @visitor = visitor_class.new(document, self) end + def_delegators :@visitor, + :path, :type_definition, :field_definition, :argument_definition, + :parent_type_definition, :directive_definition, :object_types, :dependencies + def on_dependency_resolve(&handler) @on_dependency_resolve_handlers << handler end - def object_types - @type_stack.object_types - end - def each_irep_node(&handler) @each_irep_node_handlers << handler end - # @return [GraphQL::BaseType] The current object type - def type_definition - object_types.last - end - - # @return [GraphQL::BaseType] The type which the current type came from - def parent_type_definition - object_types[-2] - end - - # @return [GraphQL::Field, nil] The most-recently-entered GraphQL::Field, if currently inside one - def field_definition - @type_stack.field_definitions.last - end - - # @return [Array] Field names to get to the current field - def path - @type_stack.path.dup - end - - # @return [GraphQL::Directive, nil] The most-recently-entered GraphQL::Directive, if currently inside one - def directive_definition - @type_stack.directive_definitions.last - end - - # @return [GraphQL::Argument, nil] The most-recently-entered GraphQL::Argument, if currently inside one - def argument_definition - # Don't get the _last_ one because that's the current one. - # Get the second-to-last one, which is the parent of the current one. - @type_stack.argument_definitions[-2] - end - def valid_literal?(ast_value, type) @literal_validator.validate(ast_value, type) end diff --git a/lib/graphql/static_validation/validator.rb b/lib/graphql/static_validation/validator.rb index 32d349c3de..2b3eaf1094 100644 --- a/lib/graphql/static_validation/validator.rb +++ b/lib/graphql/static_validation/validator.rb @@ -23,23 +23,22 @@ def initialize(schema:, rules: GraphQL::StaticValidation::ALL_RULES) # @return [Array] def validate(query, validate: true) query.trace("validate", { validate: validate, query: query }) do - context = GraphQL::StaticValidation::ValidationContext.new(query) - rewrite = GraphQL::InternalRepresentation::Rewrite.new - # Put this first so its enters and exits are always called - rewrite.validate(context) + rules_to_use = validate ? @rules : [] + visitor_class = BaseVisitor.including_rules(rules_to_use) - # If the caller opted out of validation, don't attach these - if validate - @rules.each do |rules| - rules.new.validate(context) + context = GraphQL::StaticValidation::ValidationContext.new(query, visitor_class) + + # Attach legacy-style rules + rules_to_use.each do |rule_class_or_module| + if rule_class_or_module.method_defined?(:validate) + rule_class_or_module.new.validate(context) end end context.visitor.visit - rewrite_result = rewrite.document - # Post-validation: allow validators to register handlers on rewritten query nodes + rewrite_result = context.visitor.rewrite_document GraphQL::InternalRepresentation::Visit.visit_each_node(rewrite_result.operation_definitions, context.each_irep_node_handlers) { diff --git a/spec/graphql/language/nodes_spec.rb b/spec/graphql/language/nodes_spec.rb index 7c46a2024f..0e09ec8dac 100644 --- a/spec/graphql/language/nodes_spec.rb +++ b/spec/graphql/language/nodes_spec.rb @@ -44,4 +44,22 @@ def print_field_definition(print_field_definition) assert_equal expected.chomp, document.to_query_string(printer: custom_printer_class.new) end end + + describe "#visit_method" do + it "is implemented by all node classes" do + node_classes = GraphQL::Language::Nodes.constants - [:WrapperType, :NameOnlyNode] + node_classes.each do |const| + node_class = GraphQL::Language::Nodes.const_get(const) + if node_class.is_a?(Class) && node_class < GraphQL::Language::Nodes::AbstractNode + concrete_method = node_class.instance_method(:visit_method) + refute_nil concrete_method.super_method, "#{node_class} overrides #visit_method" + visit_method_name = "on_" + node_class.name + .split("::").last + .gsub(/([a-z\d])([A-Z])/,'\1_\2') # someThing -> some_Thing + .downcase + assert GraphQL::Language::Visitor.method_defined?(visit_method_name), "Language::Visitor has a method for #{node_class} (##{visit_method_name})" + end + end + end + end end diff --git a/spec/graphql/language/visitor_spec.rb b/spec/graphql/language/visitor_spec.rb index 10f5fda046..2e9565852f 100644 --- a/spec/graphql/language/visitor_spec.rb +++ b/spec/graphql/language/visitor_spec.rb @@ -16,10 +16,11 @@ fragment cheeseFields on Cheese { flavor } ")} - let(:counts) { {fields_entered: 0, arguments_entered: 0, arguments_left: 0, argument_names: []} } + let(:hooks_counts) { {fields_entered: 0, arguments_entered: 0, arguments_left: 0, argument_names: []} } - let(:visitor) do + let(:hooks_visitor) do v = GraphQL::Language::Visitor.new(document) + counts = hooks_counts v[GraphQL::Language::Nodes::Field] << ->(node, parent) { counts[:fields_entered] += 1 } # two ways to set up enter hooks: v[GraphQL::Language::Nodes::Argument] << ->(node, parent) { counts[:argument_names] << node.name } @@ -30,14 +31,43 @@ v end - it "calls hooks during a depth-first tree traversal" do - assert_equal(2, visitor[GraphQL::Language::Nodes::Argument].enter.length) - visitor.visit - assert_equal(6, counts[:fields_entered]) - assert_equal(2, counts[:arguments_entered]) - assert_equal(2, counts[:arguments_left]) - assert_equal(["id", "first"], counts[:argument_names]) - assert(counts[:finished]) + class VisitorSpecVisitor < GraphQL::Language::Visitor + attr_reader :counts + def initialize(document) + @counts = {fields_entered: 0, arguments_entered: 0, arguments_left: 0, argument_names: []} + super + end + + def on_field(node, parent) + counts[:fields_entered] += 1 + super(node, parent) + end + + def on_argument(node, parent) + counts[:argument_names] << node.name + counts[:arguments_entered] += 1 + super + ensure + counts[:arguments_left] += 1 + end + + def on_document(node, parent) + counts[:finished] = true + super + end + end + + class SkippingVisitor < VisitorSpecVisitor + def on_document(_n, _p) + SKIP + end + end + + let(:class_based_visitor) { VisitorSpecVisitor.new(document) } + let(:class_based_counts) { class_based_visitor.counts } + + it "has an array of hooks" do + assert_equal(2, hooks_visitor[GraphQL::Language::Nodes::Argument].enter.length) end it "can visit a document with directive definitions" do @@ -64,12 +94,31 @@ assert_equal "preview", directive.name assert_equal 10, directive_locations.length end - - describe "Visitor::SKIP" do - it "skips the rest of the node" do - visitor[GraphQL::Language::Nodes::Document] << ->(node, parent) { GraphQL::Language::Visitor::SKIP } + + [:hooks, :class_based].each do |visitor_type| + it "#{visitor_type} visitor calls hooks during a depth-first tree traversal" do + visitor = public_send("#{visitor_type}_visitor") visitor.visit - assert_equal(0, counts[:fields_entered]) + counts = public_send("#{visitor_type}_counts") + assert_equal(6, counts[:fields_entered]) + assert_equal(2, counts[:arguments_entered]) + assert_equal(2, counts[:arguments_left]) + assert_equal(["id", "first"], counts[:argument_names]) + assert(counts[:finished]) + end + + describe "Visitor::SKIP" do + let(:class_based_visitor) { SkippingVisitor.new(document) } + + it "#{visitor_type} visitor skips the rest of the node" do + visitor = public_send("#{visitor_type}_visitor") + if visitor_type == :hooks + visitor[GraphQL::Language::Nodes::Document] << ->(node, parent) { GraphQL::Language::Visitor::SKIP } + end + visitor.visit + counts = public_send("#{visitor_type}_counts") + assert_equal(0, counts[:fields_entered]) + end end end diff --git a/spec/graphql/static_validation/type_stack_spec.rb b/spec/graphql/static_validation/type_stack_spec.rb index 6ce5a3478a..a749fe01ca 100644 --- a/spec/graphql/static_validation/type_stack_spec.rb +++ b/spec/graphql/static_validation/type_stack_spec.rb @@ -1,19 +1,6 @@ # frozen_string_literal: true require "spec_helper" -class TypeCheckValidator - def self.checks - @checks ||= [] - end - - def validate(context) - self.class.checks.clear - context.visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) { - self.class.checks << context.object_types.map {|t| t.name || t.kind.name } - } - end -end - describe GraphQL::StaticValidation::TypeStack do let(:query_string) {%| query getCheese { @@ -22,17 +9,21 @@ def validate(context) fragment edibleFields on Edible { fatContent @skip(if: false)} |} - let(:validator) { GraphQL::StaticValidation::Validator.new(schema: Dummy::Schema, rules: [TypeCheckValidator]) } - let(:query) { GraphQL::Query.new(Dummy::Schema, query_string) } - - it "stores up types" do - validator.validate(query) + document = GraphQL.parse(query_string) + visitor = GraphQL::Language::Visitor.new(document) + type_stack = GraphQL::StaticValidation::TypeStack.new(Dummy::Schema, visitor) + checks = [] + visitor[GraphQL::Language::Nodes::Field].enter << ->(node, parent) { + checks << type_stack.object_types.map {|t| t.name || t.kind.name } + } + visitor.visit + expected = [ ["Query", "Cheese"], ["Query", "Cheese", "NON_NULL"], ["Edible", "NON_NULL"] ] - assert_equal(expected, TypeCheckValidator.checks) + assert_equal(expected, checks) end end diff --git a/spec/graphql/static_validation/validator_spec.rb b/spec/graphql/static_validation/validator_spec.rb index 3982d824f1..8a41dc61f6 100644 --- a/spec/graphql/static_validation/validator_spec.rb +++ b/spec/graphql/static_validation/validator_spec.rb @@ -153,4 +153,50 @@ end end end + + describe "Custom ruleset" do + let(:query_string) { " + fragment Thing on Cheese { + __typename + similarCheese(source: COW) + } + " + } + + let(:rules) { + # This is from graphql-client, eg + # https://github.com/github/graphql-client/blob/c86fc05d7eba2370452592bb93572caced4123af/lib/graphql/client.rb#L168 + GraphQL::StaticValidation::ALL_RULES - [ + GraphQL::StaticValidation::FragmentsAreUsed, + GraphQL::StaticValidation::FieldsHaveAppropriateSelections + ] + } + let(:validator) { GraphQL::StaticValidation::Validator.new(schema: Dummy::Schema, rules: rules) } + + it "runs the specified rules" do + assert_equal 0, errors.size + end + + describe "With a legacy-style rule" do + # GraphQL-Pro's operation store uses this + class ValidatorSpecLegacyRule + include GraphQL::StaticValidation::Message::MessageHelper + def validate(ctx) + ctx.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(n, _p) { + ctx.errors << message("Busted!", n, context: ctx) + } + end + end + + let(:rules) { + GraphQL::StaticValidation::ALL_RULES + [ValidatorSpecLegacyRule] + } + + let(:query_string) { "{ __typename }"} + + it "runs the rule" do + assert_equal ["Busted!"], errors.map { |e| e["message"] } + end + end + end end diff --git a/spec/support/static_validation_helpers.rb b/spec/support/static_validation_helpers.rb index c620dedfb8..7d510e8b4d 100644 --- a/spec/support/static_validation_helpers.rb +++ b/spec/support/static_validation_helpers.rb @@ -12,10 +12,12 @@ # end module StaticValidationHelpers def errors - target_schema = schema - validator = GraphQL::StaticValidation::Validator.new(schema: target_schema) - query = GraphQL::Query.new(target_schema, query_string) - validator.validate(query)[:errors].map(&:to_h) + @errors ||= begin + target_schema = schema + validator = GraphQL::StaticValidation::Validator.new(schema: target_schema) + query = GraphQL::Query.new(target_schema, query_string) + validator.validate(query)[:errors].map(&:to_h) + end end def error_messages