Skip to content

Commit

Permalink
Fix defaulter for nested recursive types
Browse files Browse the repository at this point in the history
Rename
  • Loading branch information
nikhita committed Jul 6, 2017
1 parent cc8100b commit ce68359
Showing 1 changed file with 58 additions and 16 deletions.
74 changes: 58 additions & 16 deletions examples/defaulter-gen/generators/defaulter.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
if d.object != nil {
continue
}
if buildCallTreeForType(t, true, existingDefaulters, newDefaulters) != nil {
// currentlyBuildingTypes keeps track of types that have already been visited in the tree.
// This is used to avoid recursion for recursive types.
currentlyBuildingTypes := make(map[*types.Type]bool)
c := NewCallTreeForType(t, true, existingDefaulters, newDefaulters, currentlyBuildingTypes)
if c.buildCallTreeForType() != nil {
args := defaultingArgsFromType(t)
sw.Do("$.inType|objectdefaultfn$", args)
newDefaulters[t] = defaults{
Expand Down Expand Up @@ -387,6 +391,25 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
return packages
}

// callTreeForType contains fields necessary to call buildCallTreeForType.
type callTreeForType struct {
t *types.Type
root bool
existingDefaulters defaulterFuncMap
newDefaulters defaulterFuncMap
currentlyBuildingTypes map[*types.Type]bool
}

func NewCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefaulters defaulterFuncMap, currentlyBuildingTypes map[*types.Type]bool) callTreeForType {
return callTreeForType{
t: t,
root: root,
existingDefaulters: existingDefaulters,
newDefaulters: newDefaulters,
currentlyBuildingTypes: currentlyBuildingTypes,
}
}

// buildCallTreeForType creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
// slice, or key) and the functions that should be invoked on each field. An in-order traversal of the resulting tree
// can be used to generate a Go function that invokes each nested function on the appropriate type. The return
Expand All @@ -396,22 +419,22 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
// list types that call empty defaulters.
func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefaulters defaulterFuncMap) *callNode {
func (c callTreeForType) buildCallTreeForType() *callNode {
parent := &callNode{}

if root {
if c.root {
// the root node is always a pointer
parent.elem = true
}

defaults, _ := existingDefaulters[t]
newDefaults, generated := newDefaulters[t]
defaults, _ := c.existingDefaulters[c.t]
newDefaults, generated := c.newDefaulters[c.t]
switch {
case !root && generated && newDefaults.object != nil:
case !c.root && generated && newDefaults.object != nil:
parent.call = append(parent.call, newDefaults.object)
// if we will be generating the defaulter, it by definition is a covering
// defaulter, so we halt recursion
glog.V(6).Infof("the defaulter %s will be generated as an object defaulter", t.Name)
glog.V(6).Infof("the defaulter %s will be generated as an object defaulter", c.t.Name)
return parent

case defaults.object != nil:
Expand All @@ -424,32 +447,43 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
// if the base function indicates it "covers" (it already includes defaulters)
// we can halt recursion
if checkTag(defaults.base.CommentLines, "covers") {
glog.V(6).Infof("the defaulter %s indicates it covers all sub generators", t.Name)
glog.V(6).Infof("the defaulter %s indicates it covers all sub generators", c.t.Name)
return parent
}
}

// base has been added already, now add any additional defaulters defined for this object
parent.call = append(parent.call, defaults.additional...)

switch t.Kind {
// if the type already exists, don't build the tree for it and don't generate anything.
// This is used to avoid recursion for nested recursive types.
if c.currentlyBuildingTypes[c.t] {
return nil
}
// if type doesn't exist, mark it as existing
c.currentlyBuildingTypes[c.t] = true

switch c.t.Kind {
case types.Pointer:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
r := NewCallTreeForType(c.t.Elem, false, c.existingDefaulters, c.newDefaulters, c.currentlyBuildingTypes)
if child := r.buildCallTreeForType(); child != nil {
child.elem = true
parent.children = append(parent.children, *child)
}
case types.Slice, types.Array:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
r := NewCallTreeForType(c.t.Elem, false, c.existingDefaulters, c.newDefaulters, c.currentlyBuildingTypes)
if child := r.buildCallTreeForType(); child != nil {
child.index = true
parent.children = append(parent.children, *child)
}
case types.Map:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
r := NewCallTreeForType(c.t.Elem, false, c.existingDefaulters, c.newDefaulters, c.currentlyBuildingTypes)
if child := r.buildCallTreeForType(); child != nil {
child.key = true
parent.children = append(parent.children, *child)
}
case types.Struct:
for _, field := range t.Members {
for _, field := range c.t.Members {
name := field.Name
if len(name) == 0 {
if field.Type.Kind == types.Pointer {
Expand All @@ -458,20 +492,26 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
name = field.Type.Name.Name
}
}
if child := buildCallTreeForType(field.Type, false, existingDefaulters, newDefaulters); child != nil {
r := NewCallTreeForType(field.Type, false, c.existingDefaulters, c.newDefaulters, c.currentlyBuildingTypes)
if child := r.buildCallTreeForType(); child != nil {
child.field = name
parent.children = append(parent.children, *child)
}
}
case types.Alias:
if child := buildCallTreeForType(t.Underlying, false, existingDefaulters, newDefaulters); child != nil {
r := NewCallTreeForType(c.t.Underlying, false, c.existingDefaulters, c.newDefaulters, c.currentlyBuildingTypes)
if child := r.buildCallTreeForType(); child != nil {
parent.children = append(parent.children, *child)
}
}
if len(parent.children) == 0 && len(parent.call) == 0 {
//glog.V(6).Infof("decided type %s needs no generation", t.Name)
return nil
}

// The type now acts as a parent, not a nested recursive type.
// We can now build the tree for it safely.
c.currentlyBuildingTypes[c.t] = false
return parent
}

Expand Down Expand Up @@ -571,7 +611,9 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr

glog.V(5).Infof("generating for type %v", t)

callTree := buildCallTreeForType(t, true, g.existingDefaulters, g.newDefaulters)
currentlyBuildingTypes := make(map[*types.Type]bool)
newcallTreeForType := NewCallTreeForType(t, true, g.existingDefaulters, g.newDefaulters, currentlyBuildingTypes)
callTree := newcallTreeForType.buildCallTreeForType()
if callTree == nil {
glog.V(5).Infof(" no defaulters defined")
return nil
Expand Down

0 comments on commit ce68359

Please sign in to comment.