Skip to content

Commit

Permalink
check for valuer receivers before generating type switch
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Mar 17, 2018
1 parent dc89840 commit ec5f5e6
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 28 deletions.
4 changes: 2 additions & 2 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func Models(schema *schema.Schema, userTypes map[string]string, destDir string)

bindTypes(imports, namedTypes, prog)

models := buildModels(namedTypes, schema)
models := buildModels(namedTypes, schema, prog)
return &ModelBuild{
PackageName: filepath.Base(destDir),
Models: models,
Expand All @@ -67,7 +67,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
b := &Build{
PackageName: filepath.Base(destDir),
Objects: objects,
Interfaces: buildInterfaces(namedTypes, schema),
Interfaces: buildInterfaces(namedTypes, schema, prog),
Inputs: inputs,
Imports: imports,
}
Expand Down
8 changes: 7 additions & 1 deletion codegen/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,11 @@ package codegen
type Interface struct {
*NamedType

Implementors []*NamedType
Implementors []InterfaceImplementor
}

type InterfaceImplementor struct {
ValueReceiver bool

*NamedType
}
45 changes: 40 additions & 5 deletions codegen/interface_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ package codegen

import (
"fmt"
"go/types"
"os"
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
"golang.org/x/tools/go/loader"
)

func buildInterfaces(types NamedTypes, s *schema.Schema) []*Interface {
func buildInterfaces(types NamedTypes, s *schema.Schema, prog *loader.Program) []*Interface {
var interfaces []*Interface
for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.Union, *schema.Interface:
interfaces = append(interfaces, buildInterface(types, typ))
interfaces = append(interfaces, buildInterface(types, typ, prog))
default:
continue
}
Expand All @@ -26,14 +29,15 @@ func buildInterfaces(types NamedTypes, s *schema.Schema) []*Interface {
return interfaces
}

func buildInterface(types NamedTypes, typ schema.NamedType) *Interface {
func buildInterface(types NamedTypes, typ schema.NamedType, prog *loader.Program) *Interface {
switch typ := typ.(type) {

case *schema.Union:
i := &Interface{NamedType: types[typ.TypeName()]}

for _, implementor := range typ.PossibleTypes {
i.Implementors = append(i.Implementors, types[implementor.TypeName()])
t := types[implementor.TypeName()]
i.Implementors = append(i.Implementors, InterfaceImplementor{NamedType: t, ValueReceiver: true})
}

return i
Expand All @@ -42,11 +46,42 @@ func buildInterface(types NamedTypes, typ schema.NamedType) *Interface {
i := &Interface{NamedType: types[typ.TypeName()]}

for _, implementor := range typ.PossibleTypes {
i.Implementors = append(i.Implementors, types[implementor.TypeName()])
t := types[implementor.TypeName()]

i.Implementors = append(i.Implementors, InterfaceImplementor{
NamedType: t,
ValueReceiver: isValueReceiver(types[typ.Name], t, prog),
})
}

return i
default:
panic(fmt.Errorf("unknown interface %#v", typ))
}
}

func isValueReceiver(intf *NamedType, implementor *NamedType, prog *loader.Program) bool {
interfaceType := findGoInterface(prog, intf.Package, intf.GoType)
implementorType := findGoNamedType(prog, implementor.Package, implementor.GoType)

if interfaceType == nil || implementorType == nil {
return true
}

for i := 0; i < interfaceType.NumMethods(); i++ {
intfMethod := interfaceType.Method(i)

implMethod := findMethod(implementorType, intfMethod.Name())
if implMethod == nil {
fmt.Fprintf(os.Stderr, "missing method %s on %s\n", intfMethod.Name(), implementor.GoType)
return false
}

sig := implMethod.Type().(*types.Signature)
if _, isPtr := sig.Recv().Type().(*types.Pointer); isPtr {
return false
}
}

return true
}
5 changes: 3 additions & 2 deletions codegen/models_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import (
"strings"

"github.com/vektah/gqlgen/neelance/schema"
"golang.org/x/tools/go/loader"
)

func buildModels(types NamedTypes, s *schema.Schema) []Model {
func buildModels(types NamedTypes, s *schema.Schema, prog *loader.Program) []Model {
var models []Model

for _, typ := range s.Types {
Expand All @@ -26,7 +27,7 @@ func buildModels(types NamedTypes, s *schema.Schema) []Model {
}
model = obj2Model(obj)
case *schema.Interface, *schema.Union:
intf := buildInterface(types, typ)
intf := buildInterface(types, typ, prog)
if intf.GoType != "" {
continue
}
Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/data.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions codegen/templates/interface.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ func (ec *executionContext) _{{$interface.GQLType}}(sel []query.Selection, obj *
case nil:
return graphql.Null
{{- range $implementor := $interface.Implementors }}
case {{$implementor.FullName}}:
return ec._{{$implementor.GQLType}}(sel, &obj)

case *{{$implementor.FullName}}:
return ec._{{$implementor.GQLType}}(sel, obj)

{{- if $implementor.ValueReceiver }}
case {{$implementor.FullName}}:
return ec._{{$implementor.GQLType}}(sel, &obj)
{{- end}}
case *{{$implementor.FullName}}:
return ec._{{$implementor.GQLType}}(sel, obj)
{{- end }}
default:
panic(fmt.Errorf("unexpected type %T", obj))
Expand Down
33 changes: 33 additions & 0 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ func findGoType(prog *loader.Program, pkgName string, typeName string) (types.Ob
return nil, fmt.Errorf("unable to find type %s\n", fullName)
}

func findGoNamedType(prog *loader.Program, pkgName string, typeName string) *types.Named {
def, err := findGoType(prog, pkgName, typeName)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
if def == nil {
return nil
}

namedType, ok := def.Type().(*types.Named)
if !ok {
fmt.Fprintf(os.Stderr, "expected %s to be a named type, instead found %T\n", typeName, def.Type())
return nil
}

return namedType
}

func findGoInterface(prog *loader.Program, pkgName string, typeName string) *types.Interface {
namedType := findGoNamedType(prog, pkgName, typeName)
if namedType == nil {
return nil
}

underlying, ok := namedType.Underlying().(*types.Interface)
if !ok {
fmt.Fprintf(os.Stderr, "expected %s to be a named interface, instead found %s", typeName, namedType.String())
return nil
}

return underlying
}

func findMethod(typ *types.Named, name string) *types.Func {
for i := 0; i < typ.NumMethods(); i++ {
method := typ.Method(i)
Expand Down
5 changes: 0 additions & 5 deletions example/starwars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -1512,12 +1512,10 @@ func (ec *executionContext) _Character(sel []query.Selection, obj *Character) gr
return graphql.Null
case Human:
return ec._Human(sel, &obj)

case *Human:
return ec._Human(sel, obj)
case Droid:
return ec._Droid(sel, &obj)

case *Droid:
return ec._Droid(sel, obj)
default:
Expand All @@ -1531,17 +1529,14 @@ func (ec *executionContext) _SearchResult(sel []query.Selection, obj *SearchResu
return graphql.Null
case Human:
return ec._Human(sel, &obj)

case *Human:
return ec._Human(sel, obj)
case Droid:
return ec._Droid(sel, &obj)

case *Droid:
return ec._Droid(sel, obj)
case Starship:
return ec._Starship(sel, &obj)

case *Starship:
return ec._Starship(sel, obj)
default:
Expand Down
6 changes: 0 additions & 6 deletions test/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,14 +867,8 @@ func (ec *executionContext) _Shape(sel []query.Selection, obj *Shape) graphql.Ma
switch obj := (*obj).(type) {
case nil:
return graphql.Null
case Circle:
return ec._Circle(sel, &obj)

case *Circle:
return ec._Circle(sel, obj)
case Rectangle:
return ec._Rectangle(sel, &obj)

case *Rectangle:
return ec._Rectangle(sel, obj)
default:
Expand Down

0 comments on commit ec5f5e6

Please sign in to comment.