Skip to content

Commit

Permalink
Coerce variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Scarr committed Jul 25, 2018
1 parent 8098ed8 commit dba1c80
Show file tree
Hide file tree
Showing 9 changed files with 393 additions and 17 deletions.
6 changes: 0 additions & 6 deletions ast/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,4 @@ func TestDefaultValue(t *testing.T) {
value, _ := v.Value(make(map[string]interface{}))
require.Equal(t, int64(99), value)
})

t.Run("returns error when variable has no default", func(t *testing.T) {
v := Value{Raw: "foo", Kind: Variable, VariableDefinition: &VariableDefinition{}}
_, err := v.Value(make(map[string]interface{}))
require.Error(t, err)
})
}
49 changes: 47 additions & 2 deletions gqlerror/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,47 @@ type Location struct {
type List []*Error

func (err *Error) Error() string {
var res bytes.Buffer
if err == nil {
return ""
}
filename, _ := err.Extensions["file"].(string)
if filename == "" {
filename = "input"
}
res.WriteString(filename)

if len(err.Locations) > 0 {
filename += ":" + strconv.Itoa(err.Locations[0].Line)
res.WriteByte(':')
res.WriteString(strconv.Itoa(err.Locations[0].Line))
}

res.WriteString(": ")
if ps := err.pathString(); ps != "" {
res.WriteString(ps)
res.WriteByte(' ')
}

return filename + " " + err.Message
res.WriteString(err.Message)

return res.String()
}

func (err Error) pathString() string {
var str bytes.Buffer
for i, v := range err.Path {

switch v := v.(type) {
case int, int64:
str.WriteString(fmt.Sprintf("[%d]", v))
default:
if i != 0 {
str.WriteByte('.')
}
str.WriteString(fmt.Sprint(v))
}
}
return str.String()
}

func (errs List) Error() string {
Expand All @@ -57,12 +88,26 @@ func (errs List) Error() string {
return buf.String()
}

func WrapPath(path []interface{}, err error) *Error {
return &Error{
Message: err.Error(),
Path: path,
}
}

func Errorf(message string, args ...interface{}) *Error {
return &Error{
Message: fmt.Sprintf(message, args...),
}
}

func ErrorPathf(path []interface{}, message string, args ...interface{}) *Error {
return &Error{
Message: fmt.Sprintf(message, args...),
Path: path,
}
}

func ErrorPosf(pos *ast.Position, message string, args ...interface{}) *Error {
return ErrorLocf(
pos.Src.Name,
Expand Down
10 changes: 8 additions & 2 deletions gqlerror/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@ func TestErrorFormatting(t *testing.T) {
t.Run("without filename", func(t *testing.T) {
err := ErrorLocf("", 66, 2, "kabloom")

require.Equal(t, `input:66 kabloom`, err.Error())
require.Equal(t, `input:66: kabloom`, err.Error())
require.Equal(t, nil, err.Extensions["file"])
})

t.Run("with filename", func(t *testing.T) {
err := ErrorLocf("schema.graphql", 66, 2, "kabloom")

require.Equal(t, `schema.graphql:66 kabloom`, err.Error())
require.Equal(t, `schema.graphql:66: kabloom`, err.Error())
require.Equal(t, "schema.graphql", err.Extensions["file"])
})

t.Run("with path", func(t *testing.T) {
err := ErrorPathf([]interface{}{"a", 1, "b"}, "kabloom")

require.Equal(t, `input: a[1].b kabloom`, err.Error())
})
}
8 changes: 8 additions & 0 deletions gqlparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ func LoadQuery(schema *ast.Schema, str string) (*ast.QueryDocument, gqlerror.Lis

return query, nil
}

func MustLoadQuery(schema *ast.Schema, str string) *ast.QueryDocument {
q, err := LoadQuery(schema, str)
if err != nil {
panic(err)
}
return q
}
12 changes: 6 additions & 6 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestParserUtils(t *testing.T) {
p.error(p.peek(), "boom")
}
})
require.EqualError(t, p.err, "input.graphql:1 boom")
require.EqualError(t, p.err, "input.graphql:1: boom")
require.Equal(t, []string{"a", "b"}, arr)
})

