Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix defaulter-gen for recursive types #61

Merged
merged 1 commit into from
Jul 6, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions examples/defaulter-gen/generators/defaulter.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
if d.object != nil {
continue
}
if buildCallTreeForType(t, true, existingDefaulters, newDefaulters) != nil {
if newCallTreeForType(existingDefaulters, newDefaulters).build(t, true) != nil {
args := defaultingArgsFromType(t)
sw.Do("$.inType|objectdefaultfn$", args)
newDefaulters[t] = defaults{
Expand Down Expand Up @@ -387,7 +387,22 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
return packages
}

// buildCallTreeForType creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
// callTreeForType contains fields necessary to build a tree for types.
type callTreeForType struct {
existingDefaulters defaulterFuncMap
newDefaulters defaulterFuncMap
currentlyBuildingTypes map[*types.Type]bool
}

func newCallTreeForType(existingDefaulters, newDefaulters defaulterFuncMap) *callTreeForType {
return &callTreeForType{
existingDefaulters: existingDefaulters,
newDefaulters: newDefaulters,
currentlyBuildingTypes: make(map[*types.Type]bool),
}
}

// build creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
// slice, or key) and the functions that should be invoked on each field. An in-order traversal of the resulting tree
// can be used to generate a Go function that invokes each nested function on the appropriate type. The return
// value may be nil if there are no functions to call on type or the type is a primitive (Defaulters can only be
Expand All @@ -396,16 +411,16 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
// list types that call empty defaulters.
func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefaulters defaulterFuncMap) *callNode {
func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
parent := &callNode{}

if root {
// the root node is always a pointer
parent.elem = true
}

defaults, _ := existingDefaulters[t]
newDefaults, generated := newDefaulters[t]
defaults, _ := c.existingDefaulters[t]
newDefaults, generated := c.newDefaulters[t]
switch {
case !root && generated && newDefaults.object != nil:
parent.call = append(parent.call, newDefaults.object)
Expand All @@ -432,19 +447,33 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
// base has been added already, now add any additional defaulters defined for this object
parent.call = append(parent.call, defaults.additional...)

// if the type already exists, don't build the tree for it and don't generate anything.
// This is used to avoid recursion for nested recursive types.
if c.currentlyBuildingTypes[t] {
return nil
}
// if type doesn't exist, mark it as existing
c.currentlyBuildingTypes[t] = true

defer func() {
// The type will now acts as a parent, not a nested recursive type.
// We can now build the tree for it safely.
c.currentlyBuildingTypes[t] = false
}()

switch t.Kind {
case types.Pointer:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Elem, false); child != nil {
child.elem = true
parent.children = append(parent.children, *child)
}
case types.Slice, types.Array:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Elem, false); child != nil {
child.index = true
parent.children = append(parent.children, *child)
}
case types.Map:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Elem, false); child != nil {
child.key = true
parent.children = append(parent.children, *child)
}
Expand All @@ -458,13 +487,13 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
name = field.Type.Name.Name
}
}
if child := buildCallTreeForType(field.Type, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(field.Type, false); child != nil {
child.field = name
parent.children = append(parent.children, *child)
}
}
case types.Alias:
if child := buildCallTreeForType(t.Underlying, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Underlying, false); child != nil {
parent.children = append(parent.children, *child)
}
}
Expand Down Expand Up @@ -571,7 +600,7 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr

glog.V(5).Infof("generating for type %v", t)

callTree := buildCallTreeForType(t, true, g.existingDefaulters, g.newDefaulters)
callTree := newCallTreeForType(g.existingDefaulters, g.newDefaulters).build(t, true)
if callTree == nil {
glog.V(5).Infof(" no defaulters defined")
return nil
Expand Down