diff --git a/guides/guides.html b/guides/guides.html
index b96b4868c7..af72ad5d28 100644
--- a/guides/guides.html
+++ b/guides/guides.html
@@ -14,6 +14,7 @@
- name: GraphQL Pro
- name: GraphQL Pro - OperationStore
- name: JavaScript Client
+ - name: Language Tools
- name: Other
---
diff --git a/guides/language_tools/visitor.md b/guides/language_tools/visitor.md
new file mode 100644
index 0000000000..f6cb3ce6bd
--- /dev/null
+++ b/guides/language_tools/visitor.md
@@ -0,0 +1,143 @@
+---
+layout: guide
+doc_stub: false
+search: true
+section: Language Tools
+title: AST Visitor
+desc: Analyze and modify parsed GraphQL code
+index: 0
+---
+
+GraphQL code is usually contained in a string, for example:
+
+```ruby
+query_string = "query { user(id: \"1\") { userName } }"
+```
+
+You can perform programmatic analysis and modifications to GraphQL code using a three-step process:
+
+- __Parse__ the code into an abstract syntax tree
+- __Analyze/Modify__ the code with a visitor
+- __Print__ the code back to a string
+
+## Parse
+
+{{ "GraphQL.parse" | api_doc }} turns a string into a GraphQL document:
+
+```ruby
+parsed_doc = GraphQL.parse("{ user(id: \"1\") { userName } }")
+# => #
+```
+
+Also, {{ "GraphQL.parse_file" | api_doc }} parses the contents of the named file and includes a `filename` in the parsed document.
+
+#### AST Nodes
+
+The parsed document is a tree of nodes, called an _abstract syntax tree_ (AST). This tree is _immutable_: once a document has been parsed, those Ruby objects can't be changed. Modifications are performed by _copying_ existing nodes, applying changes to the copy, then making a new tree to hold the copied node. Where possible, unmodified nodes are retained in the new tree (it's _persistent_).
+
+The copy-and-modify workflow is supported by a few methods on the AST nodes:
+
+- `.merge(new_attrs)` returns a copy of the node with `new_attrs` applied. This new copy can replace the original node.
+- `.add_{child}(new_child_attrs)` makes a new node with `new_child_attrs`, adds it to the array specified by `{child}`, and returns a copy whose `{children}` array contains the newly created node.
+
+For example, to rename a field and add an argument to it, you could:
+
+```ruby
+modified_node = field_node
+ # Apply a new name
+ .merge(name: "newName")
+ # Add an argument to this field's arguments
+ .add_argument(name: "newArgument", value: "newValue")
+```
+
+Above, `field_node` is unmodified, but `modified_node` reflects the new name and new argument.
+
+## Analyze/Modify
+
+To inspect or modify a parsed document, extend {{ "GraphQL::Language::Visitor" | api_doc }} and implement its various hooks. It's an implementation of the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern). In short, each node of the tree will be "visited" by calling a method, and those methods can gather information and perform modifications.
+
+In the visitor, each node class has a hook, for example:
+
+- {{ "GraphQL::Language::Nodes::Field" | api_doc }}s are routed to `#on_field`
+- {{ "GraphQL::Language::Nodes::Argument" | api_doc }}s are routed to `#on_argument`
+
+See the {{ "GraphQL::Language::Visitor" | api_doc }} API docs for a full list of methods.
+
+Each method is called with `(node, parent)`, where:
+
+- `node` is the AST node currently visited
+- `parent` is the AST node above this node in the tree
+
+The method has a few options for analyzing or modifying the AST:
+
+#### Continue/Halt
+
+To continue visiting, the hook should call `super`. This allows the visit to continue to `node`'s children in the tree, for example:
+
+```ruby
+def on_field(_node, _parent)
+ # Do nothing, this is the default behavior:
+ super
+end
+```
+
+To _halt_ the visit, a method may skip the call to `super`. For example, if the visitor encountered an error, it might want to return early instead of continuing to visit.
+
+#### Modify a Node
+
+Visitor hooks are expected to return the `(node, parent)` they are called with. If they return a different node, then that node will replace the original `node`. When you call `super(node, parent)`, the `node` is returned. So, to modify a node and continue visiting:
+
+- Make a modified copy of `node`
+- Pass the modified copy to `super(new_node, parent)`
+
+For example, to rename an argument:
+
+```ruby
+def on_argument(node, parent)
+ # make a copy of `node` with a new name
+ modified_node = node.merge(name: "renamed")
+ # continue visiting with the modified node and parent
+ super(modified_node, parent)
+end
+```
+
+#### Delete a Node
+
+To delete the currently-visited `node`, don't pass `node` to `super(...)`. Instead, pass a magic constant, `DELETE_NODE`, in place of `node`.
+
+For example, to delete a directive:
+
+```ruby
+def on_directive(node, parent)
+ # Don't pass `node` to `super`,
+ # instead, pass `DELETE_NODE`
+ super(DELETE_NODE, parent)
+end
+```
+
+#### Insert a Node
+
+Inserting nodes is similar to modifying nodes. To insert a new child into `node`, call one of its `.add_` helpers. This returns a copied node with a new child added. For example, to add a selection to a field's selection set:
+
+```ruby
+def on_field(node, parent)
+ node_with_selection = node.add_selection(name: "emailAddress")
+ super(node_with_selection, parent)
+end
+```
+
+This will add `emailAddress` the fields selection on `node`.
+
+
+(These `.add_*` helpers are wrappers around {{ "GraphQL::Language::Nodes::AbstractNode#merge" | api_doc }}.)
+
+## Print
+
+The easiest way to turn an AST back into a string of GraphQL is {{ "GraphQL::Language::Nodes::AbstractNode#to_query_string" | api_doc }}, for example:
+
+```ruby
+parsed_doc.to_query_string
+# => '{ user(id: "1") { userName } }'
+```
+
+You can also create a subclass of {{ "GraphQL::Language::Printer" | api_doc }} to customize how nodes are printed.
diff --git a/guides/type_definitions/field_extensions.md b/guides/type_definitions/field_extensions.md
new file mode 100644
index 0000000000..de8cdc09de
--- /dev/null
+++ b/guides/type_definitions/field_extensions.md
@@ -0,0 +1,104 @@
+---
+layout: guide
+doc_stub: false
+search: true
+section: Type Definitions
+title: Field Extensions
+desc: Programmatically modify field configuration and resolution
+index: 10
+class_based_api: true
+---
+
+{{ "GraphQL::Schema::FieldExtension" | api_doc }} provides a way to modify user-defined fields in a programmatic way. For example, Relay connections are implemented as a field extension ({{ "GraphQL::Schema::Field::ConnectionExtension" | api_doc }}).
+
+### Making a new extension
+
+Field extensions are subclasses of {{ "GraphQL::Schema::FieldExtension" | api_doc }}:
+
+```ruby
+class MyExtension < GraphQL::Schema::FieldExtension
+end
+```
+
+### Using an extension
+
+Defined extensions can be added to fields using the `extensions: [...]` option or the `extension(...)` method:
+
+```ruby
+field :name, String, null: false, extensions: [UpcaseExtension]
+# or:
+field :description, String, null: false do
+ extension(UpcaseExtension)
+end
+```
+
+See below for how extensions may modify fields.
+
+### Modifying field configuration
+
+When extensions are attached, they are initialized with a `field:` and `options:`. Then, `#apply` is called, when they may extend the field they're attached to. For example:
+
+```ruby
+class SearchableExtension < GraphQL::Schema::FieldExtension
+ def apply
+ # add an argument to this field:
+ field.argument(:query, String, required: false, description: "A search query")
+ end
+end
+```
+
+This way, an extension can encapsulate a behavior requiring several configuration options.
+
+### Modifying field execution
+
+Extensions have two hooks that wrap field resolution. Since GraphQL-Ruby supports deferred execution, these hooks _might not_ be called back-to-back.
+
+First, {{ "GraphQL::Schema::FieldExtension#before_resolve" | api_doc }} is called. `before_resolve` should `yield(object, arguments)` to continue execution. If it doesn't `yield`, then the field won't resolve, and the methods return value will be returned to GraphQL instead.
+
+After resolution, {{ "GraphQL::Schema::FieldExtension#after_resolve" | api_doc }} is called. Whatever that method returns will be used as the field's return value.
+
+See the linked API docs for the parameters of those methods.
+
+#### Execution "memo"
+
+One parameter to `after_resolve` deserves special attention: `memo:`. `before_resolve` _may_ yield a third value. For example:
+
+```ruby
+def before_resolve(object:, arguments:, **rest)
+ # yield the current time as `memo`
+ yield(object, arguments, Time.now.to_i)
+end
+```
+
+If a third value is yielded, it will be passed to `after_resolve` as `memo:`, for example:
+
+```ruby
+def after_resolve(value:, memo:, **rest)
+ puts "Elapsed: #{Time.now.to_i - memo}"
+ # Return the original value
+ value
+end
+```
+
+This allows the `before_resolve` hook to pass data to `after_resolve`.
+
+Instance variables may not be used because, in a given GraphQL query, the same field may be resolved several times concurrently, and that would result in overriding the instance variable in an unpredictable way. (In fact, extensions are frozen to prevent instance variable writes.)
+
+### Extension options
+
+The `extension(...)` method takes an optional second argument, for example:
+
+```ruby
+extension(LimitExtension, limit: 20)
+```
+
+In this case, `{limit: 20}` will be passed as `options:` to `#initialize` and `options[:limit]` will be `20`.
+
+For example, options can be used for modifying execution:
+
+```ruby
+def after_resolve(value:, **rest)
+ # Apply the limit from the options
+ value.limit(options[:limit])
+end
+```
diff --git a/lib/graphql/compatibility/schema_parser_specification.rb b/lib/graphql/compatibility/schema_parser_specification.rb
index bfe86fd8af..5d1c754e67 100644
--- a/lib/graphql/compatibility/schema_parser_specification.rb
+++ b/lib/graphql/compatibility/schema_parser_specification.rb
@@ -595,31 +595,27 @@ def test_it_parses_whole_definition_with_descriptions
assert_equal 6, document.definitions.size
- schema_definition = document.definitions.shift
+ schema_definition, directive_definition, enum_type_definition, object_type_definition, input_object_type_definition, interface_type_definition = document.definitions
+
assert_equal GraphQL::Language::Nodes::SchemaDefinition, schema_definition.class
- directive_definition = document.definitions.shift
assert_equal GraphQL::Language::Nodes::DirectiveDefinition, directive_definition.class
assert_equal 'This is a directive', directive_definition.description
- enum_type_definition = document.definitions.shift
assert_equal GraphQL::Language::Nodes::EnumTypeDefinition, enum_type_definition.class
assert_equal "Multiline comment\n\nWith an enum", enum_type_definition.description
assert_nil enum_type_definition.values[0].description
assert_equal 'Not a creative color', enum_type_definition.values[1].description
- object_type_definition = document.definitions.shift
assert_equal GraphQL::Language::Nodes::ObjectTypeDefinition, object_type_definition.class
assert_equal 'Comment without preceding space', object_type_definition.description
assert_equal 'And a field to boot', object_type_definition.fields[0].description
- input_object_type_definition = document.definitions.shift
assert_equal GraphQL::Language::Nodes::InputObjectTypeDefinition, input_object_type_definition.class
assert_equal 'Comment for input object types', input_object_type_definition.description
assert_equal 'Color of the car', input_object_type_definition.fields[0].description
- interface_type_definition = document.definitions.shift
assert_equal GraphQL::Language::Nodes::InterfaceTypeDefinition, interface_type_definition.class
assert_equal 'Comment for interface definitions', interface_type_definition.description
assert_equal 'Amount of wheels', interface_type_definition.fields[0].description
diff --git a/lib/graphql/internal_representation/rewrite.rb b/lib/graphql/internal_representation/rewrite.rb
index 033ffc3a25..5f5c7a30a7 100644
--- a/lib/graphql/internal_representation/rewrite.rb
+++ b/lib/graphql/internal_representation/rewrite.rb
@@ -13,127 +13,29 @@ module InternalRepresentation
#
# The rewritten query tree serves as the basis for the `FieldsWillMerge` validation.
#
- class Rewrite
+ module Rewrite
include GraphQL::Language
NO_DIRECTIVES = [].freeze
# @return InternalRepresentation::Document
- attr_reader :document
+ attr_reader :rewrite_document
- def initialize
- @document = InternalRepresentation::Document.new
- end
-
- # @return [Hash] Roots of this query
- def operations
- warn "#{self.class}#operations is deprecated; use `document.operation_definitions` instead"
- document.operation_definitions
- end
-
- def validate(context)
- visitor = context.visitor
- query = context.query
+ def initialize(*)
+ super
+ @query = context.query
+ @rewrite_document = InternalRepresentation::Document.new
# Hash Set>
# A record of fragment spreads and the irep nodes that used them
- spread_parents = Hash.new { |h, k| h[k] = Set.new }
+ @rewrite_spread_parents = Hash.new { |h, k| h[k] = Set.new }
# Hash Scope>
- spread_scopes = {}
+ @rewrite_spread_scopes = {}
# Array>
# The current point of the irep_tree during visitation
- nodes_stack = []
+ @rewrite_nodes_stack = []
# Array
- scopes_stack = []
-
- skip_nodes = Set.new
-
- visit_op = VisitDefinition.new(context, @document.operation_definitions, nodes_stack, scopes_stack)
- visitor[Nodes::OperationDefinition].enter << visit_op.method(:enter)
- visitor[Nodes::OperationDefinition].leave << visit_op.method(:leave)
-
- visit_frag = VisitDefinition.new(context, @document.fragment_definitions, nodes_stack, scopes_stack)
- visitor[Nodes::FragmentDefinition].enter << visit_frag.method(:enter)
- visitor[Nodes::FragmentDefinition].leave << visit_frag.method(:leave)
-
- visitor[Nodes::InlineFragment].enter << ->(ast_node, ast_parent) {
- # Inline fragments provide two things to the rewritten tree:
- # - They _may_ narrow the scope by their type condition
- # - They _may_ apply their directives to their children
- if skip?(ast_node, query)
- skip_nodes.add(ast_node)
- end
-
- if skip_nodes.none?
- scopes_stack.push(scopes_stack.last.enter(context.type_definition))
- end
- }
-
- visitor[Nodes::InlineFragment].leave << ->(ast_node, ast_parent) {
- if skip_nodes.none?
- scopes_stack.pop
- end
-
- if skip_nodes.include?(ast_node)
- skip_nodes.delete(ast_node)
- end
- }
-
- visitor[Nodes::Field].enter << ->(ast_node, ast_parent) {
- if skip?(ast_node, query)
- skip_nodes.add(ast_node)
- end
-
- if skip_nodes.none?
- node_name = ast_node.alias || ast_node.name
- parent_nodes = nodes_stack.last
- next_nodes = []
-
- field_defn = context.field_definition
- if field_defn.nil?
- # It's a non-existent field
- new_scope = nil
- else
- field_return_type = field_defn.type
- scopes_stack.last.each do |scope_type|
- parent_nodes.each do |parent_node|
- node = parent_node.scoped_children[scope_type][node_name] ||= Node.new(
- parent: parent_node,
- name: node_name,
- owner_type: scope_type,
- query: query,
- return_type: field_return_type,
- )
- node.ast_nodes << ast_node
- node.definitions << field_defn
- next_nodes << node
- end
- end
- new_scope = Scope.new(query, field_return_type.unwrap)
- end
-
- nodes_stack.push(next_nodes)
- scopes_stack.push(new_scope)
- end
- }
-
- visitor[Nodes::Field].leave << ->(ast_node, ast_parent) {
- if skip_nodes.none?
- nodes_stack.pop
- scopes_stack.pop
- end
-
- if skip_nodes.include?(ast_node)
- skip_nodes.delete(ast_node)
- end
- }
-
- visitor[Nodes::FragmentSpread].enter << ->(ast_node, ast_parent) {
- if skip_nodes.none? && !skip?(ast_node, query)
- # Register the irep nodes that depend on this AST node:
- spread_parents[ast_node].merge(nodes_stack.last)
- spread_scopes[ast_node] = scopes_stack.last
- end
- }
+ @rewrite_scopes_stack = []
+ @rewrite_skip_nodes = Set.new
# Resolve fragment spreads.
# Fragment definitions got their own irep trees during visitation.
@@ -142,12 +44,12 @@ def validate(context)
# can be shared between its usages.
context.on_dependency_resolve do |defn_ast_node, spread_ast_nodes, frag_ast_node|
frag_name = frag_ast_node.name
- fragment_node = @document.fragment_definitions[frag_name]
+ fragment_node = @rewrite_document.fragment_definitions[frag_name]
if fragment_node
spread_ast_nodes.each do |spread_ast_node|
- parent_nodes = spread_parents[spread_ast_node]
- parent_scope = spread_scopes[spread_ast_node]
+ parent_nodes = @rewrite_spread_parents[spread_ast_node]
+ parent_scope = @rewrite_spread_scopes[spread_ast_node]
parent_nodes.each do |parent_node|
parent_node.deep_merge_node(fragment_node, scope: parent_scope, merge_self: false)
end
@@ -156,43 +58,126 @@ def validate(context)
end
end
- def skip?(ast_node, query)
- dir = ast_node.directives
- dir.any? && !GraphQL::Execution::DirectiveChecks.include?(dir, query)
+ # @return [Hash] Roots of this query
+ def operations
+ warn "#{self.class}#operations is deprecated; use `document.operation_definitions` instead"
+ @document.operation_definitions
+ end
+
+ def on_operation_definition(ast_node, parent)
+ push_root_node(ast_node, @rewrite_document.operation_definitions) { super }
+ end
+
+ def on_fragment_definition(ast_node, parent)
+ push_root_node(ast_node, @rewrite_document.fragment_definitions) { super }
+ end
+
+ def push_root_node(ast_node, definitions)
+ # Either QueryType or the fragment type condition
+ owner_type = context.type_definition
+ defn_name = ast_node.name
+
+ node = Node.new(
+ parent: nil,
+ name: defn_name,
+ owner_type: owner_type,
+ query: @query,
+ ast_nodes: [ast_node],
+ return_type: owner_type,
+ )
+
+ definitions[defn_name] = node
+ @rewrite_scopes_stack.push(Scope.new(@query, owner_type))
+ @rewrite_nodes_stack.push([node])
+ yield
+ @rewrite_nodes_stack.pop
+ @rewrite_scopes_stack.pop
+ end
+
+ def on_inline_fragment(node, parent)
+ # Inline fragments provide two things to the rewritten tree:
+ # - They _may_ narrow the scope by their type condition
+ # - They _may_ apply their directives to their children
+ if skip?(node)
+ @rewrite_skip_nodes.add(node)
+ end
+
+ if @rewrite_skip_nodes.none?
+ @rewrite_scopes_stack.push(@rewrite_scopes_stack.last.enter(context.type_definition))
+ end
+
+ super
+
+ if @rewrite_skip_nodes.none?
+ @rewrite_scopes_stack.pop
+ end
+
+ if @rewrite_skip_nodes.include?(node)
+ @rewrite_skip_nodes.delete(node)
+ end
end
- class VisitDefinition
- def initialize(context, definitions, nodes_stack, scopes_stack)
- @context = context
- @query = context.query
- @definitions = definitions
- @nodes_stack = nodes_stack
- @scopes_stack = scopes_stack
+ def on_field(ast_node, ast_parent)
+ if skip?(ast_node)
+ @rewrite_skip_nodes.add(ast_node)
+ end
+
+ if @rewrite_skip_nodes.none?
+ node_name = ast_node.alias || ast_node.name
+ parent_nodes = @rewrite_nodes_stack.last
+ next_nodes = []
+
+ field_defn = context.field_definition
+ if field_defn.nil?
+ # It's a non-existent field
+ new_scope = nil
+ else
+ field_return_type = field_defn.type
+ @rewrite_scopes_stack.last.each do |scope_type|
+ parent_nodes.each do |parent_node|
+ node = parent_node.scoped_children[scope_type][node_name] ||= Node.new(
+ parent: parent_node,
+ name: node_name,
+ owner_type: scope_type,
+ query: @query,
+ return_type: field_return_type,
+ )
+ node.ast_nodes << ast_node
+ node.definitions << field_defn
+ next_nodes << node
+ end
+ end
+ new_scope = Scope.new(@query, field_return_type.unwrap)
+ end
+
+ @rewrite_nodes_stack.push(next_nodes)
+ @rewrite_scopes_stack.push(new_scope)
+ end
+
+ super
+
+ if @rewrite_skip_nodes.none?
+ @rewrite_nodes_stack.pop
+ @rewrite_scopes_stack.pop
end
- def enter(ast_node, ast_parent)
- # Either QueryType or the fragment type condition
- owner_type = @context.type_definition && @context.type_definition.unwrap
- defn_name = ast_node.name
-
- node = Node.new(
- parent: nil,
- name: defn_name,
- owner_type: owner_type,
- query: @query,
- ast_nodes: [ast_node],
- return_type: @context.type_definition,
- )
-
- @definitions[defn_name] = node
- @scopes_stack.push(Scope.new(@query, owner_type))
- @nodes_stack.push([node])
+ if @rewrite_skip_nodes.include?(ast_node)
+ @rewrite_skip_nodes.delete(ast_node)
end
+ end
- def leave(ast_node, ast_parent)
- @nodes_stack.pop
- @scopes_stack.pop
+ def on_fragment_spread(ast_node, ast_parent)
+ if @rewrite_skip_nodes.none? && !skip?(ast_node)
+ # Register the irep nodes that depend on this AST node:
+ @rewrite_spread_parents[ast_node].merge(@rewrite_nodes_stack.last)
+ @rewrite_spread_scopes[ast_node] = @rewrite_scopes_stack.last
end
+ super
+ end
+
+ def skip?(ast_node)
+ dir = ast_node.directives
+ dir.any? && !GraphQL::Execution::DirectiveChecks.include?(dir, @query)
end
end
end
diff --git a/lib/graphql/language/document_from_schema_definition.rb b/lib/graphql/language/document_from_schema_definition.rb
index 00bf9f0099..6d43166955 100644
--- a/lib/graphql/language/document_from_schema_definition.rb
+++ b/lib/graphql/language/document_from_schema_definition.rb
@@ -65,7 +65,7 @@ def build_field_node(field)
)
if field.deprecation_reason
- field_node.directives << GraphQL::Language::Nodes::Directive.new(
+ field_node = field_node.merge_directive(
name: GraphQL::Directive::DeprecatedDirective.name,
arguments: [GraphQL::Language::Nodes::Argument.new(name: "reason", value: field.deprecation_reason)]
)
@@ -107,7 +107,7 @@ def build_enum_value_node(enum_value)
)
if enum_value.deprecation_reason
- enum_value_node.directives << GraphQL::Language::Nodes::Directive.new(
+ enum_value_node = enum_value_node.merge_directive(
name: GraphQL::Directive::DeprecatedDirective.name,
arguments: [GraphQL::Language::Nodes::Argument.new(name: "reason", value: enum_value.deprecation_reason)]
)
@@ -124,16 +124,19 @@ def build_scalar_type_node(scalar_type)
end
def build_argument_node(argument)
+ if argument.default_value?
+ default_value = build_default_value(argument.default_value, argument.type)
+ else
+ default_value = nil
+ end
+
argument_node = GraphQL::Language::Nodes::InputValueDefinition.new(
name: argument.name,
description: argument.description,
type: build_type_name_node(argument.type),
+ default_value: default_value,
)
- if argument.default_value?
- argument_node.default_value = build_default_value(argument.default_value, argument.type)
- end
-
argument_node
end
diff --git a/lib/graphql/language/nodes.rb b/lib/graphql/language/nodes.rb
index a907744a5d..d6e62b2038 100644
--- a/lib/graphql/language/nodes.rb
+++ b/lib/graphql/language/nodes.rb
@@ -10,15 +10,7 @@ module Nodes
# - `to_query_string` turns an AST node into a GraphQL string
class AbstractNode
- module Scalars # :nodoc:
- module Name
- def scalars
- super + [name]
- end
- end
- end
-
- attr_accessor :line, :col, :filename
+ attr_reader :line, :col, :filename
# Initialize a node by extracting its position,
# then calling the class's `initialize_node` method.
@@ -34,11 +26,6 @@ def initialize(options={})
initialize_node(options)
end
- # This is called with node-specific options
- def initialize_node(options={})
- raise NotImplementedError
- end
-
# Value equality
# @return [Boolean] True if `self` is equivalent to `other`
def eql?(other)
@@ -58,6 +45,11 @@ def scalars
[]
end
+ # @return [Symbol] the method to call on {Language::Visitor} for this node
+ def visit_method
+ raise NotImplementedError, "#{self.class.name}#visit_method shold return a symbol"
+ end
+
def position
[line, col]
end
@@ -65,35 +57,202 @@ def position
def to_query_string(printer: GraphQL::Language::Printer.new)
printer.print(self)
end
- end
- # Base class for non-null type names and list type names
- class WrapperType < AbstractNode
- attr_accessor :of_type
+ # This creates a copy of `self`, with `new_options` applied.
+ # @param new_options [Hash]
+ # @return [AbstractNode] a shallow copy of `self`
+ def merge(new_options)
+ copied_self = dup
+ new_options.each do |key, value|
+ copied_self.instance_variable_set(:"@#{key}", value)
+ end
+ copied_self
+ end
+
+ # Copy `self`, but modify the copy so that `previous_child` is replaced by `new_child`
+ def replace_child(previous_child, new_child)
+ # Figure out which list `previous_child` may be found in
+ method_name = previous_child.children_method_name
+ # Get the value from this (original) node
+ prev_children = public_send(method_name)
+ if prev_children.is_a?(Array)
+ # Copy that list, and replace `previous_child` with `new_child`
+ # in the list.
+ new_children = public_send(method_name).dup
+ prev_idx = new_children.index(previous_child)
+ new_children[prev_idx] = new_child
+ else
+ # Use the new value for the given attribute
+ new_children = new_child
+ end
+ # Copy this node, but with the new child value
+ copy_of_self = merge(method_name => new_children)
+ # Return the copy:
+ copy_of_self
+ end
+
+ # TODO DRY with `replace_child`
+ def delete_child(previous_child)
+ # Figure out which list `previous_child` may be found in
+ method_name = previous_child.children_method_name
+ # Copy that list, and delete previous_child
+ new_children = public_send(method_name).dup
+ new_children.delete(previous_child)
+ # Copy this node, but with the new list of children:
+ copy_of_self = merge(method_name => new_children)
+ # Return the copy:
+ copy_of_self
+ end
+
+ class << self
+ # Add a default `#visit_method` and `#children_method_name` using the class name
+ def inherited(child_class)
+ super
+ name_underscored = child_class.name
+ .split("::").last
+ .gsub(/([a-z])([A-Z])/,'\1_\2') # insert underscores
+ .downcase # remove caps
+
+ child_class.module_eval <<-RUBY
+ def visit_method
+ :on_#{name_underscored}
+ end
+
+ def children_method_name
+ :#{name_underscored}s
+ end
+ RUBY
+ end
- def initialize_node(of_type: nil)
- @of_type = of_type
- end
+ private
+
+ # Name accessors which return lists of nodes,
+ # along with the kind of node they return, if possible.
+ # - Add a reader for these children
+ # - Add a persistent update method to add a child
+ # - Generate a `#children` method
+ def children_methods(children_of_type)
+ if @children_methods
+ raise "Can't re-call .children_methods for #{self} (already have: #{@children_methods})"
+ else
+ @children_methods = children_of_type
+ end
- def scalars
- [of_type]
+ if children_of_type == false
+ @children_methods = {}
+ # skip
+ else
+
+ children_of_type.each do |method_name, node_type|
+ module_eval <<-RUBY, __FILE__, __LINE__
+ # A reader for these children
+ attr_reader :#{method_name}
+ RUBY
+
+ if node_type
+ # Only generate a method if we know what kind of node to make
+ module_eval <<-RUBY, __FILE__, __LINE__
+ # Singular method: create a node with these options
+ # and return a new `self` which includes that node in this list.
+ def merge_#{method_name.to_s.sub(/s$/, "")}(node_opts)
+ merge(#{method_name}: #{method_name} + [#{node_type.name}.new(node_opts)])
+ end
+ RUBY
+ end
+ end
+
+ if children_of_type.size == 1
+ module_eval <<-RUBY, __FILE__, __LINE__
+ alias :children #{children_of_type.keys.first}
+ RUBY
+ else
+ module_eval <<-RUBY, __FILE__, __LINE__
+ def children
+ @children ||= (#{children_of_type.keys.map { |k| "@#{k}" }.join(" + ")}).freeze
+ end
+ RUBY
+ end
+ end
+
+ if defined?(@scalar_methods)
+ generate_initialize_node
+ else
+ raise "Can't generate_initialize_node because scalar_methods wasn't called; call it before children_methods"
+ end
+ end
+
+ # These methods return a plain Ruby value, not another node
+ # - Add reader methods
+ # - Add a `#scalars` method
+ def scalar_methods(*method_names)
+ if @scalar_methods
+ raise "Can't re-call .scalar_methods for #{self} (already have: #{@scalar_methods})"
+ else
+ @scalar_methods = method_names
+ end
+
+ if method_names == [false]
+ @scalar_methods = []
+ # skip it
+ else
+ module_eval <<-RUBY, __FILE__, __LINE__
+ # add readers for each scalar
+ attr_reader #{method_names.map { |m| ":#{m}"}.join(", ")}
+
+ def scalars
+ @scalars ||= [#{method_names.map { |k| "@#{k}" }.join(", ")}].freeze
+ end
+ RUBY
+ end
+ end
+
+ def generate_initialize_node
+ scalar_method_names = @scalar_methods
+ # TODO: These probably should be scalar methods, but `types` returns an array
+ [:types, :description].each do |extra_method|
+ if method_defined?(extra_method)
+ scalar_method_names += [extra_method]
+ end
+ end
+
+ all_method_names = scalar_method_names + @children_methods.keys
+ if all_method_names.include?(:alias)
+ # Rather than complicating this special case,
+ # let it be overridden (in field)
+ return
+ else
+ arguments = scalar_method_names.map { |m| "#{m}: nil"} +
+ @children_methods.keys.map { |m| "#{m}: []" }
+
+ assignments = scalar_method_names.map { |m| "@#{m} = #{m}"} +
+ @children_methods.keys.map { |m| "@#{m} = #{m}.freeze" }
+
+ module_eval <<-RUBY, __FILE__, __LINE__
+ def initialize_node #{arguments.join(", ")}
+ #{assignments.join("\n")}
+ end
+ RUBY
+ end
+ end
end
end
+ # Base class for non-null type names and list type names
+ class WrapperType < AbstractNode
+ scalar_methods :of_type
+ children_methods(false)
+ end
+
# Base class for nodes whose only value is a name (no child nodes or other scalars)
class NameOnlyNode < AbstractNode
- include Scalars::Name
-
- attr_accessor :name
-
- def initialize_node(name: nil)
- @name = name
- end
+ scalar_methods :name
+ children_methods(false)
end
# A key-value pair for a field's inputs
class Argument < AbstractNode
- attr_accessor :name, :value
+ scalar_methods :name, :value
+ children_methods(false)
# @!attribute name
# @return [String] the key for this argument
@@ -101,51 +260,28 @@ class Argument < AbstractNode
# @!attribute value
# @return [String, Float, Integer, Boolean, Array, InputObject] The value passed for this key
- def initialize_node(name: nil, value: nil)
- @name = name
- @value = value
- end
-
- def scalars
- [name, value]
- end
-
def children
[value].flatten.select { |v| v.is_a?(AbstractNode) }
end
end
class Directive < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :arguments
- alias :children :arguments
+ scalar_methods :name
+ children_methods(arguments: GraphQL::Language::Nodes::Argument)
+ end
- def initialize_node(name: nil, arguments: [])
- @name = name
- @arguments = arguments
- end
+ class DirectiveLocation < NameOnlyNode
end
class DirectiveDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :arguments, :locations, :description
-
- def initialize_node(name: nil, arguments: [], locations: [], description: nil)
- @name = name
- @arguments = arguments
- @locations = locations
- @description = description
- end
-
- def children
- arguments + locations
- end
+ attr_reader :description
+ scalar_methods :name
+ children_methods(
+ locations: Nodes::DirectiveLocation,
+ arguments: Nodes::Argument,
+ )
end
- class DirectiveLocation < NameOnlyNode; end
-
# This is the AST root for normal queries
#
# @example Deriving a document by parsing a string
@@ -165,14 +301,10 @@ class DirectiveLocation < NameOnlyNode; end
# document.to_query_string(printer: VariableSrubber.new)
#
class Document < AbstractNode
- attr_accessor :definitions
- alias :children :definitions
-
+ scalar_methods false
+ children_methods(definitions: nil)
# @!attribute definitions
# @return [Array] top-level GraphQL units: operations or fragments
- def initialize_node(definitions: [])
- @definitions = definitions
- end
def slice_definition(name)
GraphQL::Language::DefinitionSlice.slice(self, name)
@@ -180,39 +312,47 @@ def slice_definition(name)
end
# An enum value. The string is available as {#name}.
- class Enum < NameOnlyNode; end
+ class Enum < NameOnlyNode
+ end
# A null value literal.
- class NullValue < NameOnlyNode; end
+ class NullValue < NameOnlyNode
+ end
# A single selection in a GraphQL query.
class Field < AbstractNode
- attr_accessor :name, :alias, :arguments, :directives, :selections
+ scalar_methods :name, :alias
+ children_methods({
+ arguments: GraphQL::Language::Nodes::Argument,
+ selections: GraphQL::Language::Nodes::Field,
+ directives: GraphQL::Language::Nodes::Directive,
+ })
# @!attribute selections
# @return [Array] Selections on this object (or empty array if this is a scalar field)
def initialize_node(name: nil, arguments: [], directives: [], selections: [], **kwargs)
@name = name
- # oops, alias is a keyword:
- @alias = kwargs.fetch(:alias, nil)
@arguments = arguments
@directives = directives
@selections = selections
+ # oops, alias is a keyword:
+ @alias = kwargs.fetch(:alias, nil)
end
- def scalars
- [name, self.alias]
- end
-
- def children
- arguments + directives + selections
+ # Override this because default is `:fields`
+ def children_method_name
+ :selections
end
end
# A reusable fragment, defined at document-level.
class FragmentDefinition < AbstractNode
- attr_accessor :name, :type, :directives, :selections
+ scalar_methods :name, :type
+ children_methods({
+ selections: GraphQL::Language::Nodes::Field,
+ directives: GraphQL::Language::Nodes::Directive,
+ })
# @!attribute name
# @return [String] the identifier for this fragment, which may be applied with `...#{name}`
@@ -226,65 +366,39 @@ def initialize_node(name: nil, type: nil, directives: [], selections: [])
@selections = selections
end
- def children
- directives + selections
- end
-
- def scalars
- [name, type]
+ def children_method_name
+ :definitions
end
end
# Application of a named fragment in a selection
class FragmentSpread < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :directives
- alias :children :directives
-
+ scalar_methods :name
+ children_methods(directives: GraphQL::Language::Nodes::Directive)
# @!attribute name
# @return [String] The identifier of the fragment to apply, corresponds with {FragmentDefinition#name}
-
- def initialize_node(name: nil, directives: [])
- @name = name
- @directives = directives
- end
end
# An unnamed fragment, defined directly in the query with `... { }`
class InlineFragment < AbstractNode
- attr_accessor :type, :directives, :selections
+ scalar_methods :type
+ children_methods({
+ selections: GraphQL::Language::Nodes::Field,
+ directives: GraphQL::Language::Nodes::Directive,
+ })
# @!attribute type
# @return [String, nil] Name of the type this fragment applies to, or `nil` if this fragment applies to any type
-
- def initialize_node(type: nil, directives: [], selections: [])
- @type = type
- @directives = directives
- @selections = selections
- end
-
- def children
- directives + selections
- end
-
- def scalars
- [type]
- end
end
# A collection of key-value inputs which may be a field argument
class InputObject < AbstractNode
- attr_accessor :arguments
- alias :children :arguments
+ scalar_methods(false)
+ children_methods(arguments: GraphQL::Language::Nodes::Argument)
# @!attribute arguments
# @return [Array] A list of key-value pairs inside this input object
- def initialize_node(arguments: [])
- @arguments = arguments
- end
-
# @return [Hash] Recursively turn this input object into a Ruby Hash
def to_h(options={})
arguments.inject({}) do |memo, pair|
@@ -294,6 +408,10 @@ def to_h(options={})
end
end
+ def children_method_name
+ :value
+ end
+
private
def serialize_value_for_hash(value)
@@ -316,16 +434,37 @@ def serialize_value_for_hash(value)
# A list type definition, denoted with `[...]` (used for variable type definitions)
- class ListType < WrapperType; end
+ class ListType < WrapperType
+ end
# A non-null type definition, denoted with `...!` (used for variable type definitions)
- class NonNullType < WrapperType; end
+ class NonNullType < WrapperType
+ end
+
+ # An operation-level query variable
+ class VariableDefinition < AbstractNode
+ scalar_methods :name, :type, :default_value
+ children_methods false
+ # @!attribute default_value
+ # @return [String, Integer, Float, Boolean, Array, NullValue] A Ruby value to use if no other value is provided
+
+ # @!attribute type
+ # @return [TypeName, NonNullType, ListType] The expected type of this value
+
+ # @!attribute name
+ # @return [String] The identifier for this variable, _without_ `$`
+ end
# A query, mutation or subscription.
# May be anonymous or named.
# May be explicitly typed (eg `mutation { ... }`) or implicitly a query (eg `{ ... }`).
class OperationDefinition < AbstractNode
- attr_accessor :operation_type, :name, :variables, :directives, :selections
+ scalar_methods :operation_type, :name
+ children_methods({
+ variables: GraphQL::Language::Nodes::VariableDefinition,
+ selections: GraphQL::Language::Nodes::Field,
+ directives: GraphQL::Language::Nodes::Directive,
+ })
# @!attribute variables
# @return [Array] Variable definitions for this operation
@@ -339,315 +478,155 @@ class OperationDefinition < AbstractNode
# @!attribute name
# @return [String, nil] The name for this operation, or `nil` if unnamed
- def initialize_node(operation_type: nil, name: nil, variables: [], directives: [], selections: [])
- @operation_type = operation_type
- @name = name
- @variables = variables
- @directives = directives
- @selections = selections
- end
-
- def children
- variables + directives + selections
- end
-
- def scalars
- [operation_type, name]
+ def children_method_name
+ :definitions
end
end
# A type name, used for variable definitions
- class TypeName < NameOnlyNode; end
-
- # An operation-level query variable
- class VariableDefinition < AbstractNode
- attr_accessor :name, :type, :default_value
-
- # @!attribute default_value
- # @return [String, Integer, Float, Boolean, Array, NullValue] A Ruby value to use if no other value is provided
-
- # @!attribute type
- # @return [TypeName, NonNullType, ListType] The expected type of this value
-
- # @!attribute name
- # @return [String] The identifier for this variable, _without_ `$`
-
- def initialize_node(name: nil, type: nil, default_value: nil)
- @name = name
- @type = type
- @default_value = default_value
- end
-
- def scalars
- [name, type, default_value]
- end
+ class TypeName < NameOnlyNode
end
# Usage of a variable in a query. Name does _not_ include `$`.
- class VariableIdentifier < NameOnlyNode; end
+ class VariableIdentifier < NameOnlyNode
+ end
class SchemaDefinition < AbstractNode
- attr_accessor :query, :mutation, :subscription, :directives
-
- def initialize_node(query: nil, mutation: nil, subscription: nil, directives: [])
- @query = query
- @mutation = mutation
- @subscription = subscription
- @directives = directives
- end
-
- def scalars
- [query, mutation, subscription]
- end
-
- alias :children :directives
+ scalar_methods :query, :mutation, :subscription
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class SchemaExtension < AbstractNode
- attr_accessor :query, :mutation, :subscription, :directives
-
- def initialize_node(query: nil, mutation: nil, subscription: nil, directives: [])
- @query = query
- @mutation = mutation
- @subscription = subscription
- @directives = directives
- end
-
- def scalars
- [query, mutation, subscription]
- end
-
- alias :children :directives
+ scalar_methods :query, :mutation, :subscription
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class ScalarTypeDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :directives, :description
- alias :children :directives
-
- def initialize_node(name:, directives: [], description: nil)
- @name = name
- @directives = directives
- @description = description
- end
+ attr_reader :description
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class ScalarTypeExtension < AbstractNode
- attr_accessor :name, :directives
- alias :children :directives
-
- def initialize_node(name:, directives: [])
- @name = name
- @directives = directives
- end
- end
-
- class ObjectTypeDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :interfaces, :fields, :directives, :description
-
- def initialize_node(name:, interfaces:, fields:, directives: [], description: nil)
- @name = name
- @interfaces = interfaces || []
- @directives = directives
- @fields = fields
- @description = description
- end
-
- def children
- interfaces + fields + directives
- end
- end
-
- class ObjectTypeExtension < AbstractNode
- attr_accessor :name, :interfaces, :fields, :directives
-
- def initialize_node(name:, interfaces:, fields:, directives: [])
- @name = name
- @interfaces = interfaces || []
- @directives = directives
- @fields = fields
- end
-
- def children
- interfaces + fields + directives
- end
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class InputValueDefinition < AbstractNode
- attr_accessor :name, :type, :default_value, :directives,:description
- alias :children :directives
-
- def initialize_node(name:, type:, default_value: nil, directives: [], description: nil)
- @name = name
- @type = type
- @default_value = default_value
- @directives = directives
- @description = description
- end
-
- def scalars
- [name, type, default_value]
- end
+ attr_reader :description
+ scalar_methods :name, :type, :default_value
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class FieldDefinition < AbstractNode
- attr_accessor :name, :arguments, :type, :directives, :description
-
- def initialize_node(name:, arguments:, type:, directives: [], description: nil)
- @name = name
- @arguments = arguments
- @type = type
- @directives = directives
- @description = description
- end
+ attr_reader :description
+ scalar_methods :name, :type
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ arguments: GraphQL::Language::Nodes::InputValueDefinition,
+ })
+ end
- def children
- arguments + directives
- end
+ class ObjectTypeDefinition < AbstractNode
+ attr_reader :description
+ scalar_methods :name, :interfaces
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ fields: GraphQL::Language::Nodes::FieldDefinition,
+ })
+ end
- def scalars
- [name, type]
- end
+ class ObjectTypeExtension < AbstractNode
+ scalar_methods :name, :interfaces
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ fields: GraphQL::Language::Nodes::FieldDefinition,
+ })
end
class InterfaceTypeDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :fields, :directives, :description
-
- def initialize_node(name:, fields:, directives: [], description: nil)
- @name = name
- @fields = fields
- @directives = directives
- @description = description
- end
-
- def children
- fields + directives
- end
+ attr_reader :description
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ fields: GraphQL::Language::Nodes::FieldDefinition,
+ })
end
class InterfaceTypeExtension < AbstractNode
- attr_accessor :name, :fields, :directives
-
- def initialize_node(name:, fields:, directives: [])
- @name = name
- @fields = fields
- @directives = directives
- end
-
- def children
- fields + directives
- end
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ fields: GraphQL::Language::Nodes::FieldDefinition,
+ })
end
class UnionTypeDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :types, :directives, :description
-
- def initialize_node(name:, types:, directives: [], description: nil)
- @name = name
- @types = types
- @directives = directives
- @description = description
- end
-
- def children
- types + directives
- end
+ attr_reader :description, :types
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class UnionTypeExtension < AbstractNode
- attr_accessor :name, :types, :directives
-
- def initialize_node(name:, types:, directives: [])
- @name = name
- @types = types
- @directives = directives
- end
+ attr_reader :types
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
+ end
- def children
- types + directives
- end
+ class EnumValueDefinition < AbstractNode
+ attr_reader :description
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ })
end
class EnumTypeDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :values, :directives, :description
-
- def initialize_node(name:, values:, directives: [], description: nil)
- @name = name
- @values = values
- @directives = directives
- @description = description
- end
-
- def children
- values + directives
- end
+ attr_reader :description
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ values: GraphQL::Language::Nodes::EnumValueDefinition,
+ })
end
class EnumTypeExtension < AbstractNode
- attr_accessor :name, :values, :directives
-
- def initialize_node(name:, values:, directives: [])
- @name = name
- @values = values
- @directives = directives
- end
-
- def children
- values + directives
- end
- end
-
- class EnumValueDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :directives, :description
- alias :children :directives
-
- def initialize_node(name:, directives: [], description: nil)
- @name = name
- @directives = directives
- @description = description
- end
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ values: GraphQL::Language::Nodes::EnumValueDefinition,
+ })
end
class InputObjectTypeDefinition < AbstractNode
- include Scalars::Name
-
- attr_accessor :name, :fields, :directives, :description
-
- def initialize_node(name:, fields:, directives: [], description: nil)
- @name = name
- @fields = fields
- @directives = directives
- @description = description
- end
-
- def children
- fields + directives
- end
+ attr_reader :description
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ fields: GraphQL::Language::Nodes::InputValueDefinition,
+ })
end
class InputObjectTypeExtension < AbstractNode
- attr_accessor :name, :fields, :directives
-
- def initialize_node(name:, fields:, directives: [])
- @name = name
- @fields = fields
- @directives = directives
- end
-
- def children
- fields + directives
- end
+ scalar_methods :name
+ children_methods({
+ directives: GraphQL::Language::Nodes::Directive,
+ fields: GraphQL::Language::Nodes::InputValueDefinition,
+ })
end
end
end
diff --git a/lib/graphql/language/visitor.rb b/lib/graphql/language/visitor.rb
index 509af2dae8..2d58bd23e5 100644
--- a/lib/graphql/language/visitor.rb
+++ b/lib/graphql/language/visitor.rb
@@ -3,32 +3,60 @@ module GraphQL
module Language
# Depth-first traversal through the tree, calling hooks at each stop.
#
- # @example Create a visitor, add hooks, then search a document
- # total_field_count = 0
- # visitor = GraphQL::Language::Visitor.new(document)
- # # Whenever you find a field, increment the field count:
- # visitor[GraphQL::Language::Nodes::Field] << ->(node) { total_field_count += 1 }
- # # When we finish, print the field count:
- # visitor[GraphQL::Language::Nodes::Document].leave << ->(node) { p total_field_count }
- # visitor.visit
- # # => 6
+ # @example Create a visitor counting certain field names
+ # class NameCounter < GraphQL::Language::Visitor
+ # def initialize(document, field_name)
+ # super(document)
+ # @field_name
+ # @count = 0
+ # end
+ #
+ # attr_reader :count
+ #
+ # def on_field(node, parent)
+ # # if this field matches our search, increment the counter
+ # if node.name == @field_name
+ # @count = 0
+ # end
+ # # Continue visiting subfields:
+ # super
+ # end
+ # end
#
+ # # Initialize a visitor
+ # visitor = GraphQL::Language::Visitor.new(document, "name")
+ # # Run it
+ # visitor.visit
+ # # Check the result
+ # visitor.count
+ # # => 3
class Visitor
# If any hook returns this value, the {Visitor} stops visiting this
# node right away
+ # @deprecated Use `super` to continue the visit; or don't call it to halt.
SKIP = :_skip
+ class DeleteNode; end
+ # When this is returned from a visitor method,
+ # Then the `node` passed into the method is removed from `parent`'s children.
+ DELETE_NODE = DeleteNode.new
+
def initialize(document)
@document = document
@visitors = {}
+ @result = nil
end
+ # @return [GraphQL::Language::Nodes::Document] The document with any modifications applied
+ attr_reader :result
+
# Get a {NodeVisitor} for `node_class`
# @param node_class [Class] The node class that you want to listen to
# @return [NodeVisitor]
#
# @example Run a hook whenever you enter a new Field
# visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) { p "Here's a field" }
+ # @deprecated see `on_` methods, like {#on_field}
def [](node_class)
@visitors[node_class] ||= NodeVisitor.new
end
@@ -36,17 +64,94 @@ def [](node_class)
# Visit `document` and all children, applying hooks as you go
# @return [void]
def visit
- visit_node(@document, nil)
+ @result, _nil_parent = on_node_with_modifications(@document, nil)
end
+ # The default implementation for visiting an AST node.
+ # It doesn't _do_ anything, but it continues to visiting the node's children.
+ # To customize this hook, override one of its aliases (or the base method?)
+ # in your subclasses.
+ #
+ # For compatibility, it calls hook procs, too.
+ # @param node [GraphQL::Language::Nodes::AbstractNode] the node being visited
+ # @param parent [GraphQL::Language::Nodes::AbstractNode, nil] the previously-visited node, or `nil` if this is the root node.
+ # @return [void]
+ def on_abstract_node(node, parent)
+ if node == DELETE_NODE
+ # This might be passed to `super(DELETE_NODE, ...)`
+ # by a user hook, don't want to keep visiting in that case.
+ return node, parent
+ else
+ # Run hooks if there are any
+ begin_hooks_ok = @visitors.none? || begin_visit(node, parent)
+ if begin_hooks_ok
+ node.children.each do |child_node|
+ # Reassign `node` in case the child hook makes a modification
+ _new_child_node, node = on_node_with_modifications(child_node, node)
+ end
+ end
+ @visitors.any? && end_visit(node, parent)
+ return node, parent
+ end
+ end
+
+ alias :on_argument :on_abstract_node
+ alias :on_directive :on_abstract_node
+ alias :on_directive_definition :on_abstract_node
+ alias :on_directive_location :on_abstract_node
+ alias :on_document :on_abstract_node
+ alias :on_enum :on_abstract_node
+ alias :on_enum_type_definition :on_abstract_node
+ alias :on_enum_type_extension :on_abstract_node
+ alias :on_enum_value_definition :on_abstract_node
+ alias :on_field :on_abstract_node
+ alias :on_field_definition :on_abstract_node
+ alias :on_fragment_definition :on_abstract_node
+ alias :on_fragment_spread :on_abstract_node
+ alias :on_inline_fragment :on_abstract_node
+ alias :on_input_object :on_abstract_node
+ alias :on_input_object_type_definition :on_abstract_node
+ alias :on_input_object_type_extension :on_abstract_node
+ alias :on_input_value_definition :on_abstract_node
+ alias :on_interface_type_definition :on_abstract_node
+ alias :on_interface_type_extension :on_abstract_node
+ alias :on_list_type :on_abstract_node
+ alias :on_non_null_type :on_abstract_node
+ alias :on_null_value :on_abstract_node
+ alias :on_object_type_definition :on_abstract_node
+ alias :on_object_type_extension :on_abstract_node
+ alias :on_operation_definition :on_abstract_node
+ alias :on_scalar_type_definition :on_abstract_node
+ alias :on_scalar_type_extension :on_abstract_node
+ alias :on_schema_definition :on_abstract_node
+ alias :on_schema_extension :on_abstract_node
+ alias :on_type_name :on_abstract_node
+ alias :on_union_type_definition :on_abstract_node
+ alias :on_union_type_extension :on_abstract_node
+ alias :on_variable_definition :on_abstract_node
+ alias :on_variable_identifier :on_abstract_node
+
private
- def visit_node(node, parent)
- begin_hooks_ok = begin_visit(node, parent)
- if begin_hooks_ok
- node.children.each { |child| visit_node(child, node) }
+ # Run the hooks for `node`, and if the hooks return a copy of `node`,
+ # copy `parent` so that it contains the copy of that node as a child,
+ # then return the copies
+ def on_node_with_modifications(node, parent)
+ new_node, new_parent = public_send(node.visit_method, node, parent)
+ if new_node.is_a?(Nodes::AbstractNode) && !node.equal?(new_node)
+ # The user-provided hook returned a new node.
+ new_parent = new_parent && new_parent.replace_child(node, new_node)
+ return new_node, new_parent
+ elsif new_node == DELETE_NODE
+ # The user-provided hook requested to remove this node
+ new_parent = new_parent && new_parent.delete_child(node)
+ return nil, new_parent
+ else
+ # The user-provided hook didn't make any modifications.
+ # In fact, the hook might have returned who-knows-what, so
+ # ignore the return value and use the original values.
+ return node, parent
end
- end_visit(node, parent)
end
def begin_visit(node, parent)
diff --git a/lib/graphql/relay/connection_instrumentation.rb b/lib/graphql/relay/connection_instrumentation.rb
index e193a7e293..1ddec45a9d 100644
--- a/lib/graphql/relay/connection_instrumentation.rb
+++ b/lib/graphql/relay/connection_instrumentation.rb
@@ -32,7 +32,9 @@ def self.default_arguments
# - Merging in the default arguments
# - Transforming its resolve function to return a connection object
def self.instrument(type, field)
- if field.connection?
+ # Don't apply the wrapper to class-based fields, since they
+ # use Schema::Field::ConnectionFilter
+ if field.connection? && !field.metadata[:type_class]
connection_arguments = default_arguments.merge(field.arguments)
original_resolve = field.resolve_proc
original_lazy_resolve = field.lazy_resolve_proc
diff --git a/lib/graphql/schema.rb b/lib/graphql/schema.rb
index de9473046e..81f29fb626 100644
--- a/lib/graphql/schema.rb
+++ b/lib/graphql/schema.rb
@@ -19,7 +19,6 @@
require "graphql/schema/warden"
require "graphql/schema/build_from_definition"
-
require "graphql/schema/member"
require "graphql/schema/wrapper"
require "graphql/schema/list"
@@ -27,15 +26,17 @@
require "graphql/schema/argument"
require "graphql/schema/enum_value"
require "graphql/schema/enum"
+require "graphql/schema/field_extension"
require "graphql/schema/field"
require "graphql/schema/input_object"
require "graphql/schema/interface"
+require "graphql/schema/scalar"
+require "graphql/schema/object"
+require "graphql/schema/union"
+
require "graphql/schema/resolver"
require "graphql/schema/mutation"
require "graphql/schema/relay_classic_mutation"
-require "graphql/schema/object"
-require "graphql/schema/scalar"
-require "graphql/schema/union"
module GraphQL
# A GraphQL schema which may be queried with {GraphQL::Query}.
diff --git a/lib/graphql/schema/argument.rb b/lib/graphql/schema/argument.rb
index 8701966ce9..80151751ac 100644
--- a/lib/graphql/schema/argument.rb
+++ b/lib/graphql/schema/argument.rb
@@ -97,12 +97,11 @@ def type
# Used by the runtime.
# @api private
def prepare_value(obj, value)
- case @prepare
- when nil
+ if @prepare.nil?
value
- when Symbol, String
+ elsif @prepare.is_a?(String) || @prepare.is_a?(Symbol)
obj.public_send(@prepare, value)
- when Proc
+ elsif @prepare.respond_to?(:call)
@prepare.call(value, obj.context)
else
raise "Invalid prepare for #{@owner.name}.name: #{@prepare.inspect}"
diff --git a/lib/graphql/schema/field.rb b/lib/graphql/schema/field.rb
index c7bda1e86b..5f4080b243 100644
--- a/lib/graphql/schema/field.rb
+++ b/lib/graphql/schema/field.rb
@@ -1,5 +1,8 @@
# frozen_string_literal: true
# test_via: ../object.rb
+require "graphql/schema/field/connection_extension"
+require "graphql/schema/field/scope_extension"
+
module GraphQL
class Schema
class Field
@@ -87,7 +90,8 @@ def connection?
elsif @return_type_expr
Member::BuildType.to_type_name(@return_type_expr)
else
- raise "No connection info possible"
+ # As a last ditch, try to force loading the return type:
+ type.unwrap.name
end
@connection = return_type_name.end_with?("Connection")
else
@@ -101,13 +105,12 @@ def scoped?
# The default was overridden
@scope
else
- @return_type_expr && type.unwrap.respond_to?(:scope_items) && (connection? || type.list?)
+ @return_type_expr.is_a?(Array) || (@return_type_expr.is_a?(String) && @return_type_expr.include?("[")) || connection?
end
end
# @param name [Symbol] The underscore-cased version of this field name (will be camelized for the GraphQL API)
- # @param return_type_expr [Class, GraphQL::BaseType, Array] The return type of this field
- # @param desc [String] Field description
+ # @param type [Class, GraphQL::BaseType, Array] The return type of this field
# @param owner [Class] The type that this field belongs to
# @param null [Boolean] `true` if this field may return `null`, `false` if it is never `null`
# @param description [String] Field description
@@ -126,8 +129,8 @@ def scoped?
# @param complexity [Numeric] When provided, set the complexity for this field
# @param scope [Boolean] If true, the return type's `.scope_items` method will be called on the return value
# @param subscription_scope [Symbol, String] A key in `context` which will be used to scope subscription payloads
- def initialize(type: nil, name: nil, owner: nil, null: nil, field: nil, function: nil, description: nil, deprecation_reason: nil, method: nil, connection: nil, max_page_size: nil, scope: nil, resolve: nil, introspection: false, hash_key: nil, camelize: true, trace: nil, complexity: 1, extras: [], resolver_class: nil, subscription_scope: nil, arguments: {}, &definition_block)
-
+ # @param extensions [Array] Named extensions to apply to this field (see also {#extension})
+ def initialize(type: nil, name: nil, owner: nil, null: nil, field: nil, function: nil, description: nil, deprecation_reason: nil, method: nil, connection: nil, max_page_size: nil, scope: nil, resolve: nil, introspection: false, hash_key: nil, camelize: true, trace: nil, complexity: 1, extras: [], extensions: [], resolver_class: nil, subscription_scope: nil, arguments: {}, &definition_block)
if name.nil?
raise ArgumentError, "missing first `name` argument or keyword `name:`"
end
@@ -145,7 +148,7 @@ def initialize(type: nil, name: nil, owner: nil, null: nil, field: nil, function
@name = camelize ? Member::BuildType.camelize(name.to_s) : name.to_s
@description = description
if field.is_a?(GraphQL::Schema::Field)
- @field_instance = field
+ raise ArgumentError, "Instead of passing a field as `field:`, use `add_field(field)` to add an already-defined field."
else
@field = field
end
@@ -185,9 +188,25 @@ def initialize(type: nil, name: nil, owner: nil, null: nil, field: nil, function
@owner = owner
@subscription_scope = subscription_scope
+ # Do this last so we have as much context as possible when initializing them:
+ @extensions = []
+ if extensions.any?
+ self.extensions(extensions)
+ end
+ # This should run before connection extension,
+ # but should it run after the definition block?
+ if scoped?
+ self.extension(ScopeExtension)
+ end
+ # The problem with putting this after the definition_block
+ # is that it would override arguments
+ if connection?
+ self.extension(ConnectionExtension)
+ end
+
if definition_block
if definition_block.arity == 1
- instance_exec(self, &definition_block)
+ yield self
else
instance_eval(&definition_block)
end
@@ -204,6 +223,49 @@ def description(text = nil)
end
end
+ # Read extension instances from this field,
+ # or add new classes/options to be initialized on this field.
+ #
+ # @param extensions [Array, Hash Object>] Add extensions to this field
+ # @return [Array] extensions to apply to this field
+ def extensions(new_extensions = nil)
+ if new_extensions.nil?
+ # Read the value
+ @extensions
+ else
+ if @resolve || @function
+ raise ArgumentError, <<-MSG
+Extensions are not supported with resolve procs or functions,
+but #{owner.name}.#{name} has: #{@resolve || @function}
+So, it can't have extensions: #{extensions}.
+Use a method or a Schema::Resolver instead.
+MSG
+ end
+
+ # Normalize to a Hash of {name => options}
+ extensions_with_options = if new_extensions.last.is_a?(Hash)
+ new_extensions.pop
+ else
+ {}
+ end
+ new_extensions.each do |f|
+ extensions_with_options[f] = nil
+ end
+
+ # Initialize each class and stash the instance
+ extensions_with_options.each do |extension_class, options|
+ @extensions << extension_class.new(field: self, options: options)
+ end
+ end
+ end
+
+ # Add `extension` to this field, initialized with `options` if provided.
+ # @param extension [Class] subclass of {Schema::Fieldextension}
+ # @param options [Object] if provided, given as `options:` when initializing `extension`.
+ def extension(extension, options = nil)
+ extensions([{extension => options}])
+ end
+
def complexity(new_complexity)
case new_complexity
when Proc
@@ -223,14 +285,11 @@ def complexity(new_complexity)
end
+ # @return [Integer, nil] Applied to connections if present
+ attr_reader :max_page_size
+
# @return [GraphQL::Field]
def to_graphql
- # this field was previously defined and passed here, so delegate to it
- if @field_instance
- return @field_instance.to_graphql
- end
-
-
field_defn = if @field
@field.dup
elsif @function
@@ -265,22 +324,11 @@ def to_graphql
field_defn.resolve = self.method(:resolve_field)
field_defn.connection = connection?
- field_defn.connection_max_page_size = @max_page_size
+ field_defn.connection_max_page_size = max_page_size
field_defn.introspection = @introspection
field_defn.complexity = @complexity
field_defn.subscription_scope = @subscription_scope
- # apply this first, so it can be overriden below
- if connection?
- # TODO: this could be a bit weird, because these fields won't be present
- # after initialization, only in the `to_graphql` response.
- # This calculation _could_ be moved up if need be.
- argument :after, "String", "Returns the elements in the list that come after the specified cursor.", required: false
- argument :before, "String", "Returns the elements in the list that come before the specified cursor.", required: false
- argument :first, "Int", "Returns the first _n_ elements from the list.", required: false
- argument :last, "Int", "Returns the last _n_ elements from the list.", required: false
- end
-
arguments.each do |name, defn|
arg_graphql = defn.to_graphql
field_defn.arguments[arg_graphql.name] = arg_graphql
@@ -341,16 +389,12 @@ def resolve_field(obj, args, ctx)
inner_obj = after_obj && after_obj.object
if authorized?(inner_obj, query_ctx) && arguments.each_value.all? { |a| a.authorized?(inner_obj, query_ctx) }
# Then if it passed, resolve the field
- v = if @resolve_proc
+ if @resolve_proc
# Might be nil, still want to call the func in that case
@resolve_proc.call(inner_obj, args, ctx)
- elsif @resolver_class
- singleton_inst = @resolver_class.new(object: inner_obj, context: query_ctx)
- public_send_field(singleton_inst, args, ctx)
else
public_send_field(after_obj, args, ctx)
end
- apply_scope(v, ctx)
else
nil
end
@@ -396,16 +440,6 @@ def resolve_field_method(obj, ruby_kwargs, ctx)
private
- def apply_scope(value, ctx)
- if scoped?
- ctx.schema.after_lazy(value) do |inner_value|
- @type.unwrap.scope_items(inner_value, ctx)
- end
- else
- value
- end
- end
-
NO_ARGS = {}.freeze
def public_send_field(obj, graphql_args, field_ctx)
@@ -420,14 +454,6 @@ def public_send_field(obj, graphql_args, field_ctx)
end
end
- if connection?
- # Remove pagination args before passing it to a user method
- ruby_kwargs.delete(:first)
- ruby_kwargs.delete(:last)
- ruby_kwargs.delete(:before)
- ruby_kwargs.delete(:after)
- end
-
@extras.each do |extra_arg|
# TODO: provide proper tests for `:ast_node`, `:irep_node`, `:parent`, others?
ruby_kwargs[extra_arg] = field_ctx.public_send(extra_arg)
@@ -436,11 +462,55 @@ def public_send_field(obj, graphql_args, field_ctx)
ruby_kwargs = NO_ARGS
end
+ query_ctx = field_ctx.query.context
+ with_extensions(obj, ruby_kwargs, query_ctx) do |extended_obj, extended_args|
+ if @resolver_class
+ if extended_obj.is_a?(GraphQL::Schema::Object)
+ extended_obj = extended_obj.object
+ end
+ extended_obj = @resolver_class.new(object: extended_obj, context: query_ctx)
+ end
+
+ if extended_args.any?
+ extended_obj.public_send(@method_sym, **extended_args)
+ else
+ extended_obj.public_send(@method_sym)
+ end
+ end
+ end
- if ruby_kwargs.any?
- obj.public_send(@method_sym, **ruby_kwargs)
+ # Wrap execution with hooks.
+ # Written iteratively to avoid big stack traces.
+ # @return [Object] Whatever the
+ def with_extensions(obj, args, ctx)
+ if @extensions.none?
+ yield(obj, args)
else
- obj.public_send(@method_sym)
+ # Save these so that the originals can be re-given to `after_resolve` handlers.
+ original_args = args
+ original_obj = obj
+
+ memos = []
+ @extensions.each do |ext|
+ ext.before_resolve(object: obj, arguments: args, context: ctx) do |extended_obj, extended_args, memo|
+ # update this scope with the yielded value
+ obj = extended_obj
+ args = extended_args
+ # record the memo (or nil if none was yielded)
+ memos << memo
+ end
+ end
+ # Call the block which actually calls resolve
+ value = yield(obj, args)
+
+ ctx.schema.after_lazy(value) do |resolved_value|
+ @extensions.each_with_index do |ext, idx|
+ memo = memos[idx]
+ # TODO after_lazy?
+ resolved_value = ext.after_resolve(object: original_obj, arguments: original_args, context: ctx, value: resolved_value, memo: memo)
+ end
+ resolved_value
+ end
end
end
end
diff --git a/lib/graphql/schema/field/connection_extension.rb b/lib/graphql/schema/field/connection_extension.rb
new file mode 100644
index 0000000000..3862c844c1
--- /dev/null
+++ b/lib/graphql/schema/field/connection_extension.rb
@@ -0,0 +1,50 @@
+# frozen_string_literal: true
+
+module GraphQL
+ class Schema
+ class Field
+ class ConnectionExtension < GraphQL::Schema::FieldExtension
+ def apply
+ field.argument :after, "String", "Returns the elements in the list that come after the specified cursor.", required: false
+ field.argument :before, "String", "Returns the elements in the list that come before the specified cursor.", required: false
+ field.argument :first, "Int", "Returns the first _n_ elements from the list.", required: false
+ field.argument :last, "Int", "Returns the last _n_ elements from the list.", required: false
+ end
+
+ # Remove pagination args before passing it to a user method
+ def before_resolve(object:, arguments:, context:)
+ next_args = arguments.dup
+ next_args.delete(:first)
+ next_args.delete(:last)
+ next_args.delete(:before)
+ next_args.delete(:after)
+ yield(object, next_args)
+ end
+
+ def after_resolve(value:, object:, arguments:, context:, memo:)
+ if value.is_a? GraphQL::ExecutionError
+ # This isn't even going to work because context doesn't have ast_node anymore
+ context.add_error(value)
+ nil
+ elsif value.nil?
+ nil
+ else
+ if object.is_a?(GraphQL::Schema::Object)
+ object = object.object
+ end
+ connection_class = GraphQL::Relay::BaseConnection.connection_for_nodes(value)
+ connection_class.new(
+ value,
+ arguments,
+ field: field,
+ max_page_size: field.max_page_size,
+ parent: object,
+ context: context,
+ )
+ end
+ end
+
+ end
+ end
+ end
+end
diff --git a/lib/graphql/schema/field/scope_extension.rb b/lib/graphql/schema/field/scope_extension.rb
new file mode 100644
index 0000000000..1c95f48979
--- /dev/null
+++ b/lib/graphql/schema/field/scope_extension.rb
@@ -0,0 +1,18 @@
+# frozen_string_literal: true
+
+module GraphQL
+ class Schema
+ class Field
+ class ScopeExtension < GraphQL::Schema::FieldExtension
+ def after_resolve(value:, context:, **rest)
+ ret_type = @field.type.unwrap
+ if ret_type.respond_to?(:scope_items)
+ ret_type.scope_items(value, context)
+ else
+ value
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/lib/graphql/schema/field_extension.rb b/lib/graphql/schema/field_extension.rb
new file mode 100644
index 0000000000..fa9d923ad8
--- /dev/null
+++ b/lib/graphql/schema/field_extension.rb
@@ -0,0 +1,61 @@
+# frozen_string_literal: true
+module GraphQL
+ class Schema
+ # Extend this class to make field-level customizations to resolve behavior.
+ #
+ # When a extension is added to a field with `extension(MyExtension)`, a `MyExtension` instance
+ # is created, and its hooks are applied whenever that field is called.
+ #
+ # The instance is frozen so that instance variables aren't modified during query execution,
+ # which could cause all kinds of issues due to race conditions.
+ class FieldExtension
+ # @return [GraphQL::Schema::Field]
+ attr_reader :field
+
+ # @return [Object]
+ attr_reader :options
+
+ # Called when the extension is mounted with `extension(name, options)`.
+ # The instance is frozen to avoid improper use of state during execution.
+ # @param field [GraphQL::Schema::Field] The field where this extension was mounted
+ # @param options [Object] The second argument to `extension`, or `nil` if nothing was passed.
+ def initialize(field:, options:)
+ @field = field
+ @options = options
+ apply
+ freeze
+ end
+
+ # Called when this extension is attached to a field.
+ # The field definition may be extended during this method.
+ # @return [void]
+ def apply
+ end
+
+ # Called before resolving {#field}. It should either:
+ # - `yield` values to continue execution; OR
+ # - return something else to shortcut field execution.
+ # @param object [Object] The object the field is being resolved on
+ # @param arguments [Hash] Ruby keyword arguments for resolving this field
+ # @param context [Query::Context] the context for this query
+ # @yieldparam object [Object] The object to continue resolving the field on
+ # @yieldparam arguments [Hash] The keyword arguments to continue resolving with
+ # @yieldparam memo [Object] Any extension-specific value which will be passed to {#after_resolve} later
+ def before_resolve(object:, arguments:, context:)
+ yield(object, arguments, nil)
+ end
+
+ # Called after {#field} was resolved, but before the value was added to the GraphQL response.
+ # Whatever this hook returns will be used as the return value.
+ # @param object [Object] The object the field is being resolved on
+ # @param arguments [Hash] Ruby keyword arguments for resolving this field
+ # @param context [Query::Context] the context for this query
+ # @param value [Object] Whatever the field previously returned
+ # @param memo [Object] The third value yielded by {#before_resolve}, or `nil` if there wasn't one
+ # @return [Object] The return value for this field.
+ def after_resolve(object:, arguments:, context:, value:, memo:)
+ value
+ end
+ end
+ end
+end
diff --git a/lib/graphql/schema/mutation.rb b/lib/graphql/schema/mutation.rb
index fffe1d3f79..c737e18fe3 100644
--- a/lib/graphql/schema/mutation.rb
+++ b/lib/graphql/schema/mutation.rb
@@ -122,7 +122,9 @@ def generate_payload_type
description("Autogenerated return type of #{mutation_name}")
mutation(mutation_class)
mutation_fields.each do |name, f|
- field(name, field: f)
+ # Reattach the already-defined field here
+ # (The field's `.owner` will still point to the mutation, not the object type, I think)
+ add_field(f)
end
end
end
diff --git a/lib/graphql/schema/relay_classic_mutation.rb b/lib/graphql/schema/relay_classic_mutation.rb
index 7a5a06b93d..ede874244c 100644
--- a/lib/graphql/schema/relay_classic_mutation.rb
+++ b/lib/graphql/schema/relay_classic_mutation.rb
@@ -1,5 +1,5 @@
# frozen_string_literal: true
-
+require "graphql/types/string"
module GraphQL
class Schema
# Mutations that extend this base class get some conventions added for free:
diff --git a/lib/graphql/static_validation.rb b/lib/graphql/static_validation.rb
index caac1cfd23..9727bd0203 100644
--- a/lib/graphql/static_validation.rb
+++ b/lib/graphql/static_validation.rb
@@ -1,12 +1,12 @@
# frozen_string_literal: true
require "graphql/static_validation/message"
-require "graphql/static_validation/arguments_validator"
require "graphql/static_validation/definition_dependencies"
require "graphql/static_validation/type_stack"
require "graphql/static_validation/validator"
require "graphql/static_validation/validation_context"
require "graphql/static_validation/literal_validator"
-
+require "graphql/static_validation/base_visitor"
+require "graphql/static_validation/no_validate_visitor"
rules_glob = File.expand_path("../static_validation/rules/*.rb", __FILE__)
Dir.glob(rules_glob).each do |file|
@@ -14,3 +14,4 @@
end
require "graphql/static_validation/all_rules"
+require "graphql/static_validation/default_visitor"
diff --git a/lib/graphql/static_validation/all_rules.rb b/lib/graphql/static_validation/all_rules.rb
index d0b5d1b6e2..30edb5f8a9 100644
--- a/lib/graphql/static_validation/all_rules.rb
+++ b/lib/graphql/static_validation/all_rules.rb
@@ -11,9 +11,10 @@ module StaticValidation
GraphQL::StaticValidation::DirectivesAreDefined,
GraphQL::StaticValidation::DirectivesAreInValidLocations,
GraphQL::StaticValidation::UniqueDirectivesPerLocation,
+ GraphQL::StaticValidation::OperationNamesAreValid,
+ GraphQL::StaticValidation::FragmentNamesAreUnique,
GraphQL::StaticValidation::FragmentsAreFinite,
GraphQL::StaticValidation::FragmentsAreNamed,
- GraphQL::StaticValidation::FragmentNamesAreUnique,
GraphQL::StaticValidation::FragmentsAreUsed,
GraphQL::StaticValidation::FragmentTypesExist,
GraphQL::StaticValidation::FragmentsAreOnCompositeTypes,
@@ -32,7 +33,6 @@ module StaticValidation
GraphQL::StaticValidation::VariableUsagesAreAllowed,
GraphQL::StaticValidation::MutationRootExists,
GraphQL::StaticValidation::SubscriptionRootExists,
- GraphQL::StaticValidation::OperationNamesAreValid,
]
end
end
diff --git a/lib/graphql/static_validation/arguments_validator.rb b/lib/graphql/static_validation/arguments_validator.rb
deleted file mode 100644
index ff72d4ccae..0000000000
--- a/lib/graphql/static_validation/arguments_validator.rb
+++ /dev/null
@@ -1,50 +0,0 @@
-# frozen_string_literal: true
-module GraphQL
- module StaticValidation
- # Implement validate_node
- class ArgumentsValidator
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- visitor = context.visitor
- visitor[GraphQL::Language::Nodes::Argument] << ->(node, parent) {
- case parent
- when GraphQL::Language::Nodes::InputObject
- arg_defn = context.argument_definition
- if arg_defn.nil?
- return
- else
- parent_defn = arg_defn.type.unwrap
- if !parent_defn.is_a?(GraphQL::InputObjectType)
- return
- end
- end
- when GraphQL::Language::Nodes::Directive
- parent_defn = context.schema.directives[parent.name]
- when GraphQL::Language::Nodes::Field
- parent_defn = context.field_definition
- else
- raise "Unexpected argument parent: #{parent.class} (##{parent})"
- end
- validate_node(parent, node, parent_defn, context)
- }
- end
-
- private
-
- def parent_name(parent, type_defn)
- if parent.is_a?(GraphQL::Language::Nodes::Field)
- parent.alias || parent.name
- elsif parent.is_a?(GraphQL::Language::Nodes::InputObject)
- type_defn.name
- else
- parent.name
- end
- end
-
- def node_type(parent)
- parent.class.name.split("::").last
- end
- end
- end
-end
diff --git a/lib/graphql/static_validation/base_visitor.rb b/lib/graphql/static_validation/base_visitor.rb
new file mode 100644
index 0000000000..0894365bcc
--- /dev/null
+++ b/lib/graphql/static_validation/base_visitor.rb
@@ -0,0 +1,184 @@
+# frozen_string_literal: true
+module GraphQL
+ module StaticValidation
+ class BaseVisitor < GraphQL::Language::Visitor
+ def initialize(document, context)
+ @path = []
+ @object_types = []
+ @directives = []
+ @field_definitions = []
+ @argument_definitions = []
+ @directive_definitions = []
+ @context = context
+ @schema = context.schema
+ super(document)
+ end
+
+ attr_reader :context
+
+ # @return [Array] Types whose scope we've entered
+ attr_reader :object_types
+
+ # @return [Array] The nesting of the current position in the AST
+ def path
+ @path.dup
+ end
+
+ # Build a class to visit the AST and perform validation,
+ # or use a pre-built class if rules is `ALL_RULES` or empty.
+ # @param rules [Array]
+ # @return [Class] A class for validating `rules` during visitation
+ def self.including_rules(rules)
+ if rules.none?
+ NoValidateVisitor
+ elsif rules == ALL_RULES
+ DefaultVisitor
+ else
+ visitor_class = Class.new(self) do
+ include(GraphQL::StaticValidation::DefinitionDependencies)
+ end
+
+ rules.reverse_each do |r|
+ # If it's a class, it gets attached later.
+ if !r.is_a?(Class)
+ visitor_class.include(r)
+ end
+ end
+
+ visitor_class.include(GraphQL::InternalRepresentation::Rewrite)
+ visitor_class.include(ContextMethods)
+ visitor_class
+ end
+ end
+
+ module ContextMethods
+ def on_operation_definition(node, parent)
+ object_type = @schema.root_type_for_operation(node.operation_type)
+ @object_types.push(object_type)
+ @path.push("#{node.operation_type}#{node.name ? " #{node.name}" : ""}")
+ super
+ @object_types.pop
+ @path.pop
+ end
+
+ def on_fragment_definition(node, parent)
+ on_fragment_with_type(node) do
+ @path.push("fragment #{node.name}")
+ super
+ end
+ end
+
+ def on_inline_fragment(node, parent)
+ on_fragment_with_type(node) do
+ @path.push("...#{node.type ? " on #{node.type.to_query_string}" : ""}")
+ super
+ end
+ end
+
+ def on_field(node, parent)
+ parent_type = @object_types.last
+ field_definition = @schema.get_field(parent_type, node.name)
+ @field_definitions.push(field_definition)
+ if !field_definition.nil?
+ next_object_type = field_definition.type.unwrap
+ @object_types.push(next_object_type)
+ else
+ @object_types.push(nil)
+ end
+ @path.push(node.alias || node.name)
+ super
+ @field_definitions.pop
+ @object_types.pop
+ @path.pop
+ end
+
+ def on_directive(node, parent)
+ directive_defn = @schema.directives[node.name]
+ @directive_definitions.push(directive_defn)
+ super
+ @directive_definitions.pop
+ end
+
+ def on_argument(node, parent)
+ argument_defn = if (arg = @argument_definitions.last)
+ arg_type = arg.type.unwrap
+ if arg_type.kind.input_object?
+ arg_type.input_fields[node.name]
+ else
+ nil
+ end
+ elsif (directive_defn = @directive_definitions.last)
+ directive_defn.arguments[node.name]
+ elsif (field_defn = @field_definitions.last)
+ field_defn.arguments[node.name]
+ else
+ nil
+ end
+
+ @argument_definitions.push(argument_defn)
+ @path.push(node.name)
+ super
+ @argument_definitions.pop
+ @path.pop
+ end
+
+ def on_fragment_spread(node, parent)
+ @path.push("... #{node.name}")
+ super
+ @path.pop
+ end
+
+ # @return [GraphQL::BaseType] The current object type
+ def type_definition
+ @object_types.last
+ end
+
+ # @return [GraphQL::BaseType] The type which the current type came from
+ def parent_type_definition
+ @object_types[-2]
+ end
+
+ # @return [GraphQL::Field, nil] The most-recently-entered GraphQL::Field, if currently inside one
+ def field_definition
+ @field_definitions.last
+ end
+
+ # @return [GraphQL::Directive, nil] The most-recently-entered GraphQL::Directive, if currently inside one
+ def directive_definition
+ @directive_definitions.last
+ end
+
+ # @return [GraphQL::Argument, nil] The most-recently-entered GraphQL::Argument, if currently inside one
+ def argument_definition
+ # Don't get the _last_ one because that's the current one.
+ # Get the second-to-last one, which is the parent of the current one.
+ @argument_definitions[-2]
+ end
+
+ private
+
+ def on_fragment_with_type(node)
+ object_type = if node.type
+ @schema.types.fetch(node.type.name, nil)
+ else
+ @object_types.last
+ end
+ @object_types.push(object_type)
+ yield(node)
+ @object_types.pop
+ @path.pop
+ end
+ end
+
+ private
+
+ # Error `message` is located at `node`
+ def add_error(message, nodes, path: nil)
+ path ||= @path.dup
+ nodes = Array(nodes)
+ m = GraphQL::StaticValidation::Message.new(message, nodes: nodes, path: path)
+ context.errors << m
+ end
+ end
+ end
+end
diff --git a/lib/graphql/static_validation/default_visitor.rb b/lib/graphql/static_validation/default_visitor.rb
new file mode 100644
index 0000000000..1202f5b3f2
--- /dev/null
+++ b/lib/graphql/static_validation/default_visitor.rb
@@ -0,0 +1,15 @@
+# frozen_string_literal: true
+module GraphQL
+ module StaticValidation
+ class DefaultVisitor < BaseVisitor
+ include(GraphQL::StaticValidation::DefinitionDependencies)
+
+ StaticValidation::ALL_RULES.reverse_each do |r|
+ include(r)
+ end
+
+ include(GraphQL::InternalRepresentation::Rewrite)
+ include(ContextMethods)
+ end
+ end
+end
diff --git a/lib/graphql/static_validation/definition_dependencies.rb b/lib/graphql/static_validation/definition_dependencies.rb
index 8deba284e6..9fd1fde127 100644
--- a/lib/graphql/static_validation/definition_dependencies.rb
+++ b/lib/graphql/static_validation/definition_dependencies.rb
@@ -4,79 +4,72 @@ module StaticValidation
# Track fragment dependencies for operations
# and expose the fragment definitions which
# are used by a given operation
- class DefinitionDependencies
- def self.mount(visitor)
- deps = self.new
- deps.mount(visitor)
- deps
- end
+ module DefinitionDependencies
+ attr_reader :dependencies
- def initialize
- @node_paths = {}
+ def initialize(*)
+ super
+ @defdep_node_paths = {}
# { name => node } pairs for fragments
- @fragment_definitions = {}
+ @defdep_fragment_definitions = {}
# This tracks dependencies from fragment to Node where it was used
# { fragment_definition_node => [dependent_node, dependent_node]}
- @dependent_definitions = Hash.new { |h, k| h[k] = Set.new }
+ @defdep_dependent_definitions = Hash.new { |h, k| h[k] = Set.new }
# First-level usages of spreads within definitions
# (When a key has an empty list as its value,
# we can resolve that key's depenedents)
# { definition_node => [node, node ...] }
- @immediate_dependencies = Hash.new { |h, k| h[k] = Set.new }
- end
-
- # A map of operation definitions to an array of that operation's dependencies
- # @return [DependencyMap]
- def dependency_map(&block)
- @dependency_map ||= resolve_dependencies(&block)
- end
+ @defdep_immediate_dependencies = Hash.new { |h, k| h[k] = Set.new }
- def mount(context)
- visitor = context.visitor
# When we encounter a spread,
# this node is the one who depends on it
- current_parent = nil
-
- visitor[GraphQL::Language::Nodes::Document] << ->(node, prev_node) {
- node.definitions.each do |definition|
- case definition
- when GraphQL::Language::Nodes::OperationDefinition
- when GraphQL::Language::Nodes::FragmentDefinition
- @fragment_definitions[definition.name] = definition
- end
- end
- }
+ @defdep_current_parent = nil
+ end
- visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, prev_node) {
- @node_paths[node] = NodeWithPath.new(node, context.path)
- current_parent = node
+ def on_document(node, parent)
+ node.definitions.each do |definition|
+ if definition.is_a? GraphQL::Language::Nodes::FragmentDefinition
+ @defdep_fragment_definitions[definition.name] = definition
+ end
+ end
+ super
+ @dependencies = dependency_map { |defn, spreads, frag|
+ context.on_dependency_resolve_handlers.each { |h| h.call(defn, spreads, frag) }
}
+ end
- visitor[GraphQL::Language::Nodes::OperationDefinition].leave << ->(node, prev_node) {
- current_parent = nil
- }
+ def on_operation_definition(node, prev_node)
+ @defdep_node_paths[node] = NodeWithPath.new(node, context.path)
+ @defdep_current_parent = node
+ super
+ @defdep_current_parent = nil
+ end
- visitor[GraphQL::Language::Nodes::FragmentDefinition] << ->(node, prev_node) {
- @node_paths[node] = NodeWithPath.new(node, context.path)
- current_parent = node
- }
+ def on_fragment_definition(node, parent)
+ @defdep_node_paths[node] = NodeWithPath.new(node, context.path)
+ @defdep_current_parent = node
+ super
+ @defdep_current_parent = nil
+ end
- visitor[GraphQL::Language::Nodes::FragmentDefinition].leave << ->(node, prev_node) {
- current_parent = nil
- }
+ def on_fragment_spread(node, parent)
+ @defdep_node_paths[node] = NodeWithPath.new(node, context.path)
- visitor[GraphQL::Language::Nodes::FragmentSpread] << ->(node, prev_node) {
- @node_paths[node] = NodeWithPath.new(node, context.path)
+ # Track both sides of the dependency
+ @defdep_dependent_definitions[@defdep_fragment_definitions[node.name]] << @defdep_current_parent
+ @defdep_immediate_dependencies[@defdep_current_parent] << node
+ end
- # Track both sides of the dependency
- @dependent_definitions[@fragment_definitions[node.name]] << current_parent
- @immediate_dependencies[current_parent] << node
- }
+ # A map of operation definitions to an array of that operation's dependencies
+ # @return [DependencyMap]
+ def dependency_map(&block)
+ @dependency_map ||= resolve_dependencies(&block)
end
+
# Map definition AST nodes to the definition AST nodes they depend on.
# Expose circular depednencies.
class DependencyMap
@@ -122,14 +115,14 @@ def resolve_dependencies
dependency_map = DependencyMap.new
# Don't allow the loop to run more times
# than the number of fragments in the document
- max_loops = @fragment_definitions.size
+ max_loops = @defdep_fragment_definitions.size
loops = 0
# Instead of tracking independent fragments _as you visit_,
# determine them at the end. This way, we can treat fragments with the
# same name as if they were the same name. If _any_ of the fragments
# with that name has a dependency, we record it.
- independent_fragment_nodes = @fragment_definitions.values - @immediate_dependencies.keys
+ independent_fragment_nodes = @defdep_fragment_definitions.values - @defdep_immediate_dependencies.keys
while fragment_node = independent_fragment_nodes.pop
loops += 1
@@ -138,22 +131,22 @@ def resolve_dependencies
end
# Since it's independent, let's remove it from here.
# That way, we can use the remainder to identify cycles
- @immediate_dependencies.delete(fragment_node)
- fragment_usages = @dependent_definitions[fragment_node]
+ @defdep_immediate_dependencies.delete(fragment_node)
+ fragment_usages = @defdep_dependent_definitions[fragment_node]
if fragment_usages.none?
# If we didn't record any usages during the visit,
# then this fragment is unused.
- dependency_map.unused_dependencies << @node_paths[fragment_node]
+ dependency_map.unused_dependencies << @defdep_node_paths[fragment_node]
else
fragment_usages.each do |definition_node|
# Register the dependency AND second-order dependencies
dependency_map[definition_node] << fragment_node
dependency_map[definition_node].concat(dependency_map[fragment_node])
# Since we've regestered it, remove it from our to-do list
- deps = @immediate_dependencies[definition_node]
+ deps = @defdep_immediate_dependencies[definition_node]
# Can't find a way to _just_ delete from `deps` and return the deleted entries
removed, remaining = deps.partition { |spread| spread.name == fragment_node.name }
- @immediate_dependencies[definition_node] = remaining
+ @defdep_immediate_dependencies[definition_node] = remaining
if block_given?
yield(definition_node, removed, fragment_node)
end
@@ -170,20 +163,20 @@ def resolve_dependencies
# If any dependencies were _unmet_
# (eg, spreads with no corresponding definition)
# then they're still in there
- @immediate_dependencies.each do |defn_node, deps|
+ @defdep_immediate_dependencies.each do |defn_node, deps|
deps.each do |spread|
- if @fragment_definitions[spread.name].nil?
- dependency_map.unmet_dependencies[@node_paths[defn_node]] << @node_paths[spread]
+ if @defdep_fragment_definitions[spread.name].nil?
+ dependency_map.unmet_dependencies[@defdep_node_paths[defn_node]] << @defdep_node_paths[spread]
deps.delete(spread)
end
end
if deps.none?
- @immediate_dependencies.delete(defn_node)
+ @defdep_immediate_dependencies.delete(defn_node)
end
end
# Anything left in @immediate_dependencies is cyclical
- cyclical_nodes = @immediate_dependencies.keys.map { |n| @node_paths[n] }
+ cyclical_nodes = @defdep_immediate_dependencies.keys.map { |n| @defdep_node_paths[n] }
# @immediate_dependencies also includes operation names, but we don't care about
# those. They became nil when we looked them up on `@fragment_definitions`, so remove them.
cyclical_nodes.compact!
diff --git a/lib/graphql/static_validation/no_validate_visitor.rb b/lib/graphql/static_validation/no_validate_visitor.rb
new file mode 100644
index 0000000000..4fc3303433
--- /dev/null
+++ b/lib/graphql/static_validation/no_validate_visitor.rb
@@ -0,0 +1,10 @@
+# frozen_string_literal: true
+module GraphQL
+ module StaticValidation
+ class NoValidateVisitor < StaticValidation::BaseVisitor
+ include(GraphQL::InternalRepresentation::Rewrite)
+ include(GraphQL::StaticValidation::DefinitionDependencies)
+ include(ContextMethods)
+ end
+ end
+end
diff --git a/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb b/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb
index 0fb4872dee..476d3dcce2 100644
--- a/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb
+++ b/lib/graphql/static_validation/rules/argument_literals_are_compatible.rb
@@ -1,27 +1,69 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class ArgumentLiteralsAreCompatible < GraphQL::StaticValidation::ArgumentsValidator
- def validate_node(parent, node, defn, context)
- return if node.value.is_a?(GraphQL::Language::Nodes::VariableIdentifier)
- arg_defn = defn.arguments[node.name]
- return unless arg_defn
-
- begin
- valid = context.valid_literal?(node.value, arg_defn.type)
- rescue GraphQL::CoercionError => err
- error_message = err.message
+ module ArgumentLiteralsAreCompatible
+ # TODO dedup with ArgumentsAreDefined
+ def on_argument(node, parent)
+ parent_defn = case parent
+ when GraphQL::Language::Nodes::InputObject
+ arg_defn = context.argument_definition
+ if arg_defn.nil?
+ nil
+ else
+ arg_ret_type = arg_defn.type.unwrap
+ if !arg_ret_type.is_a?(GraphQL::InputObjectType)
+ nil
+ else
+ arg_ret_type
+ end
+ end
+ when GraphQL::Language::Nodes::Directive
+ context.schema.directives[parent.name]
+ when GraphQL::Language::Nodes::Field
+ context.field_definition
+ else
+ raise "Unexpected argument parent: #{parent.class} (##{parent})"
end
- return if valid
+ if parent_defn && !node.value.is_a?(GraphQL::Language::Nodes::VariableIdentifier)
+ arg_defn = parent_defn.arguments[node.name]
+ if arg_defn
+ begin
+ valid = context.valid_literal?(node.value, arg_defn.type)
+ rescue GraphQL::CoercionError => err
+ error_message = err.message
+ end
- error_message ||= begin
- kind_of_node = node_type(parent)
- error_arg_name = parent_name(parent, defn)
- "Argument '#{node.name}' on #{kind_of_node} '#{error_arg_name}' has an invalid value. Expected type '#{arg_defn.type}'."
+ if !valid
+ error_message ||= begin
+ kind_of_node = node_type(parent)
+ error_arg_name = parent_name(parent, parent_defn)
+ "Argument '#{node.name}' on #{kind_of_node} '#{error_arg_name}' has an invalid value. Expected type '#{arg_defn.type}'."
+ end
+
+ add_error(error_message, parent)
+ end
+ end
+ end
+
+ super
+ end
+
+
+ private
+
+ def parent_name(parent, type_defn)
+ if parent.is_a?(GraphQL::Language::Nodes::Field)
+ parent.alias || parent.name
+ elsif parent.is_a?(GraphQL::Language::Nodes::InputObject)
+ type_defn.name
+ else
+ parent.name
end
+ end
- context.errors << message(error_message, parent, context: context)
+ def node_type(parent)
+ parent.class.name.split("::").last
end
end
end
diff --git a/lib/graphql/static_validation/rules/argument_names_are_unique.rb b/lib/graphql/static_validation/rules/argument_names_are_unique.rb
index 43c7f68f9a..cb77e4f60c 100644
--- a/lib/graphql/static_validation/rules/argument_names_are_unique.rb
+++ b/lib/graphql/static_validation/rules/argument_names_are_unique.rb
@@ -1,27 +1,27 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class ArgumentNamesAreUnique
+ module ArgumentNamesAreUnique
include GraphQL::StaticValidation::Message::MessageHelper
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) {
- validate_arguments(node, context)
- }
+ def on_field(node, parent)
+ validate_arguments(node)
+ super
+ end
- context.visitor[GraphQL::Language::Nodes::Directive] << ->(node, parent) {
- validate_arguments(node, context)
- }
+ def on_directive(node, parent)
+ validate_arguments(node)
+ super
end
- def validate_arguments(node, context)
+ def validate_arguments(node)
argument_defns = node.arguments
if argument_defns.any?
args_by_name = Hash.new { |h, k| h[k] = [] }
argument_defns.each { |a| args_by_name[a.name] << a }
args_by_name.each do |name, defns|
if defns.size > 1
- context.errors << message("There can be only one argument named \"#{name}\"", defns, context: context)
+ add_error("There can be only one argument named \"#{name}\"", defns)
end
end
end
diff --git a/lib/graphql/static_validation/rules/arguments_are_defined.rb b/lib/graphql/static_validation/rules/arguments_are_defined.rb
index 2bee1b344f..6f99ea4dd0 100644
--- a/lib/graphql/static_validation/rules/arguments_are_defined.rb
+++ b/lib/graphql/static_validation/rules/arguments_are_defined.rb
@@ -1,18 +1,56 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class ArgumentsAreDefined < GraphQL::StaticValidation::ArgumentsValidator
- def validate_node(parent, node, defn, context)
- argument_defn = context.warden.arguments(defn).find { |arg| arg.name == node.name }
- if argument_defn.nil?
+ module ArgumentsAreDefined
+ def on_argument(node, parent)
+ parent_defn = case parent
+ when GraphQL::Language::Nodes::InputObject
+ arg_defn = context.argument_definition
+ if arg_defn.nil?
+ nil
+ else
+ arg_ret_type = arg_defn.type.unwrap
+ if !arg_ret_type.is_a?(GraphQL::InputObjectType)
+ nil
+ else
+ arg_ret_type
+ end
+ end
+ when GraphQL::Language::Nodes::Directive
+ context.schema.directives[parent.name]
+ when GraphQL::Language::Nodes::Field
+ context.field_definition
+ else
+ raise "Unexpected argument parent: #{parent.class} (##{parent})"
+ end
+
+ if parent_defn && context.warden.arguments(parent_defn).any? { |arg| arg.name == node.name }
+ super
+ elsif parent_defn
kind_of_node = node_type(parent)
- error_arg_name = parent_name(parent, defn)
- context.errors << message("#{kind_of_node} '#{error_arg_name}' doesn't accept argument '#{node.name}'", node, context: context)
- GraphQL::Language::Visitor::SKIP
+ error_arg_name = parent_name(parent, parent_defn)
+ add_error("#{kind_of_node} '#{error_arg_name}' doesn't accept argument '#{node.name}'", node)
+ else
+ # Some other weird error
+ super
+ end
+ end
+
+ private
+
+ def parent_name(parent, type_defn)
+ if parent.is_a?(GraphQL::Language::Nodes::Field)
+ parent.alias || parent.name
+ elsif parent.is_a?(GraphQL::Language::Nodes::InputObject)
+ type_defn.name
else
- nil
+ parent.name
end
end
+
+ def node_type(parent)
+ parent.class.name.split("::").last
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/directives_are_defined.rb b/lib/graphql/static_validation/rules/directives_are_defined.rb
index 5811005f9b..e2f14347fb 100644
--- a/lib/graphql/static_validation/rules/directives_are_defined.rb
+++ b/lib/graphql/static_validation/rules/directives_are_defined.rb
@@ -1,24 +1,17 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class DirectivesAreDefined
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- directive_names = context.schema.directives.keys
- context.visitor[GraphQL::Language::Nodes::Directive] << ->(node, parent) {
- validate_directive(node, directive_names, context)
- }
+ module DirectivesAreDefined
+ def initialize(*)
+ super
+ @directive_names = context.schema.directives.keys
end
- private
-
- def validate_directive(ast_directive, directive_names, context)
- if !directive_names.include?(ast_directive.name)
- context.errors << message("Directive @#{ast_directive.name} is not defined", ast_directive, context: context)
- GraphQL::Language::Visitor::SKIP
+ def on_directive(node, parent)
+ if !@directive_names.include?(node.name)
+ add_error("Directive @#{node.name} is not defined", node)
else
- nil
+ super
end
end
end
diff --git a/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb b/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb
index e14d5bb4ac..5a52d03593 100644
--- a/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb
+++ b/lib/graphql/static_validation/rules/directives_are_in_valid_locations.rb
@@ -1,16 +1,12 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class DirectivesAreInValidLocations
- include GraphQL::StaticValidation::Message::MessageHelper
+ module DirectivesAreInValidLocations
include GraphQL::Language
- def validate(context)
- directives = context.schema.directives
-
- context.visitor[Nodes::Directive] << ->(node, parent) {
- validate_location(node, parent, directives, context)
- }
+ def on_directive(node, parent)
+ validate_location(node, parent, context.schema.directives)
+ super
end
private
@@ -34,25 +30,25 @@ def validate(context)
SIMPLE_LOCATION_NODES = SIMPLE_LOCATIONS.keys
- def validate_location(ast_directive, ast_parent, directives, context)
+ def validate_location(ast_directive, ast_parent, directives)
directive_defn = directives[ast_directive.name]
case ast_parent
when Nodes::OperationDefinition
required_location = GraphQL::Directive.const_get(ast_parent.operation_type.upcase)
- assert_includes_location(directive_defn, ast_directive, required_location, context)
+ assert_includes_location(directive_defn, ast_directive, required_location)
when *SIMPLE_LOCATION_NODES
required_location = SIMPLE_LOCATIONS[ast_parent.class]
- assert_includes_location(directive_defn, ast_directive, required_location, context)
+ assert_includes_location(directive_defn, ast_directive, required_location)
else
- context.errors << message("Directives can't be applied to #{ast_parent.class.name}s", ast_directive, context: context)
+ add_error("Directives can't be applied to #{ast_parent.class.name}s", ast_directive)
end
end
- def assert_includes_location(directive_defn, directive_ast, required_location, context)
+ def assert_includes_location(directive_defn, directive_ast, required_location)
if !directive_defn.locations.include?(required_location)
location_name = LOCATION_MESSAGE_NAMES[required_location]
allowed_location_names = directive_defn.locations.map { |loc| LOCATION_MESSAGE_NAMES[loc] }
- context.errors << message("'@#{directive_defn.name}' can't be applied to #{location_name} (allowed: #{allowed_location_names.join(", ")})", directive_ast, context: context)
+ add_error("'@#{directive_defn.name}' can't be applied to #{location_name} (allowed: #{allowed_location_names.join(", ")})", directive_ast)
end
end
end
diff --git a/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb b/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb
index 7870148729..d942de9558 100644
--- a/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb
+++ b/lib/graphql/static_validation/rules/fields_are_defined_on_type.rb
@@ -1,30 +1,19 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FieldsAreDefinedOnType
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- visitor = context.visitor
- visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) {
- parent_type = context.object_types[-2]
- parent_type = parent_type.unwrap
- validate_field(context, node, parent_type, parent)
- }
- end
-
- private
-
- def validate_field(context, ast_field, parent_type, parent)
- field = context.warden.get_field(parent_type, ast_field.name)
+ module FieldsAreDefinedOnType
+ def on_field(node, parent)
+ parent_type = @object_types[-2]
+ field = context.warden.get_field(parent_type, node.name)
if field.nil?
if parent_type.kind.union?
- context.errors << message("Selections can't be made directly on unions (see selections on #{parent_type.name})", parent, context: context)
+ add_error("Selections can't be made directly on unions (see selections on #{parent_type.name})", parent)
else
- context.errors << message("Field '#{ast_field.name}' doesn't exist on type '#{parent_type.name}'", ast_field, context: context)
+ add_error("Field '#{node.name}' doesn't exist on type '#{parent_type.name}'", node)
end
- return GraphQL::Language::Visitor::SKIP
+ else
+ super
end
end
end
diff --git a/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb b/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb
index 71f3cf3023..865fb6efa2 100644
--- a/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb
+++ b/lib/graphql/static_validation/rules/fields_have_appropriate_selections.rb
@@ -3,24 +3,26 @@ module GraphQL
module StaticValidation
# Scalars _can't_ have selections
# Objects _must_ have selections
- class FieldsHaveAppropriateSelections
+ module FieldsHaveAppropriateSelections
include GraphQL::StaticValidation::Message::MessageHelper
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) {
- field_defn = context.field_definition
- validate_field_selections(node, field_defn.type.unwrap, context)
- }
+ def on_field(node, parent)
+ field_defn = field_definition
+ if validate_field_selections(node, field_defn.type.unwrap)
+ super
+ end
+ end
- context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) {
- validate_field_selections(node, context.type_definition, context)
- }
+ def on_operation_definition(node, _parent)
+ if validate_field_selections(node, type_definition)
+ super
+ end
end
private
- def validate_field_selections(ast_node, resolved_type, context)
+ def validate_field_selections(ast_node, resolved_type)
msg = if resolved_type.nil?
nil
elsif resolved_type.kind.scalar? && ast_node.selections.any?
@@ -48,8 +50,10 @@ def validate_field_selections(ast_node, resolved_type, context)
else
raise("Unexpected node #{ast_node}")
end
- context.errors << message(msg % { node_name: node_name }, ast_node, context: context)
- GraphQL::Language::Visitor::SKIP
+ add_error(msg % { node_name: node_name }, ast_node)
+ false
+ else
+ true
end
end
end
diff --git a/lib/graphql/static_validation/rules/fields_will_merge.rb b/lib/graphql/static_validation/rules/fields_will_merge.rb
index 5897e936d7..af9e9df3c9 100644
--- a/lib/graphql/static_validation/rules/fields_will_merge.rb
+++ b/lib/graphql/static_validation/rules/fields_will_merge.rb
@@ -1,11 +1,13 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FieldsWillMerge
+ module FieldsWillMerge
# Special handling for fields without arguments
NO_ARGS = {}.freeze
- def validate(context)
+ def initialize(*)
+ super
+
context.each_irep_node do |node|
if node.ast_nodes.size > 1
defn_names = Set.new(node.ast_nodes.map(&:name))
diff --git a/lib/graphql/static_validation/rules/fragment_names_are_unique.rb b/lib/graphql/static_validation/rules/fragment_names_are_unique.rb
index f4e9dff7de..aa650ccfbe 100644
--- a/lib/graphql/static_validation/rules/fragment_names_are_unique.rb
+++ b/lib/graphql/static_validation/rules/fragment_names_are_unique.rb
@@ -1,22 +1,25 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentNamesAreUnique
- include GraphQL::StaticValidation::Message::MessageHelper
+ module FragmentNamesAreUnique
- def validate(context)
- fragments_by_name = Hash.new { |h, k| h[k] = [] }
- context.visitor[GraphQL::Language::Nodes::FragmentDefinition] << ->(node, parent) {
- fragments_by_name[node.name] << node
- }
+ def initialize(*)
+ super
+ @fragments_by_name = Hash.new { |h, k| h[k] = [] }
+ end
+
+ def on_fragment_definition(node, parent)
+ @fragments_by_name[node.name] << node
+ super
+ end
- context.visitor[GraphQL::Language::Nodes::Document].leave << ->(node, parent) {
- fragments_by_name.each do |name, fragments|
- if fragments.length > 1
- context.errors << message(%|Fragment name "#{name}" must be unique|, fragments, context: context)
- end
+ def on_document(_n, _p)
+ super
+ @fragments_by_name.each do |name, fragments|
+ if fragments.length > 1
+ add_error(%|Fragment name "#{name}" must be unique|, fragments)
end
- }
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb b/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb
index 5351c28bbb..57421d1a38 100644
--- a/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb
+++ b/lib/graphql/static_validation/rules/fragment_spreads_are_possible.rb
@@ -1,39 +1,40 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentSpreadsArePossible
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
-
- context.visitor[GraphQL::Language::Nodes::InlineFragment] << ->(node, parent) {
- fragment_parent = context.object_types[-2]
- fragment_child = context.object_types.last
- if fragment_child
- validate_fragment_in_scope(fragment_parent, fragment_child, node, context, context.path)
- end
- }
+ module FragmentSpreadsArePossible
+ def initialize(*)
+ super
+ @spreads_to_validate = []
+ end
- spreads_to_validate = []
+ def on_inline_fragment(node, parent)
+ fragment_parent = context.object_types[-2]
+ fragment_child = context.object_types.last
+ if fragment_child
+ validate_fragment_in_scope(fragment_parent, fragment_child, node, context, context.path)
+ end
+ super
+ end
- context.visitor[GraphQL::Language::Nodes::FragmentSpread] << ->(node, parent) {
- fragment_parent = context.object_types.last
- spreads_to_validate << FragmentSpread.new(node: node, parent_type: fragment_parent, path: context.path)
- }
+ def on_fragment_spread(node, parent)
+ fragment_parent = context.object_types.last
+ @spreads_to_validate << FragmentSpread.new(node: node, parent_type: fragment_parent, path: context.path)
+ super
+ end
- context.visitor[GraphQL::Language::Nodes::Document].leave << ->(doc_node, parent) {
- spreads_to_validate.each do |frag_spread|
- frag_node = context.fragments[frag_spread.node.name]
- if frag_node
- fragment_child_name = frag_node.type.name
- fragment_child = context.warden.get_type(fragment_child_name)
- # Might be non-existent type name
- if fragment_child
- validate_fragment_in_scope(frag_spread.parent_type, fragment_child, frag_spread.node, context, frag_spread.path)
- end
+ def on_document(node, parent)
+ super
+ @spreads_to_validate.each do |frag_spread|
+ frag_node = context.fragments[frag_spread.node.name]
+ if frag_node
+ fragment_child_name = frag_node.type.name
+ fragment_child = context.warden.get_type(fragment_child_name)
+ # Might be non-existent type name
+ if fragment_child
+ validate_fragment_in_scope(frag_spread.parent_type, fragment_child, frag_spread.node, context, frag_spread.path)
end
end
- }
+ end
end
private
@@ -48,7 +49,7 @@ def validate_fragment_in_scope(parent_type, child_type, node, context, path)
if child_types.none? { |c| parent_types.include?(c) }
name = node.respond_to?(:name) ? " #{node.name}" : ""
- context.errors << message("Fragment#{name} on #{child_type.name} can't be spread inside #{parent_type.name}", node, path: path)
+ add_error("Fragment#{name} on #{child_type.name} can't be spread inside #{parent_type.name}", node, path: path)
end
end
diff --git a/lib/graphql/static_validation/rules/fragment_types_exist.rb b/lib/graphql/static_validation/rules/fragment_types_exist.rb
index 06fac1e4d8..f087fc6e0b 100644
--- a/lib/graphql/static_validation/rules/fragment_types_exist.rb
+++ b/lib/graphql/static_validation/rules/fragment_types_exist.rb
@@ -1,29 +1,33 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentTypesExist
- include GraphQL::StaticValidation::Message::MessageHelper
-
- FRAGMENTS_ON_TYPES = [
- GraphQL::Language::Nodes::FragmentDefinition,
- GraphQL::Language::Nodes::InlineFragment,
- ]
+ module FragmentTypesExist
+ def on_fragment_definition(node, _parent)
+ if validate_type_exists(node)
+ super
+ end
+ end
- def validate(context)
- FRAGMENTS_ON_TYPES.each do |node_class|
- context.visitor[node_class] << ->(node, parent) { validate_type_exists(node, context) }
+ def on_inline_fragment(node, _parent)
+ if validate_type_exists(node)
+ super
end
end
private
- def validate_type_exists(node, context)
- return unless node.type
- type_name = node.type.name
- type = context.warden.get_type(type_name)
- if type.nil?
- context.errors << message("No such type #{type_name}, so it can't be a fragment condition", node, context: context)
- GraphQL::Language::Visitor::SKIP
+ def validate_type_exists(fragment_node)
+ if !fragment_node.type
+ true
+ else
+ type_name = fragment_node.type.name
+ type = context.warden.get_type(type_name)
+ if type.nil?
+ add_error("No such type #{type_name}, so it can't be a fragment condition", fragment_node)
+ false
+ else
+ true
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/fragments_are_finite.rb b/lib/graphql/static_validation/rules/fragments_are_finite.rb
index da0d198f54..d0bd368e55 100644
--- a/lib/graphql/static_validation/rules/fragments_are_finite.rb
+++ b/lib/graphql/static_validation/rules/fragments_are_finite.rb
@@ -1,16 +1,13 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentsAreFinite
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::Document].leave << ->(_n, _p) do
- dependency_map = context.dependencies
- dependency_map.cyclical_definitions.each do |defn|
- if defn.node.is_a?(GraphQL::Language::Nodes::FragmentDefinition)
- context.errors << message("Fragment #{defn.name} contains an infinite loop", defn.node, path: defn.path)
- end
+ module FragmentsAreFinite
+ def on_document(_n, _p)
+ super
+ dependency_map = context.dependencies
+ dependency_map.cyclical_definitions.each do |defn|
+ if defn.node.is_a?(GraphQL::Language::Nodes::FragmentDefinition)
+ context.errors << message("Fragment #{defn.name} contains an infinite loop", defn.node, path: defn.path)
end
end
end
diff --git a/lib/graphql/static_validation/rules/fragments_are_named.rb b/lib/graphql/static_validation/rules/fragments_are_named.rb
index d48dc1f6d2..c460ad0adb 100644
--- a/lib/graphql/static_validation/rules/fragments_are_named.rb
+++ b/lib/graphql/static_validation/rules/fragments_are_named.rb
@@ -1,19 +1,12 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentsAreNamed
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::FragmentDefinition] << ->(node, parent) { validate_name_exists(node, context) }
- end
-
- private
-
- def validate_name_exists(node, context)
+ module FragmentsAreNamed
+ def on_fragment_definition(node, _parent)
if node.name.nil?
- context.errors << message("Fragment definition has no name", node, context: context)
+ add_error("Fragment definition has no name", node)
end
+ super
end
end
end
diff --git a/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb b/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb
index 5d92cfa535..ce92ca1beb 100644
--- a/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb
+++ b/lib/graphql/static_validation/rules/fragments_are_on_composite_types.rb
@@ -1,34 +1,30 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentsAreOnCompositeTypes
- include GraphQL::StaticValidation::Message::MessageHelper
-
- HAS_TYPE_CONDITION = [
- GraphQL::Language::Nodes::FragmentDefinition,
- GraphQL::Language::Nodes::InlineFragment,
- ]
+ module FragmentsAreOnCompositeTypes
+ def on_fragment_definition(node, parent)
+ validate_type_is_composite(node) && super
+ end
- def validate(context)
- HAS_TYPE_CONDITION.each do |node_class|
- context.visitor[node_class] << ->(node, parent) {
- validate_type_is_composite(node, context)
- }
- end
+ def on_inline_fragment(node, parent)
+ validate_type_is_composite(node) && super
end
private
- def validate_type_is_composite(node, context)
+ def validate_type_is_composite(node)
node_type = node.type
if node_type.nil?
# Inline fragment on the same type
+ true
else
type_name = node_type.to_query_string
type_def = context.warden.get_type(type_name)
if type_def.nil? || !type_def.kind.composite?
- context.errors << message("Invalid fragment on type #{type_name} (must be Union, Interface or Object)", node, context: context)
- GraphQL::Language::Visitor::SKIP
+ add_error("Invalid fragment on type #{type_name} (must be Union, Interface or Object)", node)
+ false
+ else
+ true
end
end
end
diff --git a/lib/graphql/static_validation/rules/fragments_are_used.rb b/lib/graphql/static_validation/rules/fragments_are_used.rb
index 5489818087..0d76ef224a 100644
--- a/lib/graphql/static_validation/rules/fragments_are_used.rb
+++ b/lib/graphql/static_validation/rules/fragments_are_used.rb
@@ -1,22 +1,19 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class FragmentsAreUsed
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::Document].leave << ->(_n, _p) do
- dependency_map = context.dependencies
- dependency_map.unmet_dependencies.each do |op_defn, spreads|
- spreads.each do |fragment_spread|
- context.errors << message("Fragment #{fragment_spread.name} was used, but not defined", fragment_spread.node, path: fragment_spread.path)
- end
+ module FragmentsAreUsed
+ def on_document(node, parent)
+ super
+ dependency_map = context.dependencies
+ dependency_map.unmet_dependencies.each do |op_defn, spreads|
+ spreads.each do |fragment_spread|
+ add_error("Fragment #{fragment_spread.name} was used, but not defined", fragment_spread.node, path: fragment_spread.path)
end
+ end
- dependency_map.unused_dependencies.each do |fragment|
- if !fragment.name.nil?
- context.errors << message("Fragment #{fragment.name} was defined, but not used", fragment.node, path: fragment.path)
- end
+ dependency_map.unused_dependencies.each do |fragment|
+ if !fragment.name.nil?
+ add_error("Fragment #{fragment.name} was defined, but not used", fragment.node, path: fragment.path)
end
end
end
diff --git a/lib/graphql/static_validation/rules/mutation_root_exists.rb b/lib/graphql/static_validation/rules/mutation_root_exists.rb
index 0f56859e68..8d65f95e0e 100644
--- a/lib/graphql/static_validation/rules/mutation_root_exists.rb
+++ b/lib/graphql/static_validation/rules/mutation_root_exists.rb
@@ -1,20 +1,13 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class MutationRootExists
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- return if context.warden.root_type_for_operation("mutation")
-
- visitor = context.visitor
-
- visitor[GraphQL::Language::Nodes::OperationDefinition].enter << ->(ast_node, prev_ast_node) {
- if ast_node.operation_type == 'mutation'
- context.errors << message('Schema is not configured for mutations', ast_node, context: context)
- return GraphQL::Language::Visitor::SKIP
- end
- }
+ module MutationRootExists
+ def on_operation_definition(node, _parent)
+ if node.operation_type == 'mutation' && context.warden.root_type_for_operation("mutation").nil?
+ add_error('Schema is not configured for mutations', node)
+ else
+ super
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/no_definitions_are_present.rb b/lib/graphql/static_validation/rules/no_definitions_are_present.rb
index c2e661fb2e..5818fd7152 100644
--- a/lib/graphql/static_validation/rules/no_definitions_are_present.rb
+++ b/lib/graphql/static_validation/rules/no_definitions_are_present.rb
@@ -1,40 +1,40 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class NoDefinitionsArePresent
+ module NoDefinitionsArePresent
include GraphQL::StaticValidation::Message::MessageHelper
- def validate(context)
- schema_definition_nodes = []
- register_node = ->(node, _p) {
- schema_definition_nodes << node
- GraphQL::Language::Visitor::SKIP
- }
-
- visitor = context.visitor
+ def initialize(*)
+ super
+ @schema_definition_nodes = []
+ end
- visitor[GraphQL::Language::Nodes::DirectiveDefinition] << register_node
- visitor[GraphQL::Language::Nodes::SchemaDefinition] << register_node
- visitor[GraphQL::Language::Nodes::ScalarTypeDefinition] << register_node
- visitor[GraphQL::Language::Nodes::ObjectTypeDefinition] << register_node
- visitor[GraphQL::Language::Nodes::InputObjectTypeDefinition] << register_node
- visitor[GraphQL::Language::Nodes::InterfaceTypeDefinition] << register_node
- visitor[GraphQL::Language::Nodes::UnionTypeDefinition] << register_node
- visitor[GraphQL::Language::Nodes::EnumTypeDefinition] << register_node
+ def on_invalid_node(node, parent)
+ @schema_definition_nodes << node
+ nil
+ end
- visitor[GraphQL::Language::Nodes::SchemaExtension] << register_node
- visitor[GraphQL::Language::Nodes::ScalarTypeExtension] << register_node
- visitor[GraphQL::Language::Nodes::ObjectTypeExtension] << register_node
- visitor[GraphQL::Language::Nodes::InputObjectTypeExtension] << register_node
- visitor[GraphQL::Language::Nodes::InterfaceTypeExtension] << register_node
- visitor[GraphQL::Language::Nodes::UnionTypeExtension] << register_node
- visitor[GraphQL::Language::Nodes::EnumTypeExtension] << register_node
+ alias :on_directive_definition :on_invalid_node
+ alias :on_schema_definition :on_invalid_node
+ alias :on_scalar_type_definition :on_invalid_node
+ alias :on_object_type_definition :on_invalid_node
+ alias :on_input_object_type_definition :on_invalid_node
+ alias :on_interface_type_definition :on_invalid_node
+ alias :on_union_type_definition :on_invalid_node
+ alias :on_enum_type_definition :on_invalid_node
+ alias :on_schema_extension :on_invalid_node
+ alias :on_scalar_type_extension :on_invalid_node
+ alias :on_object_type_extension :on_invalid_node
+ alias :on_input_object_type_extension :on_invalid_node
+ alias :on_interface_type_extension :on_invalid_node
+ alias :on_union_type_extension :on_invalid_node
+ alias :on_enum_type_extension :on_invalid_node
- visitor[GraphQL::Language::Nodes::Document].leave << ->(node, _p) {
- if schema_definition_nodes.any?
- context.errors << message(%|Query cannot contain schema definitions|, schema_definition_nodes, context: context)
- end
- }
+ def on_document(node, parent)
+ super
+ if @schema_definition_nodes.any?
+ add_error(%|Query cannot contain schema definitions|, @schema_definition_nodes)
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/operation_names_are_valid.rb b/lib/graphql/static_validation/rules/operation_names_are_valid.rb
index e464dc1d2d..9a7ad2226e 100644
--- a/lib/graphql/static_validation/rules/operation_names_are_valid.rb
+++ b/lib/graphql/static_validation/rules/operation_names_are_valid.rb
@@ -1,27 +1,28 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class OperationNamesAreValid
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- op_names = Hash.new { |h, k| h[k] = [] }
+ module OperationNamesAreValid
+ def initialize(*)
+ super
+ @operation_names = Hash.new { |h, k| h[k] = [] }
+ end
- context.visitor[GraphQL::Language::Nodes::OperationDefinition].enter << ->(node, _parent) {
- op_names[node.name] << node
- }
+ def on_operation_definition(node, parent)
+ @operation_names[node.name] << node
+ super
+ end
- context.visitor[GraphQL::Language::Nodes::Document].leave << ->(node, _parent) {
- op_count = op_names.values.inject(0) { |m, v| m + v.size }
+ def on_document(node, parent)
+ super
+ op_count = @operation_names.values.inject(0) { |m, v| m + v.size }
- op_names.each do |name, nodes|
- if name.nil? && op_count > 1
- context.errors << message(%|Operation name is required when multiple operations are present|, nodes, context: context)
- elsif nodes.length > 1
- context.errors << message(%|Operation name "#{name}" must be unique|, nodes, context: context)
- end
+ @operation_names.each do |name, nodes|
+ if name.nil? && op_count > 1
+ add_error(%|Operation name is required when multiple operations are present|, nodes)
+ elsif nodes.length > 1
+ add_error(%|Operation name "#{name}" must be unique|, nodes)
end
- }
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/required_arguments_are_present.rb b/lib/graphql/static_validation/rules/required_arguments_are_present.rb
index 50e50e97bf..3569dd910f 100644
--- a/lib/graphql/static_validation/rules/required_arguments_are_present.rb
+++ b/lib/graphql/static_validation/rules/required_arguments_are_present.rb
@@ -1,28 +1,21 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class RequiredArgumentsArePresent
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- v = context.visitor
- v[GraphQL::Language::Nodes::Field] << ->(node, parent) { validate_field(node, context) }
- v[GraphQL::Language::Nodes::Directive] << ->(node, parent) { validate_directive(node, context) }
+ module RequiredArgumentsArePresent
+ def on_field(node, _parent)
+ assert_required_args(node, field_definition)
+ super
end
- private
-
- def validate_directive(ast_directive, context)
- directive_defn = context.schema.directives[ast_directive.name]
- assert_required_args(ast_directive, directive_defn, context)
+ def on_directive(node, _parent)
+ directive_defn = context.schema.directives[node.name]
+ assert_required_args(node, directive_defn)
+ super
end
- def validate_field(ast_field, context)
- defn = context.field_definition
- assert_required_args(ast_field, defn, context)
- end
+ private
- def assert_required_args(ast_node, defn, context)
+ def assert_required_args(ast_node, defn)
present_argument_names = ast_node.arguments.map(&:name)
required_argument_names = defn.arguments.values
.select { |a| a.type.kind.non_null? }
@@ -30,7 +23,7 @@ def assert_required_args(ast_node, defn, context)
missing_names = required_argument_names - present_argument_names
if missing_names.any?
- context.errors << message("#{ast_node.class.name.split("::").last} '#{ast_node.name}' is missing required arguments: #{missing_names.join(", ")}", ast_node, context: context)
+ add_error("#{ast_node.class.name.split("::").last} '#{ast_node.name}' is missing required arguments: #{missing_names.join(", ")}", ast_node)
end
end
end
diff --git a/lib/graphql/static_validation/rules/subscription_root_exists.rb b/lib/graphql/static_validation/rules/subscription_root_exists.rb
index 065b9b0a42..1143aa7ef1 100644
--- a/lib/graphql/static_validation/rules/subscription_root_exists.rb
+++ b/lib/graphql/static_validation/rules/subscription_root_exists.rb
@@ -1,20 +1,13 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class SubscriptionRootExists
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- return if context.warden.root_type_for_operation("subscription")
-
- visitor = context.visitor
-
- visitor[GraphQL::Language::Nodes::OperationDefinition].enter << ->(ast_node, prev_ast_node) {
- if ast_node.operation_type == 'subscription'
- context.errors << message('Schema is not configured for subscriptions', ast_node, context: context)
- return GraphQL::Language::Visitor::SKIP
- end
- }
+ module SubscriptionRootExists
+ def on_operation_definition(node, _parent)
+ if node.operation_type == "subscription" && context.warden.root_type_for_operation("subscription").nil?
+ add_error('Schema is not configured for subscriptions', node)
+ else
+ super
+ end
end
end
end
diff --git a/lib/graphql/static_validation/rules/unique_directives_per_location.rb b/lib/graphql/static_validation/rules/unique_directives_per_location.rb
index 43944c4225..ec8370cedc 100644
--- a/lib/graphql/static_validation/rules/unique_directives_per_location.rb
+++ b/lib/graphql/static_validation/rules/unique_directives_per_location.rb
@@ -1,33 +1,43 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class UniqueDirectivesPerLocation
- include GraphQL::StaticValidation::Message::MessageHelper
+ module UniqueDirectivesPerLocation
+ DIRECTIVE_NODE_HOOKS = [
+ :on_fragment_definition,
+ :on_fragment_spread,
+ :on_inline_fragment,
+ :on_operation_definition,
+ :on_scalar_type_definition,
+ :on_object_type_definition,
+ :on_input_value_definition,
+ :on_field_definition,
+ :on_interface_type_definition,
+ :on_union_type_definition,
+ :on_enum_type_definition,
+ :on_enum_value_definition,
+ :on_input_object_type_definition,
+ :on_field,
+ ]
- NODES_WITH_DIRECTIVES = GraphQL::Language::Nodes.constants
- .map{|c| GraphQL::Language::Nodes.const_get(c)}
- .select{|c| c.is_a?(Class) && c.instance_methods.include?(:directives)}
-
- def validate(context)
- NODES_WITH_DIRECTIVES.each do |node_class|
- context.visitor[node_class] << ->(node, _) {
- validate_directives(node, context) unless node.directives.empty?
- }
+ DIRECTIVE_NODE_HOOKS.each do |method_name|
+ define_method(method_name) do |node, parent|
+ if node.directives.any?
+ validate_directive_location(node)
+ end
+ super(node, parent)
end
end
private
- def validate_directives(node, context)
+ def validate_directive_location(node)
used_directives = {}
-
node.directives.each do |ast_directive|
directive_name = ast_directive.name
if used_directives[directive_name]
- context.errors << message(
+ add_error(
"The directive \"#{directive_name}\" can only be used once at this location.",
- [used_directives[directive_name], ast_directive],
- context: context
+ [used_directives[directive_name], ast_directive]
)
else
used_directives[directive_name] = ast_directive
diff --git a/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb b/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb
index 83d552a7ba..b799e0d202 100644
--- a/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb
+++ b/lib/graphql/static_validation/rules/variable_default_values_are_correctly_typed.rb
@@ -1,39 +1,33 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class VariableDefaultValuesAreCorrectlyTyped
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::VariableDefinition] << ->(node, parent) {
- if !node.default_value.nil?
- validate_default_value(node, context)
- end
- }
- end
-
- def validate_default_value(node, context)
- value = node.default_value
- if node.type.is_a?(GraphQL::Language::Nodes::NonNullType)
- context.errors << message("Non-null variable $#{node.name} can't have a default value", node, context: context)
- else
- type = context.schema.type_from_ast(node.type)
- if type.nil?
- # This is handled by another validator
+ module VariableDefaultValuesAreCorrectlyTyped
+ def on_variable_definition(node, parent)
+ if !node.default_value.nil?
+ value = node.default_value
+ if node.type.is_a?(GraphQL::Language::Nodes::NonNullType)
+ add_error("Non-null variable $#{node.name} can't have a default value", node)
else
- begin
- valid = context.valid_literal?(value, type)
- rescue GraphQL::CoercionError => err
- error_message = err.message
- end
+ type = context.schema.type_from_ast(node.type)
+ if type.nil?
+ # This is handled by another validator
+ else
+ begin
+ valid = context.valid_literal?(value, type)
+ rescue GraphQL::CoercionError => err
+ error_message = err.message
+ end
- if !valid
- error_message ||= "Default value for $#{node.name} doesn't match type #{type}"
- context.errors << message(error_message, node, context: context)
- end
- end
+ if !valid
+ error_message ||= "Default value for $#{node.name} doesn't match type #{type}"
+ add_error(error_message, node)
+ end
+ end
+ end
end
+
+ super
end
end
- end
+ end
end
diff --git a/lib/graphql/static_validation/rules/variable_names_are_unique.rb b/lib/graphql/static_validation/rules/variable_names_are_unique.rb
index 340aa2eb47..c8117a00ec 100644
--- a/lib/graphql/static_validation/rules/variable_names_are_unique.rb
+++ b/lib/graphql/static_validation/rules/variable_names_are_unique.rb
@@ -1,22 +1,19 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class VariableNamesAreUnique
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) {
- var_defns = node.variables
- if var_defns.any?
- vars_by_name = Hash.new { |h, k| h[k] = [] }
- var_defns.each { |v| vars_by_name[v.name] << v }
- vars_by_name.each do |name, defns|
- if defns.size > 1
- context.errors << message("There can only be one variable named \"#{name}\"", defns, context: context)
- end
+ module VariableNamesAreUnique
+ def on_operation_definition(node, parent)
+ var_defns = node.variables
+ if var_defns.any?
+ vars_by_name = Hash.new { |h, k| h[k] = [] }
+ var_defns.each { |v| vars_by_name[v.name] << v }
+ vars_by_name.each do |name, defns|
+ if defns.size > 1
+ add_error("There can only be one variable named \"#{name}\"", defns)
end
end
- }
+ end
+ super
end
end
end
diff --git a/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb b/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb
index 50a9d6fe3e..5ce3605560 100644
--- a/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb
+++ b/lib/graphql/static_validation/rules/variable_usages_are_allowed.rb
@@ -1,53 +1,57 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class VariableUsagesAreAllowed
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
+ module VariableUsagesAreAllowed
+ def initialize(*)
+ super
# holds { name => ast_node } pairs
- declared_variables = {}
- context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) {
- declared_variables = node.variables.each_with_object({}) { |var, memo| memo[var.name] = var }
- }
-
- context.visitor[GraphQL::Language::Nodes::Argument] << ->(node, parent) {
- node_values = if node.value.is_a?(Array)
- node.value
- else
- [node.value]
- end
- node_values = node_values.select { |value| value.is_a? GraphQL::Language::Nodes::VariableIdentifier }
+ @declared_variables = {}
+ end
- return if node_values.none?
+ def on_operation_definition(node, parent)
+ @declared_variables = node.variables.each_with_object({}) { |var, memo| memo[var.name] = var }
+ super
+ end
- arguments = nil
- case parent
+ def on_argument(node, parent)
+ node_values = if node.value.is_a?(Array)
+ node.value
+ else
+ [node.value]
+ end
+ node_values = node_values.select { |value| value.is_a? GraphQL::Language::Nodes::VariableIdentifier }
+
+ if node_values.any?
+ arguments = case parent
when GraphQL::Language::Nodes::Field
- arguments = context.field_definition.arguments
+ context.field_definition.arguments
when GraphQL::Language::Nodes::Directive
- arguments = context.directive_definition.arguments
+ context.directive_definition.arguments
when GraphQL::Language::Nodes::InputObject
arg_type = context.argument_definition.type.unwrap
if arg_type.is_a?(GraphQL::InputObjectType)
arguments = arg_type.input_fields
+ else
+ # This is some kind of error
+ nil
end
else
raise("Unexpected argument parent: #{parent}")
end
node_values.each do |node_value|
- var_defn_ast = declared_variables[node_value.name]
+ var_defn_ast = @declared_variables[node_value.name]
# Might be undefined :(
# VariablesAreUsedAndDefined can't finalize its search until the end of the document.
- var_defn_ast && arguments && validate_usage(arguments, node, var_defn_ast, context)
+ var_defn_ast && arguments && validate_usage(arguments, node, var_defn_ast)
end
- }
+ end
+ super
end
private
- def validate_usage(arguments, arg_node, ast_var, context)
+ def validate_usage(arguments, arg_node, ast_var)
var_type = context.schema.type_from_ast(ast_var.type)
if var_type.nil?
return
@@ -71,16 +75,16 @@ def validate_usage(arguments, arg_node, ast_var, context)
var_type = wrap_var_type_with_depth_of_arg(var_type, arg_node)
if var_inner_type != arg_inner_type
- context.errors << create_error("Type mismatch", var_type, ast_var, arg_defn, arg_node, context)
+ create_error("Type mismatch", var_type, ast_var, arg_defn, arg_node)
elsif list_dimension(var_type) != list_dimension(arg_defn_type)
- context.errors << create_error("List dimension mismatch", var_type, ast_var, arg_defn, arg_node, context)
+ create_error("List dimension mismatch", var_type, ast_var, arg_defn, arg_node)
elsif !non_null_levels_match(arg_defn_type, var_type)
- context.errors << create_error("Nullability mismatch", var_type, ast_var, arg_defn, arg_node, context)
+ create_error("Nullability mismatch", var_type, ast_var, arg_defn, arg_node)
end
end
- def create_error(error_message, var_type, ast_var, arg_defn, arg_node, context)
- message("#{error_message} on variable $#{ast_var.name} and argument #{arg_node.name} (#{var_type.to_s} / #{arg_defn.type.to_s})", arg_node, context: context)
+ def create_error(error_message, var_type, ast_var, arg_defn, arg_node)
+ add_error("#{error_message} on variable $#{ast_var.name} and argument #{arg_node.name} (#{var_type.to_s} / #{arg_defn.type.to_s})", arg_node)
end
def wrap_var_type_with_depth_of_arg(var_type, arg_node)
diff --git a/lib/graphql/static_validation/rules/variables_are_input_types.rb b/lib/graphql/static_validation/rules/variables_are_input_types.rb
index 0f3f6552af..01f1a5be36 100644
--- a/lib/graphql/static_validation/rules/variables_are_input_types.rb
+++ b/lib/graphql/static_validation/rules/variables_are_input_types.rb
@@ -1,28 +1,22 @@
# frozen_string_literal: true
module GraphQL
module StaticValidation
- class VariablesAreInputTypes
- include GraphQL::StaticValidation::Message::MessageHelper
-
- def validate(context)
- context.visitor[GraphQL::Language::Nodes::VariableDefinition] << ->(node, parent) {
- validate_is_input_type(node, context)
- }
- end
-
- private
-
- def validate_is_input_type(node, context)
+ module VariablesAreInputTypes
+ def on_variable_definition(node, parent)
type_name = get_type_name(node.type)
type = context.warden.get_type(type_name)
if type.nil?
- context.errors << message("#{type_name} isn't a defined input type (on $#{node.name})", node, context: context)
+ add_error("#{type_name} isn't a defined input type (on $#{node.name})", node)
elsif !type.kind.input?
- context.errors << message("#{type.name} isn't a valid input type (on $#{node.name})", node, context: context)
+ add_error("#{type.name} isn't a valid input type (on $#{node.name})", node)
end
+
+ super
end
+ private
+
def get_type_name(ast_type)
if ast_type.respond_to?(:of_type)
get_type_name(ast_type.of_type)
diff --git a/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb b/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb
index c6277e627a..380748d09b 100644
--- a/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb
+++ b/lib/graphql/static_validation/rules/variables_are_used_and_defined.rb
@@ -11,9 +11,7 @@ module StaticValidation
# - re-visiting the AST for each validator
# - allowing validators to say `followSpreads: true`
#
- class VariablesAreUsedAndDefined
- include GraphQL::StaticValidation::Message::MessageHelper
-
+ module VariablesAreUsedAndDefined
class VariableUsage
attr_accessor :ast_node, :used_by, :declared_by, :path
def used?
@@ -25,73 +23,67 @@ def declared?
end
end
- def variable_hash
- Hash.new {|h, k| h[k] = VariableUsage.new }
+ def initialize(*)
+ super
+ @variable_usages_for_context = Hash.new {|hash, key| hash[key] = Hash.new {|h, k| h[k] = VariableUsage.new } }
+ @spreads_for_context = Hash.new {|hash, key| hash[key] = [] }
+ @variable_context_stack = []
end
- def validate(context)
- variable_usages_for_context = Hash.new {|hash, key| hash[key] = variable_hash }
- spreads_for_context = Hash.new {|hash, key| hash[key] = [] }
- variable_context_stack = []
-
- # OperationDefinitions and FragmentDefinitions
- # both push themselves onto the context stack (and pop themselves off)
- push_variable_context_stack = ->(node, parent) {
- # initialize the hash of vars for this context:
- variable_usages_for_context[node]
- variable_context_stack.push(node)
+ def on_operation_definition(node, parent)
+ # initialize the hash of vars for this context:
+ @variable_usages_for_context[node]
+ @variable_context_stack.push(node)
+ # mark variables as defined:
+ var_hash = @variable_usages_for_context[node]
+ node.variables.each { |var|
+ var_usage = var_hash[var.name]
+ var_usage.declared_by = node
+ var_usage.path = context.path
}
+ super
+ @variable_context_stack.pop
+ end
- pop_variable_context_stack = ->(node, parent) {
- variable_context_stack.pop
- }
-
-
- context.visitor[GraphQL::Language::Nodes::OperationDefinition] << push_variable_context_stack
- context.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(node, parent) {
- # mark variables as defined:
- var_hash = variable_usages_for_context[node]
- node.variables.each { |var|
- var_usage = var_hash[var.name]
- var_usage.declared_by = node
- var_usage.path = context.path
- }
- }
- context.visitor[GraphQL::Language::Nodes::OperationDefinition].leave << pop_variable_context_stack
-
- context.visitor[GraphQL::Language::Nodes::FragmentDefinition] << push_variable_context_stack
- context.visitor[GraphQL::Language::Nodes::FragmentDefinition].leave << pop_variable_context_stack
-
- # For FragmentSpreads:
- # - find the context on the stack
- # - mark the context as containing this spread
- context.visitor[GraphQL::Language::Nodes::FragmentSpread] << ->(node, parent) {
- variable_context = variable_context_stack.last
- spreads_for_context[variable_context] << node.name
- }
+ def on_fragment_definition(node, parent)
+ # initialize the hash of vars for this context:
+ @variable_usages_for_context[node]
+ @variable_context_stack.push(node)
+ super
+ @variable_context_stack.pop
+ end
- # For VariableIdentifiers:
- # - mark the variable as used
- # - assign its AST node
- context.visitor[GraphQL::Language::Nodes::VariableIdentifier] << ->(node, parent) {
- usage_context = variable_context_stack.last
- declared_variables = variable_usages_for_context[usage_context]
- usage = declared_variables[node.name]
- usage.used_by = usage_context
- usage.ast_node = node
- usage.path = context.path
- }
+ # For FragmentSpreads:
+ # - find the context on the stack
+ # - mark the context as containing this spread
+ def on_fragment_spread(node, parent)
+ variable_context = @variable_context_stack.last
+ @spreads_for_context[variable_context] << node.name
+ super
+ end
+ # For VariableIdentifiers:
+ # - mark the variable as used
+ # - assign its AST node
+ def on_variable_identifier(node, parent)
+ usage_context = @variable_context_stack.last
+ declared_variables = @variable_usages_for_context[usage_context]
+ usage = declared_variables[node.name]
+ usage.used_by = usage_context
+ usage.ast_node = node
+ usage.path = context.path
+ super
+ end
- context.visitor[GraphQL::Language::Nodes::Document].leave << ->(node, parent) {
- fragment_definitions = variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::FragmentDefinition) }
- operation_definitions = variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::OperationDefinition) }
+ def on_document(node, parent)
+ super
+ fragment_definitions = @variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::FragmentDefinition) }
+ operation_definitions = @variable_usages_for_context.select { |key, value| key.is_a?(GraphQL::Language::Nodes::OperationDefinition) }
- operation_definitions.each do |node, node_variables|
- follow_spreads(node, node_variables, spreads_for_context, fragment_definitions, [])
- create_errors(node_variables, context)
- end
- }
+ operation_definitions.each do |node, node_variables|
+ follow_spreads(node, node_variables, @spreads_for_context, fragment_definitions, [])
+ create_errors(node_variables)
+ end
end
private
@@ -129,16 +121,16 @@ def follow_spreads(node, parent_variables, spreads_for_context, fragment_definit
# Determine all the error messages,
# Then push messages into the validation context
- def create_errors(node_variables, context)
+ def create_errors(node_variables)
# Declared but not used:
node_variables
.select { |name, usage| usage.declared? && !usage.used? }
- .each { |var_name, usage| context.errors << message("Variable $#{var_name} is declared by #{usage.declared_by.name} but not used", usage.declared_by, path: usage.path) }
+ .each { |var_name, usage| add_error("Variable $#{var_name} is declared by #{usage.declared_by.name} but not used", usage.declared_by, path: usage.path) }
# Used but not declared:
node_variables
.select { |name, usage| usage.used? && !usage.declared? }
- .each { |var_name, usage| context.errors << message("Variable $#{var_name} is used by #{usage.used_by.name} but not declared", usage.ast_node, path: usage.path) }
+ .each { |var_name, usage| add_error("Variable $#{var_name} is used by #{usage.used_by.name} but not declared", usage.ast_node, path: usage.path) }
end
end
end
diff --git a/lib/graphql/static_validation/validation_context.rb b/lib/graphql/static_validation/validation_context.rb
index 7787936cfe..855b5b8c91 100644
--- a/lib/graphql/static_validation/validation_context.rb
+++ b/lib/graphql/static_validation/validation_context.rb
@@ -14,70 +14,35 @@ module StaticValidation
class ValidationContext
extend Forwardable
- attr_reader :query, :errors, :visitor, :dependencies, :each_irep_node_handlers
+
+ attr_reader :query, :errors, :visitor,
+ :on_dependency_resolve_handlers, :each_irep_node_handlers
def_delegators :@query, :schema, :document, :fragments, :operations, :warden
- def initialize(query)
+ def initialize(query, visitor_class)
@query = query
@literal_validator = LiteralValidator.new(context: query.context)
@errors = []
- @visitor = GraphQL::Language::Visitor.new(document)
- @type_stack = GraphQL::StaticValidation::TypeStack.new(schema, visitor)
- definition_dependencies = DefinitionDependencies.mount(self)
- @on_dependency_resolve_handlers = []
+ # TODO it will take some finegalling but I think all this state could
+ # be moved to `Visitor`
@each_irep_node_handlers = []
- visitor[GraphQL::Language::Nodes::Document].leave << ->(_n, _p) {
- @dependencies = definition_dependencies.dependency_map { |defn, spreads, frag|
- @on_dependency_resolve_handlers.each { |h| h.call(defn, spreads, frag) }
- }
- }
+ @on_dependency_resolve_handlers = []
+ @visitor = visitor_class.new(document, self)
end
+ def_delegators :@visitor,
+ :path, :type_definition, :field_definition, :argument_definition,
+ :parent_type_definition, :directive_definition, :object_types, :dependencies
+
def on_dependency_resolve(&handler)
@on_dependency_resolve_handlers << handler
end
- def object_types
- @type_stack.object_types
- end
-
def each_irep_node(&handler)
@each_irep_node_handlers << handler
end
- # @return [GraphQL::BaseType] The current object type
- def type_definition
- object_types.last
- end
-
- # @return [GraphQL::BaseType] The type which the current type came from
- def parent_type_definition
- object_types[-2]
- end
-
- # @return [GraphQL::Field, nil] The most-recently-entered GraphQL::Field, if currently inside one
- def field_definition
- @type_stack.field_definitions.last
- end
-
- # @return [Array] Field names to get to the current field
- def path
- @type_stack.path.dup
- end
-
- # @return [GraphQL::Directive, nil] The most-recently-entered GraphQL::Directive, if currently inside one
- def directive_definition
- @type_stack.directive_definitions.last
- end
-
- # @return [GraphQL::Argument, nil] The most-recently-entered GraphQL::Argument, if currently inside one
- def argument_definition
- # Don't get the _last_ one because that's the current one.
- # Get the second-to-last one, which is the parent of the current one.
- @type_stack.argument_definitions[-2]
- end
-
def valid_literal?(ast_value, type)
@literal_validator.validate(ast_value, type)
end
diff --git a/lib/graphql/static_validation/validator.rb b/lib/graphql/static_validation/validator.rb
index 32d349c3de..2b3eaf1094 100644
--- a/lib/graphql/static_validation/validator.rb
+++ b/lib/graphql/static_validation/validator.rb
@@ -23,23 +23,22 @@ def initialize(schema:, rules: GraphQL::StaticValidation::ALL_RULES)
# @return [Array]
def validate(query, validate: true)
query.trace("validate", { validate: validate, query: query }) do
- context = GraphQL::StaticValidation::ValidationContext.new(query)
- rewrite = GraphQL::InternalRepresentation::Rewrite.new
- # Put this first so its enters and exits are always called
- rewrite.validate(context)
+ rules_to_use = validate ? @rules : []
+ visitor_class = BaseVisitor.including_rules(rules_to_use)
- # If the caller opted out of validation, don't attach these
- if validate
- @rules.each do |rules|
- rules.new.validate(context)
+ context = GraphQL::StaticValidation::ValidationContext.new(query, visitor_class)
+
+ # Attach legacy-style rules
+ rules_to_use.each do |rule_class_or_module|
+ if rule_class_or_module.method_defined?(:validate)
+ rule_class_or_module.new.validate(context)
end
end
context.visitor.visit
- rewrite_result = rewrite.document
-
# Post-validation: allow validators to register handlers on rewritten query nodes
+ rewrite_result = context.visitor.rewrite_document
GraphQL::InternalRepresentation::Visit.visit_each_node(rewrite_result.operation_definitions, context.each_irep_node_handlers)
{
diff --git a/lib/graphql/types/iso_8601_date_time.rb b/lib/graphql/types/iso_8601_date_time.rb
index 497adb9994..0382ccbb97 100644
--- a/lib/graphql/types/iso_8601_date_time.rb
+++ b/lib/graphql/types/iso_8601_date_time.rb
@@ -15,10 +15,24 @@ module Types
class ISO8601DateTime < GraphQL::Schema::Scalar
description "An ISO 8601-encoded datetime"
+ # It's not compatible with Rails' default,
+ # i.e. ActiveSupport::JSON::Encoder.time_precision (3 by default)
+ DEFAULT_TIME_PRECISION = 0
+
+ # @return [Integer]
+ def self.time_precision
+ @time_precision || DEFAULT_TIME_PRECISION
+ end
+
+ # @param [Integer] value
+ def self.time_precision=(value)
+ @time_precision = value
+ end
+
# @param value [DateTime]
# @return [String]
def self.coerce_result(value, _ctx)
- value.iso8601
+ value.iso8601(time_precision)
end
# @param str_value [String]
diff --git a/spec/graphql/introspection/type_type_spec.rb b/spec/graphql/introspection/type_type_spec.rb
index 94a7a27823..d85dfc16aa 100644
--- a/spec/graphql/introspection/type_type_spec.rb
+++ b/spec/graphql/introspection/type_type_spec.rb
@@ -147,7 +147,7 @@
type_result = res["data"]["__schema"]["types"].find { |t| t["name"] == "Faction" }
field_result = type_result["fields"].find { |f| f["name"] == "bases" }
- all_arg_names = ["first", "after", "last", "before", "nameIncludes"]
+ all_arg_names = ["after", "before", "first", "last", "nameIncludes"]
returned_arg_names = field_result["args"].map { |a| a["name"] }
assert_equal all_arg_names, returned_arg_names
end
diff --git a/spec/graphql/language/visitor_spec.rb b/spec/graphql/language/visitor_spec.rb
index 10f5fda046..e5bf69a36c 100644
--- a/spec/graphql/language/visitor_spec.rb
+++ b/spec/graphql/language/visitor_spec.rb
@@ -16,10 +16,11 @@
fragment cheeseFields on Cheese { flavor }
")}
- let(:counts) { {fields_entered: 0, arguments_entered: 0, arguments_left: 0, argument_names: []} }
+ let(:hooks_counts) { {fields_entered: 0, arguments_entered: 0, arguments_left: 0, argument_names: []} }
- let(:visitor) do
+ let(:hooks_visitor) do
v = GraphQL::Language::Visitor.new(document)
+ counts = hooks_counts
v[GraphQL::Language::Nodes::Field] << ->(node, parent) { counts[:fields_entered] += 1 }
# two ways to set up enter hooks:
v[GraphQL::Language::Nodes::Argument] << ->(node, parent) { counts[:argument_names] << node.name }
@@ -30,14 +31,43 @@
v
end
- it "calls hooks during a depth-first tree traversal" do
- assert_equal(2, visitor[GraphQL::Language::Nodes::Argument].enter.length)
- visitor.visit
- assert_equal(6, counts[:fields_entered])
- assert_equal(2, counts[:arguments_entered])
- assert_equal(2, counts[:arguments_left])
- assert_equal(["id", "first"], counts[:argument_names])
- assert(counts[:finished])
+ class VisitorSpecVisitor < GraphQL::Language::Visitor
+ attr_reader :counts
+ def initialize(document)
+ @counts = {fields_entered: 0, arguments_entered: 0, arguments_left: 0, argument_names: []}
+ super
+ end
+
+ def on_field(node, parent)
+ counts[:fields_entered] += 1
+ super(node, parent)
+ end
+
+ def on_argument(node, parent)
+ counts[:argument_names] << node.name
+ counts[:arguments_entered] += 1
+ super
+ ensure
+ counts[:arguments_left] += 1
+ end
+
+ def on_document(node, parent)
+ counts[:finished] = true
+ super
+ end
+ end
+
+ class SkippingVisitor < VisitorSpecVisitor
+ def on_document(_n, _p)
+ SKIP
+ end
+ end
+
+ let(:class_based_visitor) { VisitorSpecVisitor.new(document) }
+ let(:class_based_counts) { class_based_visitor.counts }
+
+ it "has an array of hooks" do
+ assert_equal(2, hooks_visitor[GraphQL::Language::Nodes::Argument].enter.length)
end
it "can visit a document with directive definitions" do
@@ -65,11 +95,30 @@
assert_equal 10, directive_locations.length
end
- describe "Visitor::SKIP" do
- it "skips the rest of the node" do
- visitor[GraphQL::Language::Nodes::Document] << ->(node, parent) { GraphQL::Language::Visitor::SKIP }
+ [:hooks, :class_based].each do |visitor_type|
+ it "#{visitor_type} visitor calls hooks during a depth-first tree traversal" do
+ visitor = public_send("#{visitor_type}_visitor")
visitor.visit
- assert_equal(0, counts[:fields_entered])
+ counts = public_send("#{visitor_type}_counts")
+ assert_equal(6, counts[:fields_entered])
+ assert_equal(2, counts[:arguments_entered])
+ assert_equal(2, counts[:arguments_left])
+ assert_equal(["id", "first"], counts[:argument_names])
+ assert(counts[:finished])
+ end
+
+ describe "Visitor::SKIP" do
+ let(:class_based_visitor) { SkippingVisitor.new(document) }
+
+ it "#{visitor_type} visitor skips the rest of the node" do
+ visitor = public_send("#{visitor_type}_visitor")
+ if visitor_type == :hooks
+ visitor[GraphQL::Language::Nodes::Document] << ->(node, parent) { GraphQL::Language::Visitor::SKIP }
+ end
+ visitor.visit
+ counts = public_send("#{visitor_type}_counts")
+ assert_equal(0, counts[:fields_entered])
+ end
end
end
@@ -91,4 +140,169 @@
assert visited_directive
end
+
+ describe "AST modification" do
+ class ModificationTestVisitor < GraphQL::Language::Visitor
+ def on_field(node, parent)
+ if node.name == "c"
+ new_node = node.merge(name: "renamedC")
+ super(new_node, parent)
+ elsif node.name == "addFields"
+ new_node = node.merge_selection(name: "addedChild")
+ super(new_node, parent)
+ elsif node.name == "anotherAddition"
+ new_node = node
+ .merge_argument(name: "addedArgument", value: 1)
+ .merge_directive(name: "doStuff")
+ super(new_node, parent)
+ else
+ super
+ end
+ end
+
+ def on_argument(node, parent)
+ if node.name == "deleteMe"
+ super(DELETE_NODE, parent)
+ else
+ super
+ end
+ end
+
+ def on_input_object(node, parent)
+ if node.arguments.map(&:name).sort == ["delete", "me"]
+ super(DELETE_NODE, parent)
+ else
+ super
+ end
+ end
+
+ def on_directive(node, parent)
+ if node.name == "doStuff"
+ new_node = node.merge_argument(name: "addedArgument2", value: 2)
+ super(new_node, parent)
+ else
+ super
+ end
+ end
+ end
+
+ def get_result(query_str)
+ document = GraphQL.parse(query_str)
+ visitor = ModificationTestVisitor.new(document)
+ visitor.visit
+ return document, visitor.result
+ end
+
+ it "returns a new AST with modifications applied" do
+ query = <<-GRAPHQL.chop
+query {
+ a(a1: 1) {
+ b(b2: 2) {
+ c(c3: 3)
+ }
+ }
+ d(d4: 4)
+}
+ GRAPHQL
+ document, new_document = get_result(query)
+ refute_equal document, new_document
+ expected_result = <<-GRAPHQL.chop
+query {
+ a(a1: 1) {
+ b(b2: 2) {
+ renamedC(c3: 3)
+ }
+ }
+ d(d4: 4)
+}
+GRAPHQL
+ assert_equal expected_result, new_document.to_query_string, "the result has changes"
+ assert_equal query, document.to_query_string, "the original is unchanged"
+
+ # This is testing the implementation: nodes which aren't affected by modification
+ # should be shared between the two trees
+ orig_c3_argument = document.definitions.first.selections.first.selections.first.selections.first.arguments.first
+ copy_c3_argument = new_document.definitions.first.selections.first.selections.first.selections.first.arguments.first
+ assert_equal "c3", orig_c3_argument.name
+ assert orig_c3_argument.equal?(copy_c3_argument), "Child nodes are persisted"
+
+ orig_d_field = document.definitions.first.selections[1]
+ copy_d_field = new_document.definitions.first.selections[1]
+ assert_equal "d", orig_d_field.name
+ assert orig_d_field.equal?(copy_d_field), "Sibling nodes are persisted"
+
+ orig_b_field = document.definitions.first.selections.first.selections.first
+ copy_b_field = new_document.definitions.first.selections.first.selections.first
+ assert_equal "b", orig_b_field.name
+ refute orig_b_field.equal?(copy_b_field), "Parents with modified children are copied"
+ end
+
+ it "deletes nodes with DELETE_NODE" do
+ before_query = <<-GRAPHQL.chop
+query {
+ f1 {
+ f2(deleteMe: 1) {
+ f3(c1: {deleteMe: {c2: 2}})
+ f4(c2: [{keepMe: 1}, {deleteMe: 2}, {keepMe: 3}])
+ }
+ }
+}
+GRAPHQL
+
+ after_query = <<-GRAPHQL.chop
+query {
+ f1 {
+ f2 {
+ f3(c1: {})
+ f4(c2: [{keepMe: 1}, {}, {keepMe: 3}])
+ }
+ }
+}
+GRAPHQL
+
+ document, new_document = get_result(before_query)
+ assert_equal before_query, document.to_query_string
+ assert_equal after_query, new_document.to_query_string
+ end
+
+ it "Deletes from lists" do
+ before_query = <<-GRAPHQL.chop
+query {
+ f1(arg1: [{a: 1}, {delete: 1, me: 2}, {b: 2}])
+}
+GRAPHQL
+
+ after_query = <<-GRAPHQL.chop
+query {
+ f1(arg1: [{a: 1}, {b: 2}])
+}
+GRAPHQL
+
+ document, new_document = get_result(before_query)
+ assert_equal before_query, document.to_query_string
+ assert_equal after_query, new_document.to_query_string
+ end
+
+ it "can add children" do
+ before_query = <<-GRAPHQL.chop
+query {
+ addFields
+ anotherAddition
+}
+GRAPHQL
+
+ after_query = <<-GRAPHQL.chop
+query {
+ addFields {
+ addedChild
+ }
+ anotherAddition(addedArgument: 1) @doStuff(addedArgument2: 2)
+}
+GRAPHQL
+
+ document, new_document = get_result(before_query)
+ assert_equal before_query, document.to_query_string
+ assert_equal after_query, new_document.to_query_string
+ end
+ end
end
diff --git a/spec/graphql/schema/argument_spec.rb b/spec/graphql/schema/argument_spec.rb
index 8720e3d889..7c0947cb38 100644
--- a/spec/graphql/schema/argument_spec.rb
+++ b/spec/graphql/schema/argument_spec.rb
@@ -13,6 +13,15 @@ class Query < GraphQL::Schema::Object
argument :aliased_arg, String, required: false, as: :renamed
argument :prepared_arg, Int, required: false, prepare: :multiply
+ argument :prepared_by_proc_arg, Int, required: false, prepare: ->(val, context) { context[:multiply_by] * val }
+
+ class Multiply
+ def call(val, context)
+ context[:multiply_by] * val
+ end
+ end
+
+ argument :prepared_by_callable_arg, Int, required: false, prepare: Multiply.new
end
def field(**args)
@@ -93,5 +102,23 @@ class Schema < GraphQL::Schema
# Make sure it's getting the renamed symbol:
assert_equal '{:prepared_arg=>15}', res["data"]["field"]
end
+ it "calls the method on the provided Proc" do
+ query_str = <<-GRAPHQL
+ { field(preparedByProcArg: 5) }
+ GRAPHQL
+
+ res = SchemaArgumentTest::Schema.execute(query_str, context: {multiply_by: 3})
+ # Make sure it's getting the renamed symbol:
+ assert_equal '{:prepared_by_proc_arg=>15}', res["data"]["field"]
+ end
+ it "calls the method on the provided callable object" do
+ query_str = <<-GRAPHQL
+ { field(preparedByCallableArg: 5) }
+ GRAPHQL
+
+ res = SchemaArgumentTest::Schema.execute(query_str, context: {multiply_by: 3})
+ # Make sure it's getting the renamed symbol:
+ assert_equal '{:prepared_by_callable_arg=>15}', res["data"]["field"]
+ end
end
end
diff --git a/spec/graphql/schema/field_extension_spec.rb b/spec/graphql/schema/field_extension_spec.rb
new file mode 100644
index 0000000000..d96c571078
--- /dev/null
+++ b/spec/graphql/schema/field_extension_spec.rb
@@ -0,0 +1,90 @@
+# frozen_string_literal: true
+require "spec_helper"
+
+describe GraphQL::Schema::FieldExtension do
+ module FilterTestSchema
+ class DoubleFilter < GraphQL::Schema::FieldExtension
+ def after_resolve(object:, value:, arguments:, context:, memo:)
+ value * 2
+ end
+ end
+
+ class MultiplyByOption < GraphQL::Schema::FieldExtension
+ def after_resolve(object:, value:, arguments:, context:, memo:)
+ value * options[:factor]
+ end
+ end
+
+ class MultiplyByArgument < GraphQL::Schema::FieldExtension
+ def apply
+ field.argument(:factor, Integer, required: true)
+ end
+
+ def before_resolve(object:, arguments:, context:)
+ factor = arguments.delete(:factor)
+ yield(object, arguments, factor)
+ end
+
+ def after_resolve(object:, value:, arguments:, context:, memo:)
+ value * memo
+ end
+ end
+
+ class BaseObject < GraphQL::Schema::Object
+ end
+
+ class Query < BaseObject
+ field :doubled, Integer, null: false, method: :pass_thru do
+ extension(DoubleFilter)
+ argument :input, Integer, required: true
+ end
+
+ def pass_thru(input:)
+ input # return it as-is, it will be modified by extensions
+ end
+
+ field :trippled_by_option, Integer, null: false, method: :pass_thru do
+ extension(MultiplyByOption, factor: 3)
+ argument :input, Integer, required: true
+ end
+
+ field :multiply_input, Integer, null: false, method: :pass_thru, extensions: [MultiplyByArgument] do
+ argument :input, Integer, required: true
+ end
+ end
+
+ class Schema < GraphQL::Schema
+ query(Query)
+ end
+ end
+
+ def exec_query(query_str, **kwargs)
+ FilterTestSchema::Schema.execute(query_str, **kwargs)
+ end
+
+ describe "reading" do
+ it "has a reader method" do
+ field = FilterTestSchema::Query.fields["multiplyInput"]
+ assert_equal 1, field.extensions.size
+ assert_instance_of FilterTestSchema::MultiplyByArgument, field.extensions.first
+ end
+ end
+
+ describe "modifying return values" do
+ it "returns the modified value" do
+ res = exec_query("{ doubled(input: 5) }")
+ assert_equal 10, res["data"]["doubled"]
+ end
+
+ it "has access to config options" do
+ # The factor of three came from an option
+ res = exec_query("{ trippledByOption(input: 4) }")
+ assert_equal 12, res["data"]["trippledByOption"]
+ end
+
+ it "can hide arguments from resolve methods" do
+ res = exec_query("{ multiplyInput(input: 3, factor: 5) }")
+ assert_equal 15, res["data"]["multiplyInput"]
+ end
+ end
+end
diff --git a/spec/graphql/schema/field_spec.rb b/spec/graphql/schema/field_spec.rb
index 2d420ec9d4..d0beb56be6 100644
--- a/spec/graphql/schema/field_spec.rb
+++ b/spec/graphql/schema/field_spec.rb
@@ -212,14 +212,12 @@
end
it "makes a suggestion when the type is false" do
- thing = Class.new(GraphQL::Schema::Object) do
- graphql_name "Thing"
- # False might come from an invalid `!`
- field :stuff, false, null: false
- end
-
err = assert_raises ArgumentError do
- thing.fields["stuff"].type
+ Class.new(GraphQL::Schema::Object) do
+ graphql_name "Thing"
+ # False might come from an invalid `!`
+ field :stuff, false, null: false
+ end
end
assert_includes err.message, "Thing.stuff"
diff --git a/spec/graphql/static_validation/type_stack_spec.rb b/spec/graphql/static_validation/type_stack_spec.rb
index 6ce5a3478a..a749fe01ca 100644
--- a/spec/graphql/static_validation/type_stack_spec.rb
+++ b/spec/graphql/static_validation/type_stack_spec.rb
@@ -1,19 +1,6 @@
# frozen_string_literal: true
require "spec_helper"
-class TypeCheckValidator
- def self.checks
- @checks ||= []
- end
-
- def validate(context)
- self.class.checks.clear
- context.visitor[GraphQL::Language::Nodes::Field] << ->(node, parent) {
- self.class.checks << context.object_types.map {|t| t.name || t.kind.name }
- }
- end
-end
-
describe GraphQL::StaticValidation::TypeStack do
let(:query_string) {%|
query getCheese {
@@ -22,17 +9,21 @@ def validate(context)
fragment edibleFields on Edible { fatContent @skip(if: false)}
|}
- let(:validator) { GraphQL::StaticValidation::Validator.new(schema: Dummy::Schema, rules: [TypeCheckValidator]) }
- let(:query) { GraphQL::Query.new(Dummy::Schema, query_string) }
-
-
it "stores up types" do
- validator.validate(query)
+ document = GraphQL.parse(query_string)
+ visitor = GraphQL::Language::Visitor.new(document)
+ type_stack = GraphQL::StaticValidation::TypeStack.new(Dummy::Schema, visitor)
+ checks = []
+ visitor[GraphQL::Language::Nodes::Field].enter << ->(node, parent) {
+ checks << type_stack.object_types.map {|t| t.name || t.kind.name }
+ }
+ visitor.visit
+
expected = [
["Query", "Cheese"],
["Query", "Cheese", "NON_NULL"],
["Edible", "NON_NULL"]
]
- assert_equal(expected, TypeCheckValidator.checks)
+ assert_equal(expected, checks)
end
end
diff --git a/spec/graphql/static_validation/validator_spec.rb b/spec/graphql/static_validation/validator_spec.rb
index 3982d824f1..8a41dc61f6 100644
--- a/spec/graphql/static_validation/validator_spec.rb
+++ b/spec/graphql/static_validation/validator_spec.rb
@@ -153,4 +153,50 @@
end
end
end
+
+ describe "Custom ruleset" do
+ let(:query_string) { "
+ fragment Thing on Cheese {
+ __typename
+ similarCheese(source: COW)
+ }
+ "
+ }
+
+ let(:rules) {
+ # This is from graphql-client, eg
+ # https://github.com/github/graphql-client/blob/c86fc05d7eba2370452592bb93572caced4123af/lib/graphql/client.rb#L168
+ GraphQL::StaticValidation::ALL_RULES - [
+ GraphQL::StaticValidation::FragmentsAreUsed,
+ GraphQL::StaticValidation::FieldsHaveAppropriateSelections
+ ]
+ }
+ let(:validator) { GraphQL::StaticValidation::Validator.new(schema: Dummy::Schema, rules: rules) }
+
+ it "runs the specified rules" do
+ assert_equal 0, errors.size
+ end
+
+ describe "With a legacy-style rule" do
+ # GraphQL-Pro's operation store uses this
+ class ValidatorSpecLegacyRule
+ include GraphQL::StaticValidation::Message::MessageHelper
+ def validate(ctx)
+ ctx.visitor[GraphQL::Language::Nodes::OperationDefinition] << ->(n, _p) {
+ ctx.errors << message("Busted!", n, context: ctx)
+ }
+ end
+ end
+
+ let(:rules) {
+ GraphQL::StaticValidation::ALL_RULES + [ValidatorSpecLegacyRule]
+ }
+
+ let(:query_string) { "{ __typename }"}
+
+ it "runs the rule" do
+ assert_equal ["Busted!"], errors.map { |e| e["message"] }
+ end
+ end
+ end
end
diff --git a/spec/graphql/types/iso_8601_date_time_spec.rb b/spec/graphql/types/iso_8601_date_time_spec.rb
index e154b7d603..40d8157fec 100644
--- a/spec/graphql/types/iso_8601_date_time_spec.rb
+++ b/spec/graphql/types/iso_8601_date_time_spec.rb
@@ -90,6 +90,31 @@ def parse_date(date_str)
full_res = DateTimeTest::Schema.execute(query_str, variables: { date: date_str })
assert_equal date_str, full_res["data"]["parseDate"]["iso8601"]
end
+
+ describe "with time_precision = 3 (i.e. 'with milliseconds')" do
+ before do
+ @tp = GraphQL::Types::ISO8601DateTime.time_precision
+ GraphQL::Types::ISO8601DateTime.time_precision = 3
+ end
+
+ after do
+ GraphQL::Types::ISO8601DateTime.time_precision = @tp
+ end
+
+ it "returns a string" do
+ query_str = <<-GRAPHQL
+ query($date: ISO8601DateTime!){
+ parseDate(date: $date) {
+ iso8601
+ }
+ }
+ GRAPHQL
+
+ date_str = "2010-02-02T22:30:30.123-06:00"
+ full_res = DateTimeTest::Schema.execute(query_str, variables: { date: date_str })
+ assert_equal date_str, full_res["data"]["parseDate"]["iso8601"]
+ end
+ end
end
describe "structure" do
diff --git a/spec/integration/mongoid/star_trek/schema.rb b/spec/integration/mongoid/star_trek/schema.rb
index 48ae8e918d..b7afb87271 100644
--- a/spec/integration/mongoid/star_trek/schema.rb
+++ b/spec/integration/mongoid/star_trek/schema.rb
@@ -79,18 +79,18 @@ def field_name
end
end
- # Example of GraphQL::Function used with the connection helper:
- class ShipsWithMaxPageSize < GraphQL::Function
- argument :nameIncludes, GraphQL::STRING_TYPE
- def call(obj, args, ctx)
- all_ships = obj.ships.map { |ship_id| StarTrek::DATA["Ship"][ship_id] }
- if args[:nameIncludes]
- all_ships = all_ships.select { |ship| ship.name.include?(args[:nameIncludes])}
+
+ class ShipsWithMaxPageSize < GraphQL::Schema::Resolver
+ argument :name_includes, String, required: false
+ type Ship.connection_type, null: true
+
+ def resolve(name_includes: nil)
+ all_ships = object.ships.map { |ship_id| StarTrek::DATA["Ship"][ship_id] }
+ if name_includes
+ all_ships = all_ships.select { |ship| ship.name.include?(name_includes)}
end
all_ships
end
-
- type Ship.connection_type
end
class ShipConnectionWithParentType < GraphQL::Types::Relay::BaseConnection
@@ -107,10 +107,14 @@ class Faction < GraphQL::Schema::Object
field :id, ID, null: false, resolve: GraphQL::Relay::GlobalIdResolve.new(type: Faction)
field :name, String, null: true
- field :ships, ShipConnectionWithParentType, connection: true, max_page_size: 1000, null: true, resolve: ->(obj, args, ctx) {
- all_ships = obj.ships.map {|ship_id| StarTrek::DATA["Ship"][ship_id] }
- if args[:nameIncludes]
- case args[:nameIncludes]
+ field :ships, ShipConnectionWithParentType, connection: true, max_page_size: 1000, null: true do
+ argument :name_includes, String, required: false
+ end
+
+ def ships(name_includes: nil)
+ all_ships = object.ships.map {|ship_id| StarTrek::DATA["Ship"][ship_id] }
+ if name_includes
+ case name_includes
when "error"
all_ships = GraphQL::ExecutionError.new("error from within connection")
when "raisedError"
@@ -125,25 +129,24 @@ class Faction < GraphQL::Schema::Object
prev_all_ships = all_ships
all_ships = LazyWrapper.new { prev_all_ships }
else
- all_ships = all_ships.select { |ship| ship.name.include?(args[:nameIncludes])}
+ all_ships = all_ships.select { |ship| ship.name.include?(name_includes)}
end
end
all_ships
- } do
- # You can define arguments here and use them in the connection
- argument :nameIncludes, String, required: false
end
- field :shipsWithMaxPageSize, "Ships with max page size", max_page_size: 2, function: ShipsWithMaxPageSize.new
+ field :shipsWithMaxPageSize, "Ships with max page size", max_page_size: 2, resolver: ShipsWithMaxPageSize
+
+ field :bases, BaseConnectionWithTotalCountType, null: true, connection: true do
+ argument :name_includes, String, required: false
+ end
- field :bases, BaseConnectionWithTotalCountType, null: true, connection: true, resolve: ->(obj, args, ctx) {
- all_bases = obj.bases
- if args[:nameIncludes]
- all_bases = all_bases.where(name: Regexp.new(args[:nameIncludes]))
+ def bases(name_includes: nil)
+ all_bases = object.bases
+ if name_includes
+ all_bases = all_bases.where(name: Regexp.new(name_includes))
end
all_bases
- } do
- argument :nameIncludes, String, required: false
end
field :basesClone, BaseType.connection_type, null: true
@@ -158,13 +161,24 @@ def bases_by_name(order: nil)
end
end
- field :basesWithMaxLimitRelation, BaseType.connection_type, null: true, max_page_size: 2, resolve: Proc.new { Base.all}
- field :basesWithMaxLimitArray, BaseType.connection_type, null: true, max_page_size: 2, resolve: Proc.new { Base.all.to_a }
- field :basesWithDefaultMaxLimitRelation, BaseType.connection_type, null: true, resolve: Proc.new { Base.all }
- field :basesWithDefaultMaxLimitArray, BaseType.connection_type, null: true, resolve: Proc.new { Base.all.to_a }
- field :basesWithLargeMaxLimitRelation, BaseType.connection_type, null: true, max_page_size: 1000, resolve: Proc.new { Base.all }
+ def all_bases
+ Base.all
+ end
- field :basesWithCustomEdge, CustomEdgeBaseConnectionType, null: true, connection: true, resolve: ->(o, a, c) { LazyNodesWrapper.new(o.bases) }
+ def all_bases_array
+ all_bases.to_a
+ end
+
+ field :basesWithMaxLimitRelation, BaseType.connection_type, null: true, max_page_size: 2, method: :all_bases
+ field :basesWithMaxLimitArray, BaseType.connection_type, null: true, max_page_size: 2, method: :all_bases_array
+ field :basesWithDefaultMaxLimitRelation, BaseType.connection_type, null: true, method: :all_bases
+ field :basesWithDefaultMaxLimitArray, BaseType.connection_type, null: true, method: :all_bases_array
+ field :basesWithLargeMaxLimitRelation, BaseType.connection_type, null: true, max_page_size: 1000, method: :all_bases
+
+ field :basesWithCustomEdge, CustomEdgeBaseConnectionType, null: true, connection: true
+ def bases_with_custom_edge
+ LazyNodesWrapper.new(object.bases)
+ end
end
class IntroduceShipMutation < GraphQL::Schema::RelayClassicMutation
@@ -302,7 +316,9 @@ class QueryType < GraphQL::Schema::Object
field :largestBase, BaseType, null: true, resolve: ->(obj, args, ctx) { Base.find(3) }
- field :newestBasesGroupedByFaction, BaseType.connection_type, null: true, resolve: ->(obj, args, ctx) {
+ field :newestBasesGroupedByFaction, BaseType.connection_type, null: true
+
+ def newest_bases_grouped_by_faction
agg = Base.collection.aggregate([{
"$group" => {
"_id" => "$faction_id",
@@ -312,11 +328,13 @@ class QueryType < GraphQL::Schema::Object
Base.
in(id: agg.map { |doc| doc['baseId'] }).
order_by(faction_id: -1)
- }
+ end
+
+ field :basesWithNullName, BaseType.connection_type, null: false
- field :basesWithNullName, BaseType.connection_type, null: false, resolve: ->(obj, args, ctx) {
+ def bases_with_null_name
[OpenStruct.new(id: nil)]
- }
+ end
field :node, field: GraphQL::Relay::Node.field
diff --git a/spec/integration/rails/graphql/relay/connection_instrumentation_spec.rb b/spec/integration/rails/graphql/relay/connection_instrumentation_spec.rb
index aaf2ec3a75..caea27153e 100644
--- a/spec/integration/rails/graphql/relay/connection_instrumentation_spec.rb
+++ b/spec/integration/rails/graphql/relay/connection_instrumentation_spec.rb
@@ -11,11 +11,6 @@
assert_equal ["tests"], test_type.fields.keys
end
- it "keeps a reference to the function" do
- conn_field = StarWars::Faction.graphql_definition.fields["shipsWithMaxPageSize"]
- assert_instance_of StarWars::ShipsWithMaxPageSize, conn_field.function
- end
-
let(:build_schema) {
test_type = nil
@@ -73,9 +68,8 @@
GRAPHQL
ctx = { before_built_ins: [], after_built_ins: [] }
star_wars_query(query_str, {}, context: ctx)
- # The second item is different here:
- # Before the object is wrapped in a connection, the instrumentation sees `Array`
- assert_equal ["StarWars::FactionRecord", "Array", "GraphQL::Relay::ArrayConnection"], ctx[:before_built_ins]
+ # These are data classes, later they're wrapped with type proxies
+ assert_equal ["StarWars::FactionRecord", "GraphQL::Relay::ArrayConnection", "GraphQL::Relay::ArrayConnection"], ctx[:before_built_ins]
# After the object is wrapped in a connection, it sees the connection object
assert_equal ["StarWars::Faction", "StarWars::ShipConnectionWithParentType", "GraphQL::Types::Relay::PageInfo"], ctx[:after_built_ins]
end
diff --git a/spec/integration/rails/graphql/relay/connection_resolve_spec.rb b/spec/integration/rails/graphql/relay/connection_resolve_spec.rb
index 17efd01fe1..2b8c1c0bc8 100644
--- a/spec/integration/rails/graphql/relay/connection_resolve_spec.rb
+++ b/spec/integration/rails/graphql/relay/connection_resolve_spec.rb
@@ -53,6 +53,22 @@
end
end
+ describe "when a resolver is used" do
+ it "returns the items with the correct parent" do
+ resolver_query_str = <<-GRAPHQL
+ {
+ rebels {
+ shipsByResolver {
+ parentClassName
+ }
+ }
+ }
+ GRAPHQL
+ result = star_wars_query(resolver_query_str)
+ assert_equal "StarWars::FactionRecord", result["data"]["rebels"]["shipsByResolver"]["parentClassName"]
+ end
+ end
+
describe "when nil is returned" do
it "becomes null" do
result = star_wars_query(query_string, { "name" => "null" })
diff --git a/spec/support/star_wars/schema.rb b/spec/support/star_wars/schema.rb
index d86bbf92d2..4c9de1ad12 100644
--- a/spec/support/star_wars/schema.rb
+++ b/spec/support/star_wars/schema.rb
@@ -86,18 +86,17 @@ def field_name
end
end
- # Example of GraphQL::Function used with the connection helper:
- class ShipsWithMaxPageSize < GraphQL::Function
- argument :nameIncludes, GraphQL::STRING_TYPE
- def call(obj, args, ctx)
- all_ships = obj.ships.map { |ship_id| StarWars::DATA["Ship"][ship_id] }
- if args[:nameIncludes]
- all_ships = all_ships.select { |ship| ship.name.include?(args[:nameIncludes])}
+ class ShipsWithMaxPageSize < GraphQL::Schema::Resolver
+ argument :name_includes, String, required: false
+ type Ship.connection_type, null: true
+
+ def resolve(name_includes: nil)
+ all_ships = object.ships.map { |ship_id| StarWars::DATA["Ship"][ship_id] }
+ if name_includes
+ all_ships = all_ships.select { |ship| ship.name.include?(name_includes)}
end
all_ships
end
-
- type Ship.connection_type
end
class ShipConnectionWithParentType < GraphQL::Types::Relay::BaseConnection
@@ -109,15 +108,29 @@ def parent_class_name
end
end
+ class ShipsByResolver < GraphQL::Schema::Resolver
+ type ShipConnectionWithParentType, null: false
+
+ def resolve
+ object.ships.map { |ship_id| StarWars::DATA["Ship"][ship_id] }
+ end
+ end
+
class Faction < GraphQL::Schema::Object
implements GraphQL::Relay::Node.interface
field :id, ID, null: false, resolve: GraphQL::Relay::GlobalIdResolve.new(type: Faction)
field :name, String, null: true
- field :ships, ShipConnectionWithParentType, connection: true, max_page_size: 1000, null: true, resolve: ->(obj, args, ctx) {
- all_ships = obj.ships.map {|ship_id| StarWars::DATA["Ship"][ship_id] }
- if args[:nameIncludes]
- case args[:nameIncludes]
+ field :ships, ShipConnectionWithParentType, connection: true, max_page_size: 1000, null: true do
+ argument :name_includes, String, required: false
+ end
+
+ field :shipsByResolver, resolver: ShipsByResolver, connection: true
+
+ def ships(name_includes: nil)
+ all_ships = object.ships.map {|ship_id| StarWars::DATA["Ship"][ship_id] }
+ if name_includes
+ case name_includes
when "error"
all_ships = GraphQL::ExecutionError.new("error from within connection")
when "raisedError"
@@ -132,25 +145,24 @@ class Faction < GraphQL::Schema::Object
prev_all_ships = all_ships
all_ships = LazyWrapper.new { prev_all_ships }
else
- all_ships = all_ships.select { |ship| ship.name.include?(args[:nameIncludes])}
+ all_ships = all_ships.select { |ship| ship.name.include?(name_includes)}
end
end
all_ships
- } do
- # You can define arguments here and use them in the connection
- argument :nameIncludes, String, required: false
end
- field :shipsWithMaxPageSize, "Ships with max page size", max_page_size: 2, function: ShipsWithMaxPageSize.new
+ field :shipsWithMaxPageSize, "Ships with max page size", max_page_size: 2, resolver: ShipsWithMaxPageSize
+
+ field :bases, BasesConnectionWithTotalCountType, null: true, connection: true do
+ argument :name_includes, String, required: false
+ end
- field :bases, BasesConnectionWithTotalCountType, null: true, connection: true, resolve: ->(obj, args, ctx) {
- all_bases = Base.where(id: obj.bases)
- if args[:nameIncludes]
- all_bases = all_bases.where("name LIKE ?", "%#{args[:nameIncludes]}%")
+ def bases(name_includes: nil)
+ all_bases = Base.where(id: object.bases)
+ if name_includes
+ all_bases = all_bases.where("name LIKE ?", "%#{name_includes}%")
end
all_bases
- } do
- argument :nameIncludes, String, required: false
end
field :basesClone, BaseConnection, null: true
@@ -165,12 +177,20 @@ def bases_by_name(order: nil)
end
end
- field :basesWithMaxLimitRelation, BaseConnection, null: true, max_page_size: 2, resolve: Proc.new { Base.all}
- field :basesWithMaxLimitArray, BaseConnection, null: true, max_page_size: 2, resolve: Proc.new { Base.all.to_a }
- field :basesWithDefaultMaxLimitRelation, BaseConnection, null: true, resolve: Proc.new { Base.all }
- field :basesWithDefaultMaxLimitArray, BaseConnection, null: true, resolve: Proc.new { Base.all.to_a }
- field :basesWithLargeMaxLimitRelation, BaseConnection, null: true, max_page_size: 1000, resolve: Proc.new { Base.all }
- field :basesWithoutNodes, BaseConnectionWithoutNodes, null: true, resolve: Proc.new { Base.all.to_a }
+ def all_bases
+ Base.all
+ end
+
+ def all_bases_array
+ all_bases.to_a
+ end
+
+ field :basesWithMaxLimitRelation, BaseConnection, null: true, max_page_size: 2, method: :all_bases
+ field :basesWithMaxLimitArray, BaseConnection, null: true, max_page_size: 2, method: :all_bases_array
+ field :basesWithDefaultMaxLimitRelation, BaseConnection, null: true, method: :all_bases
+ field :basesWithDefaultMaxLimitArray, BaseConnection, null: true, method: :all_bases_array
+ field :basesWithLargeMaxLimitRelation, BaseConnection, null: true, max_page_size: 1000, method: :all_bases
+ field :basesWithoutNodes, BaseConnectionWithoutNodes, null: true, method: :all_bases_array
field :basesAsSequelDataset, BasesConnectionWithTotalCountType, null: true, connection: true, max_page_size: 1000 do
argument :nameIncludes, String, required: false
@@ -184,7 +204,11 @@ def bases_as_sequel_dataset(name_includes: nil)
all_bases
end
- field :basesWithCustomEdge, CustomEdgeBaseConnectionType, null: true, connection: true, resolve: ->(o, a, c) { LazyNodesWrapper.new(o.bases) }
+ field :basesWithCustomEdge, CustomEdgeBaseConnectionType, null: true, connection: true, method: :lazy_bases
+
+ def lazy_bases
+ LazyNodesWrapper.new(object.bases)
+ end
end
class IntroduceShipMutation < GraphQL::Schema::RelayClassicMutation
@@ -320,16 +344,20 @@ class QueryType < GraphQL::Schema::Object
field :largestBase, BaseType, null: true, resolve: ->(obj, args, ctx) { Base.find(3) }
- field :newestBasesGroupedByFaction, BaseConnection, null: true, resolve: ->(obj, args, ctx) {
+ field :newestBasesGroupedByFaction, BaseConnection, null: true
+
+ def newest_bases_grouped_by_faction
Base
.having('id in (select max(id) from bases group by faction_id)')
.group(:id)
.order('faction_id desc')
- }
+ end
- field :basesWithNullName, BaseConnection, null: false, resolve: ->(obj, args, ctx) {
+ field :basesWithNullName, BaseConnection, null: false
+
+ def bases_with_null_name
[OpenStruct.new(id: nil)]
- }
+ end
field :node, field: GraphQL::Relay::Node.field
diff --git a/spec/support/static_validation_helpers.rb b/spec/support/static_validation_helpers.rb
index c620dedfb8..7d510e8b4d 100644
--- a/spec/support/static_validation_helpers.rb
+++ b/spec/support/static_validation_helpers.rb
@@ -12,10 +12,12 @@
# end
module StaticValidationHelpers
def errors
- target_schema = schema
- validator = GraphQL::StaticValidation::Validator.new(schema: target_schema)
- query = GraphQL::Query.new(target_schema, query_string)
- validator.validate(query)[:errors].map(&:to_h)
+ @errors ||= begin
+ target_schema = schema
+ validator = GraphQL::StaticValidation::Validator.new(schema: target_schema)
+ query = GraphQL::Query.new(target_schema, query_string)
+ validator.validate(query)[:errors].map(&:to_h)
+ end
end
def error_messages