Expand All @@ -74,7 +74,7 @@ func TestParserUtils(t *testing.T) {
p.error(p.peek(), "test error")
p.error(p.peek(), "secondary error")

require.EqualError(t, p.err, "input.graphql:1 test error")
require.EqualError(t, p.err, "input.graphql:1: test error")

require.Equal(t, "foo", p.peek().Value)
require.Equal(t, "foo", p.next().Value)
Expand All @@ -84,27 +84,27 @@ func TestParserUtils(t *testing.T) {
t.Run("unexpected error", func(t *testing.T) {
p := newParser("1 3")
p.unexpectedError()
require.EqualError(t, p.err, "input.graphql:1 Unexpected Int \"1\"")
require.EqualError(t, p.err, "input.graphql:1: Unexpected Int \"1\"")
})

t.Run("unexpected error", func(t *testing.T) {
p := newParser("1 3")
p.unexpectedToken(p.next())
require.EqualError(t, p.err, "input.graphql:1 Unexpected Int \"1\"")
require.EqualError(t, p.err, "input.graphql:1: Unexpected Int \"1\"")
})

t.Run("expect error", func(t *testing.T) {
p := newParser("foo bar")
p.expect(lexer.Float)

require.EqualError(t, p.err, "input.graphql:1 Expected Float, found Name")
require.EqualError(t, p.err, "input.graphql:1: Expected Float, found Name")
})

t.Run("expectKeyword error", func(t *testing.T) {
p := newParser("foo bar")
p.expectKeyword("baz")

require.EqualError(t, p.err, "input.graphql:1 Expected \"baz\", found Name \"foo\"")
require.EqualError(t, p.err, "input.graphql:1: Expected \"baz\", found Name \"foo\"")
})
}

Expand Down
152 changes: 152 additions & 0 deletions validator/coercevars.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package validator

import (
"fmt"
"reflect"

"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
)

// CoerceVariableValues checks the variables for a given operation are valid. mutates variables to include default values where they were not provided
func CoerceVariableValues(schema *ast.Schema, op *ast.OperationDefinition, variables map[string]interface{}) (map[string]interface{}, *gqlerror.Error) {
coercedVars := map[string]interface{}{}

validator := operationValidator{
path: []interface{}{"variable"},
schema: schema,
}

for _, v := range op.VariableDefinitions {
validator.path = append(validator.path, v.Variable)

if !v.Definition.IsInputType() {
return nil, gqlerror.ErrorPathf(validator.path, "must an input type")
}

val, hasValue := variables[v.Variable]
if !hasValue {
if v.DefaultValue != nil {
var err error
val, err = v.DefaultValue.Value(variables)
if err != nil {
return nil, gqlerror.WrapPath(validator.path, err)
}
hasValue = true
} else if v.Type.NonNull {
return nil, gqlerror.ErrorPathf(validator.path, "must be defined")
}
}

rv := reflect.ValueOf(val)
if v.Type.NonNull && val == nil {
return nil, gqlerror.ErrorPathf(validator.path, "cannot be null")
}

if rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface {
rv = rv.Elem()
}

if err := validator.validateVarType(v.Type, rv); err != nil {
return nil, err
}

if hasValue {
coercedVars[v.Variable] = val
}

validator.path = validator.path[0 : len(validator.path)-1]
}

return coercedVars, nil
}

type operationValidator struct {
path []interface{}
schema *ast.Schema
}

func (v *operationValidator) validateVarType(typ *ast.Type, val reflect.Value) *gqlerror.Error {
if typ.Elem != nil {
if val.Kind() != reflect.Slice {
return gqlerror.ErrorPathf(v.path, "must be an array")
}

for i := 0; i < val.Len(); i++ {
v.path = append(v.path, i)
field := val.Index(i)

fmt.Println(field.Kind(), field.IsNil())
if field.Kind() == reflect.Ptr || field.Kind() == reflect.Interface {
if typ.Elem.NonNull && field.IsNil() {
return gqlerror.ErrorPathf(v.path, "cannot be null")
}
field = field.Elem()
}

if err := v.validateVarType(typ.Elem, field); err != nil {
return err
}

v.path = v.path[0 : len(v.path)-1]
}

return nil
}

def := v.schema.Types[typ.NamedType]
if def == nil {
panic(fmt.Errorf("missing def for %s", typ.NamedType))
}

switch def.Kind {
case ast.Scalar, ast.Enum:
// todo scalar coercion, assuming valid for now
case ast.InputObject:
if val.Kind() != reflect.Map {
return gqlerror.ErrorPathf(v.path, "must be a %s", def.Name)
}

// check for unknown fields
for _, name := range val.MapKeys() {
val.MapIndex(name)
fieldDef := def.Fields.ForName(name.String())
v.path = append(v.path, name)

if fieldDef == nil {
return gqlerror.ErrorPathf(v.path, "unknown field")
}
v.path = v.path[0 : len(v.path)-1]
}

for _, fieldDef := range def.Fields {
v.path = append(v.path, fieldDef.Name)

field := val.MapIndex(reflect.ValueOf(fieldDef.Name))
if !field.IsValid() {
if fieldDef.Type.NonNull {
return gqlerror.ErrorPathf(v.path, "must be defined")
}
continue
}

if field.Kind() == reflect.Ptr || field.Kind() == reflect.Interface {
if typ.NonNull && field.IsNil() {
return gqlerror.ErrorPathf(v.path, "cannot be null")
}
field = field.Elem()
}

err := v.validateVarType(fieldDef.Type, field)
if err != nil {
return err
}

v.path = v.path[0 : len(v.path)-1]
}
default:
panic(fmt.Errorf("unsupported type %s", def.Kind))
}

return nil
}
Loading

0 comments on commit dba1c80

Please sign in to comment.