Skip to content

Commit

Permalink
Merge pull request #1009 from 99designs/interface-regression
Browse files Browse the repository at this point in the history
Interface regression
  • Loading branch information
vektah authored Feb 5, 2020
2 parents 0ddb3ef + ffc419f commit f7667e1
Show file tree
Hide file tree
Showing 8 changed files with 659 additions and 19 deletions.
23 changes: 5 additions & 18 deletions codegen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,12 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
}
}

func (b *builder) findBindTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
if _, ok := t.Underlying().(*types.Interface); ok {
return nil, errors.New("can't bind to an interface at root")
}
case *types.Interface:
return nil, errors.New("can't bind to an interface at root")
}

return b.findBindTargetRecur(in, name)
}

// findBindTargetRecur attempts to match the name to a field or method on a Type
// findBindTarget attempts to match the name to a field or method on a Type
// with the following priorites:
// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
// 2. Any method or field with a matching name. Errors if more than one match is found
// 3. Same logic again for embedded fields
func (b *builder) findBindTargetRecur(t types.Type, name string) (types.Object, error) {
func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(t, name)
Expand Down Expand Up @@ -366,7 +353,7 @@ func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string)
fieldType = ptr.Elem()
}

f, err := b.findBindTargetRecur(fieldType, name)
f, err := b.findBindTarget(fieldType, name)
if err != nil {
return nil, err
}
Expand All @@ -388,7 +375,7 @@ func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name str
for i := 0; i < iface.NumEmbeddeds(); i++ {
embeddedType := iface.EmbeddedType(i)

f, err := b.findBindTargetRecur(embeddedType, name)
f, err := b.findBindTarget(embeddedType, name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -481,7 +468,7 @@ func (f *Field) ShortResolverDeclaration() string {
res := "(ctx context.Context"

if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Type))
res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
Expand Down
7 changes: 6 additions & 1 deletion codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
}

func (o *Object) Reference() types.Type {
switch o.Type.(type) {
switch v := o.Type.(type) {
case *types.Named:
_, isInterface := v.Underlying().(*types.Interface)
if isInterface {
return o.Type
}
case *types.Pointer, *types.Slice, *types.Map:
return o.Type
}
Expand Down
Loading

0 comments on commit f7667e1

Please sign in to comment.