From eb8f3d14336f39fd1b4a4b096db3df8b679e0b2f Mon Sep 17 00:00:00 2001 From: Andrew Sisley Date: Mon, 12 Sep 2022 16:55:53 -0400 Subject: [PATCH] Add null support Code is mostly copy-pasted from the PR https://github.com/graphql-go/graphql/pull/536 - main difference is that I haven't copied over a couple of new tests, and that I permitted null within arrays (unanswered question in original PR, and I see no reason for it too). --- language/ast/node.go | 1 + language/ast/values.go | 28 +++++++++++++++ language/kinds/kinds.go | 1 + language/parser/parser.go | 14 ++++++-- language/parser/parser_test.go | 10 +----- language/printer/printer.go | 9 +++++ rules.go | 6 +++- values.go | 63 +++++++++++++++++++++++++--------- variables_test.go | 6 ++-- 9 files changed, 106 insertions(+), 32 deletions(-) diff --git a/language/ast/node.go b/language/ast/node.go index cd63a0fc..d7cc3a80 100644 --- a/language/ast/node.go +++ b/language/ast/node.go @@ -22,6 +22,7 @@ var _ Node = (*IntValue)(nil) var _ Node = (*FloatValue)(nil) var _ Node = (*StringValue)(nil) var _ Node = (*BooleanValue)(nil) +var _ Node = (*NullValue)(nil) var _ Node = (*EnumValue)(nil) var _ Node = (*ListValue)(nil) var _ Node = (*ObjectValue)(nil) diff --git a/language/ast/values.go b/language/ast/values.go index 6c3c8864..36ac1052 100644 --- a/language/ast/values.go +++ b/language/ast/values.go @@ -16,6 +16,7 @@ var _ Value = (*IntValue)(nil) var _ Value = (*FloatValue)(nil) var _ Value = (*StringValue)(nil) var _ Value = (*BooleanValue)(nil) +var _ Value = (*NullValue)(nil) var _ Value = (*EnumValue)(nil) var _ Value = (*ListValue)(nil) var _ Value = (*ObjectValue)(nil) @@ -172,6 +173,33 @@ func (v *BooleanValue) GetValue() interface{} { return v.Value } +type NullValue struct { + Kind string + Loc *Location + Value interface{} +} + +func NewNullValue(v *NullValue) *NullValue { + + return &NullValue{ + Kind: kinds.NullValue, + Loc: v.Loc, + Value: v.Value, + } +} + +func (v *NullValue) GetKind() string { + return v.Kind +} + +func (v *NullValue) GetLoc() *Location { + return v.Loc +} + +func (v *NullValue) GetValue() interface{} { + return nil +} + // EnumValue implements Node, Value type EnumValue struct { Kind string diff --git a/language/kinds/kinds.go b/language/kinds/kinds.go index 40bc994e..370ba6fe 100644 --- a/language/kinds/kinds.go +++ b/language/kinds/kinds.go @@ -23,6 +23,7 @@ const ( FloatValue = "FloatValue" StringValue = "StringValue" BooleanValue = "BooleanValue" + NullValue = "NullValue" EnumValue = "EnumValue" ListValue = "ListValue" ObjectValue = "ObjectValue" diff --git a/language/parser/parser.go b/language/parser/parser.go index ea7c7f0b..238dc4a6 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -614,15 +614,23 @@ func parseValueLiteral(parser *Parser, isConst bool) (ast.Value, error) { Value: value, Loc: loc(parser, token.Start), }), nil - } else if token.Value != "null" { + } else if token.Value == "null" { if err := advance(parser); err != nil { return nil, err } - return ast.NewEnumValue(&ast.EnumValue{ - Value: token.Value, + return ast.NewNullValue(&ast.NullValue{ + Value: nil, Loc: loc(parser, token.Start), }), nil } + + if err := advance(parser); err != nil { + return nil, err + } + return ast.NewEnumValue(&ast.EnumValue{ + Value: token.Value, + Loc: loc(parser, token.Start), + }), nil case lexer.DOLLAR: if !isConst { return parseVariable(parser) diff --git a/language/parser/parser_test.go b/language/parser/parser_test.go index 3cc4253a..ec260ea2 100644 --- a/language/parser/parser_test.go +++ b/language/parser/parser_test.go @@ -183,15 +183,6 @@ func TestDoesNotAcceptFragmentsSpreadOfOn(t *testing.T) { testErrorMessage(t, test) } -func TestDoesNotAllowNullAsValue(t *testing.T) { - test := errorMessageTest{ - `{ fieldWithNullableStringInput(input: null) }'`, - `Syntax Error GraphQL (1:39) Unexpected Name "null"`, - false, - } - testErrorMessage(t, test) -} - func TestParsesMultiByteCharacters_Unicode(t *testing.T) { doc := ` @@ -367,6 +358,7 @@ func TestAllowsNonKeywordsAnywhereNameIsAllowed(t *testing.T) { "subscription", "true", "false", + "null", } for _, keyword := range nonKeywords { fragmentName := keyword diff --git a/language/printer/printer.go b/language/printer/printer.go index ac771ba6..bbc4c440 100644 --- a/language/printer/printer.go +++ b/language/printer/printer.go @@ -388,6 +388,15 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ } return visitor.ActionNoChange, nil }, + "NullValue": func(p visitor.VisitFuncParams) (string, interface{}) { + switch node := p.Node.(type) { + case *ast.NullValue: + return visitor.ActionUpdate, fmt.Sprintf("%v", node.Value) + case map[string]interface{}: + return visitor.ActionUpdate, getMapValueString(node, "Value") + } + return visitor.ActionNoChange, nil + }, "EnumValue": func(p visitor.VisitFuncParams) (string, interface{}) { switch node := p.Node.(type) { case *ast.EnumValue: diff --git a/rules.go b/rules.go index ae0c75b9..f2d7315b 100644 --- a/rules.go +++ b/rules.go @@ -1730,6 +1730,10 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { return true, nil } + if valueAST.GetKind() == kinds.NullValue { + return true, nil + } + // This function only tests literals, and assumes variables will provide // values of the correct type. if valueAST.GetKind() == kinds.Variable { @@ -1742,7 +1746,7 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { if e := ttype.Error(); e != nil { return false, []string{e.Error()} } - if valueAST == nil { + if valueAST == nil || valueAST.GetKind() == kinds.NullValue { if ttype.OfType.Name() != "" { return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())} } diff --git a/values.go b/values.go index 06c08af6..608490ee 100644 --- a/values.go +++ b/values.go @@ -14,6 +14,9 @@ import ( "github.com/graphql-go/graphql/language/printer" ) +// Used to detect the difference between a "null" literal and not present +type nullValue struct{} + // Prepares an object map of variableValues of the correct type based on the // provided variable definitions and arbitrary input. If the input cannot be // parsed to match the variable definitions, a GraphQLError will be returned. @@ -27,7 +30,7 @@ func getVariableValues( continue } varName := defAST.Variable.Name.Value - if varValue, err := getVariableValue(schema, defAST, inputs[varName]); err != nil { + if varValue, err := getVariableValue(schema, defAST, getValueOrNull(inputs, varName)); err != nil { return values, err } else { values[varName] = varValue @@ -36,6 +39,25 @@ func getVariableValues( return values, nil } +func getValueOrNull(values map[string]interface{}, name string) interface{} { + if tmp, ok := values[name]; ok { // Is present + if tmp == nil { + return nullValue{} // Null value + } else { + return tmp + } + } + return nil // Not present +} + +func addValueOrNull(values map[string]interface{}, name string, value interface{}) { + if _, ok := value.(nullValue); ok { // Null value + values[name] = nil + } else if !isNullish(value) { // Not present + values[name] = value + } +} + // Prepares an object map of argument values given a list of argument // definitions and list of argument AST nodes. func getArgumentValues( @@ -60,9 +82,7 @@ func getArgumentValues( if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) { tmp = argDef.DefaultValue } - if !isNullish(tmp) { - results[argDef.PrivateName] = tmp - } + addValueOrNull(results, argDef.PrivateName, tmp) } return results } @@ -97,7 +117,7 @@ func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, inpu } return coerceValue(ttype, input), nil } - if isNullish(input) { + if _, ok := input.(nullValue); ok || isNullish(input) { return "", gqlerrors.NewError( fmt.Sprintf(`Variable "$%v" of required type `+ `"%v" was not provided.`, variable.Name.Value, printer.Print(definitionAST.Type)), @@ -134,6 +154,11 @@ func coerceValue(ttype Input, value interface{}) interface{} { if isNullish(value) { return nil } + + if _, ok := value.(nullValue); ok { + return nullValue{} + } + switch ttype := ttype.(type) { case *NonNull: return coerceValue(ttype.OfType, value) @@ -156,13 +181,11 @@ func coerceValue(ttype Input, value interface{}) interface{} { } for name, field := range ttype.Fields() { - fieldValue := coerceValue(field.Type, valueMap[name]) + fieldValue := coerceValue(field.Type, getValueOrNull(valueMap, name)) if isNullish(fieldValue) { fieldValue = field.DefaultValue } - if !isNullish(fieldValue) { - obj[name] = fieldValue - } + addValueOrNull(obj, name, fieldValue) } return obj case *Scalar: @@ -212,7 +235,7 @@ func typeFromAST(schema Schema, inputTypeAST ast.Type) (Type, error) { // accepted for that type. This is primarily useful for validating the // runtime values of query variables. func isValidInputValue(value interface{}, ttype Input) (bool, []string) { - if isNullish(value) { + if _, ok := value.(nullValue); ok || isNullish(value) { if ttype, ok := ttype.(*NonNull); ok { if ttype.OfType.Name() != "" { return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())} @@ -233,9 +256,14 @@ func isValidInputValue(value interface{}, ttype Input) (bool, []string) { messagesReduce := []string{} for i := 0; i < valType.Len(); i++ { val := valType.Index(i).Interface() - _, messages := isValidInputValue(val, ttype.OfType) - for idx, message := range messages { - messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, idx+1, message)) + var messages []string + if _, ok := val.(nullValue); ok { + messages = []string{"Unexpected null value."} + } else { + _, messages = isValidInputValue(val, ttype.OfType) + } + for _, message := range messages { + messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, i+1, message)) } } return (len(messagesReduce) == 0), messagesReduce @@ -352,6 +380,11 @@ func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interfac if valueAST == nil { return nil } + + if valueAST.GetKind() == kinds.NullValue { + return nullValue{} + } + // precedence: value > type if valueAST, ok := valueAST.(*ast.Variable); ok { if valueAST.Name == nil || variables == nil { @@ -398,9 +431,7 @@ func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interfac } else { value = field.DefaultValue } - if !isNullish(value) { - obj[name] = value - } + addValueOrNull(obj, name, value) } return obj case *Scalar: diff --git a/variables_test.go b/variables_test.go index 9dc430df..bcd49e60 100644 --- a/variables_test.go +++ b/variables_test.go @@ -67,7 +67,7 @@ var testNestedInputObject *graphql.InputObject = graphql.NewInputObject(graphql. func inputResolved(p graphql.ResolveParams) (interface{}, error) { input, ok := p.Args["input"] - if !ok { + if !ok || input == nil { return nil, nil } b, err := json.Marshal(input) @@ -1188,7 +1188,7 @@ func TestVariables_ListsAndNullability_DoesNotAllowListOfNonNullsToContainNull(t { Message: `Variable "$input" got invalid value ` + `["A",null,"B"].` + - "\nIn element #1: Expected \"String!\", found null.", + "\nIn element #2: Expected \"String!\", found null.", Locations: []location.SourceLocation{ { Line: 2, Column: 17, @@ -1290,7 +1290,7 @@ func TestVariables_ListsAndNullability_DoesNotAllowNonNullListOfNonNullsToContai { Message: `Variable "$input" got invalid value ` + `["A",null,"B"].` + - "\nIn element #1: Expected \"String!\", found null.", + "\nIn element #2: Expected \"String!\", found null.", Locations: []location.SourceLocation{ { Line: 2, Column: 17,