From ec5f5e66a1303bc6dcca06b717e8e1cdd45e6b96 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Sat, 17 Mar 2018 11:55:32 +1100 Subject: [PATCH] check for valuer receivers before generating type switch --- codegen/build.go | 4 +-- codegen/interface.go | 8 +++++- codegen/interface_build.go | 45 +++++++++++++++++++++++++++---- codegen/models_build.go | 5 ++-- codegen/templates/data.go | 2 +- codegen/templates/interface.gotpl | 12 ++++----- codegen/util.go | 33 +++++++++++++++++++++++ example/starwars/generated.go | 5 ---- test/generated.go | 6 ----- 9 files changed, 92 insertions(+), 28 deletions(-) diff --git a/codegen/build.go b/codegen/build.go index c4cfcb4655d..d24149a9548 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -41,7 +41,7 @@ func Models(schema *schema.Schema, userTypes map[string]string, destDir string) bindTypes(imports, namedTypes, prog) - models := buildModels(namedTypes, schema) + models := buildModels(namedTypes, schema, prog) return &ModelBuild{ PackageName: filepath.Base(destDir), Models: models, @@ -67,7 +67,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* b := &Build{ PackageName: filepath.Base(destDir), Objects: objects, - Interfaces: buildInterfaces(namedTypes, schema), + Interfaces: buildInterfaces(namedTypes, schema, prog), Inputs: inputs, Imports: imports, } diff --git a/codegen/interface.go b/codegen/interface.go index 98c9bc17751..2de0c88a9b3 100644 --- a/codegen/interface.go +++ b/codegen/interface.go @@ -3,5 +3,11 @@ package codegen type Interface struct { *NamedType - Implementors []*NamedType + Implementors []InterfaceImplementor +} + +type InterfaceImplementor struct { + ValueReceiver bool + + *NamedType } diff --git a/codegen/interface_build.go b/codegen/interface_build.go index b45d6ba7d18..672a9e82bff 100644 --- a/codegen/interface_build.go +++ b/codegen/interface_build.go @@ -2,18 +2,21 @@ package codegen import ( "fmt" + "go/types" + "os" "sort" "strings" "github.com/vektah/gqlgen/neelance/schema" + "golang.org/x/tools/go/loader" ) -func buildInterfaces(types NamedTypes, s *schema.Schema) []*Interface { +func buildInterfaces(types NamedTypes, s *schema.Schema, prog *loader.Program) []*Interface { var interfaces []*Interface for _, typ := range s.Types { switch typ := typ.(type) { case *schema.Union, *schema.Interface: - interfaces = append(interfaces, buildInterface(types, typ)) + interfaces = append(interfaces, buildInterface(types, typ, prog)) default: continue } @@ -26,14 +29,15 @@ func buildInterfaces(types NamedTypes, s *schema.Schema) []*Interface { return interfaces } -func buildInterface(types NamedTypes, typ schema.NamedType) *Interface { +func buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program) *Interface { switch typ := typ.(type) { case *schema.Union: i := &Interface{NamedType: types[typ.TypeName()]} for _, implementor := range typ.PossibleTypes { - i.Implementors = append(i.Implementors, types[implementor.TypeName()]) + t := types[implementor.TypeName()] + i.Implementors = append(i.Implementors, InterfaceImplementor{NamedType: t, ValueReceiver: true}) } return i @@ -42,7 +46,12 @@ func buildInterface(types NamedTypes, typ schema.NamedType) *Interface { i := &Interface{NamedType: types[typ.TypeName()]} for _, implementor := range typ.PossibleTypes { - i.Implementors = append(i.Implementors, types[implementor.TypeName()]) + t := types[implementor.TypeName()] + + i.Implementors = append(i.Implementors, InterfaceImplementor{ + NamedType: t, + ValueReceiver: isValueReceiver(types[typ.Name], t, prog), + }) } return i @@ -50,3 +59,29 @@ func buildInterface(types NamedTypes, typ schema.NamedType) *Interface { panic(fmt.Errorf("unknown interface %#v", typ)) } } + +func isValueReceiver(intf *NamedType, implementor *NamedType, prog *loader.Program) bool { + interfaceType := findGoInterface(prog, intf.Package, intf.GoType) + implementorType := findGoNamedType(prog, implementor.Package, implementor.GoType) + + if interfaceType == nil || implementorType == nil { + return true + } + + for i := 0; i < interfaceType.NumMethods(); i++ { + intfMethod := interfaceType.Method(i) + + implMethod := findMethod(implementorType, intfMethod.Name()) + if implMethod == nil { + fmt.Fprintf(os.Stderr, "missing method %s on %s\n", intfMethod.Name(), implementor.GoType) + return false + } + + sig := implMethod.Type().(*types.Signature) + if _, isPtr := sig.Recv().Type().(*types.Pointer); isPtr { + return false + } + } + + return true +} diff --git a/codegen/models_build.go b/codegen/models_build.go index 139020372fb..317a4e7360f 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -5,9 +5,10 @@ import ( "strings" "github.com/vektah/gqlgen/neelance/schema" + "golang.org/x/tools/go/loader" ) -func buildModels(types NamedTypes, s *schema.Schema) []Model { +func buildModels(types NamedTypes, s *schema.Schema, prog *loader.Program) []Model { var models []Model for _, typ := range s.Types { @@ -26,7 +27,7 @@ func buildModels(types NamedTypes, s *schema.Schema) []Model { } model = obj2Model(obj) case *schema.Interface, *schema.Union: - intf := buildInterface(types, typ) + intf := buildInterface(types, typ, prog) if intf.GoType != "" { continue } diff --git a/codegen/templates/data.go b/codegen/templates/data.go index dafbe2855e1..3041b73a62e 100644 --- a/codegen/templates/data.go +++ b/codegen/templates/data.go @@ -5,7 +5,7 @@ var data = map[string]string{ "field.gotpl": "{{ $field := . }}\n{{ $object := $field.Object }}\n\n{{- if $object.Stream }}\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) func() graphql.Marshaler {\n\t\t{{- template \"args.gotpl\" $field.Args }}\n\t\tresults, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})\n\t\tif err != nil {\n\t\t\tec.Error(err)\n\t\t\treturn nil\n\t\t}\n\t\treturn func() graphql.Marshaler {\n\t\t\tres, ok := <-results\n\t\t\tif !ok {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tvar out graphql.OrderedMap\n\t\t\tout.Add(field.Alias, func() graphql.Marshaler { {{ $field.WriteJson }} }())\n\t\t\treturn &out\n\t\t}\n\t}\n{{ else }}\n\tfunc (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {\n\t\t{{- template \"args.gotpl\" $field.Args }}\n\n\t\t{{- if $field.IsConcurrent }}\n\t\t\treturn graphql.Defer(func() (ret graphql.Marshaler) {\n\t\t\t\tdefer func() {\n\t\t\t\t\tif r := recover(); r != nil {\n\t\t\t\t\t\tuserErr := ec.recover(r)\n\t\t\t\t\t\tec.Error(userErr)\n\t\t\t\t\t\tret = graphql.Null\n\t\t\t\t\t}\n\t\t\t\t}()\n\t\t{{- end }}\n\n\t\t\t{{- if $field.GoVarName }}\n\t\t\t\tres := obj.{{$field.GoVarName}}\n\t\t\t{{- else if $field.GoMethodName }}\n\t\t\t\t{{- if $field.NoErr }}\n\t\t\t\t\tres := {{$field.GoMethodName}}({{ $field.CallArgs }})\n\t\t\t\t{{- else }}\n\t\t\t\t\tres, err := {{$field.GoMethodName}}({{ $field.CallArgs }})\n\t\t\t\t\tif err != nil {\n\t\t\t\t\t\tec.Error(err)\n\t\t\t\t\t\treturn graphql.Null\n\t\t\t\t\t}\n\t\t\t\t{{- end }}\n\t\t\t{{- else }}\n\t\t\t\tres, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})\n\t\t\t\tif err != nil {\n\t\t\t\t\tec.Error(err)\n\t\t\t\t\treturn graphql.Null\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t{{ $field.WriteJson }}\n\t\t{{- if $field.IsConcurrent }}\n\t\t\t})\n\t\t{{- end }}\n\t}\n{{ end }}\n", "generated.gotpl": "// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\nfunc MakeExecutableSchema(resolvers Resolvers) graphql.ExecutableSchema {\n\treturn &executableSchema{resolvers}\n}\n\ntype Resolvers interface {\n{{- range $object := .Objects -}}\n\t{{ range $field := $object.Fields -}}\n\t\t{{ $field.ResolverDeclaration }}\n\t{{ end }}\n{{- end }}\n}\n\ntype executableSchema struct {\n\tresolvers Resolvers\n}\n\nfunc (e *executableSchema) Schema() *schema.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx, recover: recover}\n\n\t\tdata := ec._{{.QueryRoot.GQLType}}(op.Selections)\n\t\tvar buf bytes.Buffer\n\t\tdata.MarshalGQL(&buf)\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf.Bytes(),\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn &graphql.Response{Errors: []*errors.QueryError{ {Message: \"queries are not supported\"} }}\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx, recover: recover}\n\n\t\tdata := ec._{{.MutationRoot.GQLType}}(op.Selections)\n\t\tvar buf bytes.Buffer\n\t\tdata.MarshalGQL(&buf)\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf.Bytes(),\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn &graphql.Response{Errors: []*errors.QueryError{ {Message: \"mutations are not supported\"} }}\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx, recover: recover}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(op.Selections)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf.Reset()\n\t\t\tdata := next()\n\t\t\tif data == nil {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tdata.MarshalGQL(&buf)\n\n\t\t\terrs := ec.Errors\n\t\t\tec.Errors = nil\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf.Bytes(),\n\t\t\t\tErrors: errs,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(&graphql.Response{Errors: []*errors.QueryError{ {Message: \"subscriptions are not supported\"} }})\n\t{{- end }}\n}\n\ntype executionContext struct {\n\terrors.Builder\n\tresolvers Resolvers\n\tvariables map[string]interface{}\n\tdoc *query.Document\n\tctx context.Context\n\trecover graphql.RecoverFunc\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nvar parsedSchema = schema.MustParse({{.SchemaRaw|quote}})\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\tt := parsedSchema.Resolve(name)\n\tif t == nil {\n\t\treturn nil\n\t}\n\treturn introspection.WrapType(t)\n}\n", "input.gotpl": "\t{{- if .IsMarshaled }}\n\tfunc Unmarshal{{ .GQLType }}(v interface{}) ({{.FullName}}, error) {\n\t\tvar it {{.FullName}}\n\n\t\tfor k, v := range v.(map[string]interface{}) {\n\t\t\tswitch k {\n\t\t\t{{- range $field := .Fields }}\n\t\t\tcase {{$field.GQLName|quote}}:\n\t\t\t\tvar err error\n\t\t\t\t{{ $field.Unmarshal (print \"it.\" $field.GoVarName) \"v\" }}\n\t\t\t\tif err != nil {\n\t\t\t\t\treturn it, err\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\n\t\treturn it, nil\n\t}\n\t{{- end }}\n", - "interface.gotpl": "{{- $interface := . }}\n\nfunc (ec *executionContext) _{{$interface.GQLType}}(sel []query.Selection, obj *{{$interface.FullName}}) graphql.Marshaler {\n\tswitch obj := (*obj).(type) {\n\tcase nil:\n\t\treturn graphql.Null\n\t{{- range $implementor := $interface.Implementors }}\n\tcase {{$implementor.FullName}}:\n\t\treturn ec._{{$implementor.GQLType}}(sel, &obj)\n\n\tcase *{{$implementor.FullName}}:\n\t\treturn ec._{{$implementor.GQLType}}(sel, obj)\n\n\t{{- end }}\n\tdefault:\n\t\tpanic(fmt.Errorf(\"unexpected type %T\", obj))\n\t}\n}\n", + "interface.gotpl": "{{- $interface := . }}\n\nfunc (ec *executionContext) _{{$interface.GQLType}}(sel []query.Selection, obj *{{$interface.FullName}}) graphql.Marshaler {\n\tswitch obj := (*obj).(type) {\n\tcase nil:\n\t\treturn graphql.Null\n\t{{- range $implementor := $interface.Implementors }}\n\t\t{{- if $implementor.ValueReceiver }}\n\t\t\tcase {{$implementor.FullName}}:\n\t\t\t\treturn ec._{{$implementor.GQLType}}(sel, &obj)\n\t\t{{- end}}\n\t\tcase *{{$implementor.FullName}}:\n\t\t\treturn ec._{{$implementor.GQLType}}(sel, obj)\n\t{{- end }}\n\tdefault:\n\t\tpanic(fmt.Errorf(\"unexpected type %T\", obj))\n\t}\n}\n", "models.gotpl": "// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n{{ range $model := .Models }}\n\t{{- if .IsInterface }}\n\t\ttype {{.GoType}} interface {}\n\t{{- else }}\n\t\ttype {{.GoType}} struct {\n\t\t\t{{- range $field := .Fields }}\n\t\t\t\t{{- if $field.GoVarName }}\n\t\t\t\t\t{{ $field.GoVarName }} {{$field.Signature}}\n\t\t\t\t{{- else }}\n\t\t\t\t\t{{ $field.GoFKName }} {{$field.GoFKType}}\n\t\t\t\t{{- end }}\n\t\t\t{{- end }}\n\t\t}\n\t{{- end }}\n{{- end}}\n", "object.gotpl": "{{ $object := . }}\n\nvar {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}}\n\n// nolint: gocyclo, errcheck, gas, goconst\n{{- if .Stream }}\nfunc (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) func() graphql.Marshaler {\n\tfields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)\n\n\tif len(fields) != 1 {\n\t\tec.Errorf(\"must subscribe to exactly one stream\")\n\t\treturn nil\n\t}\n\n\tswitch fields[0].Name {\n\t{{- range $field := $object.Fields }}\n\tcase \"{{$field.GQLName}}\":\n\t\treturn ec._{{$object.GQLType}}_{{$field.GQLName}}(fields[0])\n\t{{- end }}\n\tdefault:\n\t\tpanic(\"unknown field \" + strconv.Quote(fields[0].Name))\n\t}\n}\n{{- else }}\nfunc (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection{{if not $object.Root}}, obj *{{$object.FullName}} {{end}}) graphql.Marshaler {\n\tfields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)\n\tout := graphql.NewOrderedMap(len(fields))\n\tfor i, field := range fields {\n\t\tout.Keys[i] = field.Alias\n\n\t\tswitch field.Name {\n\t\tcase \"__typename\":\n\t\t\tout.Values[i] = graphql.MarshalString({{$object.GQLType|quote}})\n\t\t{{- range $field := $object.Fields }}\n\t\tcase \"{{$field.GQLName}}\":\n\t\t\tout.Values[i] = ec._{{$object.GQLType}}_{{$field.GQLName}}(field{{if not $object.Root}}, obj{{end}})\n\t\t{{- end }}\n\t\tdefault:\n\t\t\tpanic(\"unknown field \" + strconv.Quote(field.Name))\n\t\t}\n\t}\n\n\treturn out\n}\n{{- end }}\n", } diff --git a/codegen/templates/interface.gotpl b/codegen/templates/interface.gotpl index 58cd1d3a19b..0392a98b694 100644 --- a/codegen/templates/interface.gotpl +++ b/codegen/templates/interface.gotpl @@ -5,12 +5,12 @@ func (ec *executionContext) _{{$interface.GQLType}}(sel []query.Selection, obj * case nil: return graphql.Null {{- range $implementor := $interface.Implementors }} - case {{$implementor.FullName}}: - return ec._{{$implementor.GQLType}}(sel, &obj) - - case *{{$implementor.FullName}}: - return ec._{{$implementor.GQLType}}(sel, obj) - + {{- if $implementor.ValueReceiver }} + case {{$implementor.FullName}}: + return ec._{{$implementor.GQLType}}(sel, &obj) + {{- end}} + case *{{$implementor.FullName}}: + return ec._{{$implementor.GQLType}}(sel, obj) {{- end }} default: panic(fmt.Errorf("unexpected type %T", obj)) diff --git a/codegen/util.go b/codegen/util.go index 81bb069deda..5f1b7328312 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -38,6 +38,39 @@ func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Ob return nil, fmt.Errorf("unable to find type %s\n", fullName) } +func findGoNamedType(prog *loader.Program, pkgName string, typeName string) *types.Named { + def, err := findGoType(prog, pkgName, typeName) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + } + if def == nil { + return nil + } + + namedType, ok := def.Type().(*types.Named) + if !ok { + fmt.Fprintf(os.Stderr, "expected %s to be a named type, instead found %T\n", typeName, def.Type()) + return nil + } + + return namedType +} + +func findGoInterface(prog *loader.Program, pkgName string, typeName string) *types.Interface { + namedType := findGoNamedType(prog, pkgName, typeName) + if namedType == nil { + return nil + } + + underlying, ok := namedType.Underlying().(*types.Interface) + if !ok { + fmt.Fprintf(os.Stderr, "expected %s to be a named interface, instead found %s", typeName, namedType.String()) + return nil + } + + return underlying +} + func findMethod(typ *types.Named, name string) *types.Func { for i := 0; i < typ.NumMethods(); i++ { method := typ.Method(i) diff --git a/example/starwars/generated.go b/example/starwars/generated.go index f88d61e518b..4e3d389ea2f 100644 --- a/example/starwars/generated.go +++ b/example/starwars/generated.go @@ -1512,12 +1512,10 @@ func (ec *executionContext) _Character(sel []query.Selection, obj *Character) gr return graphql.Null case Human: return ec._Human(sel, &obj) - case *Human: return ec._Human(sel, obj) case Droid: return ec._Droid(sel, &obj) - case *Droid: return ec._Droid(sel, obj) default: @@ -1531,17 +1529,14 @@ func (ec *executionContext) _SearchResult(sel []query.Selection, obj *SearchResu return graphql.Null case Human: return ec._Human(sel, &obj) - case *Human: return ec._Human(sel, obj) case Droid: return ec._Droid(sel, &obj) - case *Droid: return ec._Droid(sel, obj) case Starship: return ec._Starship(sel, &obj) - case *Starship: return ec._Starship(sel, obj) default: diff --git a/test/generated.go b/test/generated.go index ade26740076..1f61592891b 100644 --- a/test/generated.go +++ b/test/generated.go @@ -867,14 +867,8 @@ func (ec *executionContext) _Shape(sel []query.Selection, obj *Shape) graphql.Ma switch obj := (*obj).(type) { case nil: return graphql.Null - case Circle: - return ec._Circle(sel, &obj) - case *Circle: return ec._Circle(sel, obj) - case Rectangle: - return ec._Rectangle(sel, &obj) - case *Rectangle: return ec._Rectangle(sel, obj) default: