diff --git a/result.go b/result.go index 369cd218..f5e37494 100644 --- a/result.go +++ b/result.go @@ -23,11 +23,16 @@ package dig import ( "fmt" "reflect" + "strings" "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/dot" ) +const ( + _extraAnonymous = "extra-anonymous" +) + // The result interface represents a result produced by a constructor. // // The following implementations exist: @@ -74,7 +79,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) { case isError(t): return nil, newErrInvalidInput("cannot return an error here, return it from the constructor instead", nil) case IsOut(t): - return newResultObject(t, opts) + return newResultObject(t, opts, false) case embedsType(t, _outPtrType): return nil, newErrInvalidInput(fmt.Sprintf( "cannot build a result object by embedding *dig.Out, embed dig.Out instead: %v embeds *dig.Out", t), nil) @@ -353,7 +358,7 @@ func (ro resultObject) DotResult() []*dot.Result { return types } -func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { +func newResultObject(t reflect.Type, opts resultOptions, anonymous bool) (resultObject, error) { ro := resultObject{Type: t} if len(opts.Name) > 0 { return ro, newErrInvalidInput(fmt.Sprintf( @@ -372,19 +377,66 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { continue } + if anonymous && !f.IsExported() { + continue + } + rof, err := newResultObjectField(i, f, opts) if err != nil { return ro, newErrInvalidInput(fmt.Sprintf("bad field %q of %v", f.Name, t), err) } ro.Fields = append(ro.Fields, rof) + + if !f.Anonymous || f.Tag.Get(_extraAnonymous) != "true" { + continue + } + if err = extraAnonymous(&ro, &f, &rof, opts); err != nil { + return ro, err + } } return ro, nil } +func extraAnonymous(ro *resultObject, f *reflect.StructField, rof *resultObjectField, opts resultOptions) error { + ft := f.Type + if ft.Kind() == reflect.Pointer { + ft = ft.Elem() + } + subRo, err := newResultObject(ft, opts, true) + if err != nil { + return err + } + + for _, subField := range subRo.Fields { + subField.FieldIndices = append(rof.FieldIndices, subField.FieldIndices...) + switch rofResult := rof.Result.(type) { + case resultGrouped: + switch subResult := subField.Result.(type) { + case resultGrouped: + subResult.Group = strings.Join([]string{rofResult.Group, subResult.Group}, ",") + case resultSingle: + subField.Result = resultGrouped{ + Group: rofResult.Group, + Flatten: rofResult.Flatten, + Type: subResult.Type, + } + } + } + + ro.Fields = append(ro.Fields, subField) + } + + return nil +} + func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) + var rv reflect.Value = v + for _, fieldIndex := range f.FieldIndices { + rv = rv.Field(fieldIndex) + } + f.Result.Extract(cw, decorated, rv) } } @@ -397,7 +449,7 @@ type resultObjectField struct { // // We need to track this separately because not all fields of the struct // map to results. - FieldIndex int + FieldIndices []int // Result produced by this field. Result result @@ -411,8 +463,8 @@ func (rof resultObjectField) DotResult() []*dot.Result { // f at index i. func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (resultObjectField, error) { rof := resultObjectField{ - FieldName: f.Name, - FieldIndex: idx, + FieldName: f.Name, + FieldIndices: []int{idx}, } var r result diff --git a/result_test.go b/result_test.go index c19db20d..abd3643e 100644 --- a/result_test.go +++ b/result_test.go @@ -116,6 +116,11 @@ func TestNewResultErrors(t *testing.T) { } func TestNewResultObject(t *testing.T) { + type Embed struct { + Writer io.Writer + } + + typeOfEmbed := reflect.TypeOf(&Embed{}).Elem() typeOfReader := reflect.TypeOf((*io.Reader)(nil)).Elem() typeOfWriter := reflect.TypeOf((*io.Writer)(nil)).Elem() @@ -137,14 +142,14 @@ func TestNewResultObject(t *testing.T) { }{}, wantFields: []resultObjectField{ { - FieldName: "Reader", - FieldIndex: 1, - Result: resultSingle{Type: typeOfReader}, + FieldName: "Reader", + FieldIndices: []int{1}, + Result: resultSingle{Type: typeOfReader}, }, { - FieldName: "Writer", - FieldIndex: 2, - Result: resultSingle{Type: typeOfWriter}, + FieldName: "Writer", + FieldIndices: []int{2}, + Result: resultSingle{Type: typeOfWriter}, }, }, }, @@ -158,14 +163,14 @@ func TestNewResultObject(t *testing.T) { }{}, wantFields: []resultObjectField{ { - FieldName: "A", - FieldIndex: 1, - Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, + FieldName: "A", + FieldIndices: []int{1}, + Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, }, { - FieldName: "B", - FieldIndex: 2, - Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, + FieldName: "B", + FieldIndices: []int{2}, + Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, }, }, }, @@ -178,9 +183,64 @@ func TestNewResultObject(t *testing.T) { }{}, wantFields: []resultObjectField{ { - FieldName: "Writer", - FieldIndex: 1, - Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + FieldName: "Writer", + FieldIndices: []int{1}, + Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + }, + }, + }, + { + desc: "anonymous", + give: struct { + Out + + Embed + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Embed", + FieldIndices: []int{1}, + Result: resultSingle{Name: "", Type: typeOfEmbed}, + }, + }, + }, + { + desc: "anonymous", + give: struct { + Out + + Embed `extra-anonymous:"true"` + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Embed", + FieldIndices: []int{1}, + Result: resultSingle{Name: "", Type: typeOfEmbed}, + }, + { + FieldName: "Writer", + FieldIndices: []int{1, 0}, + Result: resultSingle{Name: "", Type: typeOfWriter}, + }, + }, + }, + { + desc: "anonymous group", + give: struct { + Out + + Embed `extra-anonymous:"true" group:"embed_group"` + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Embed", + FieldIndices: []int{1}, + Result: resultGrouped{Group: "embed_group", Type: typeOfEmbed}, + }, + { + FieldName: "Writer", + FieldIndices: []int{1, 0}, + Result: resultGrouped{Group: "embed_group", Type: typeOfWriter}, }, }, }, @@ -188,7 +248,7 @@ func TestNewResultObject(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts) + got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts, false) require.NoError(t, err) assert.Equal(t, tt.wantFields, got.Fields) }) @@ -302,7 +362,7 @@ func TestNewResultObjectErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - _, err := newResultObject(reflect.TypeOf(tt.give), tt.opts) + _, err := newResultObject(reflect.TypeOf(tt.give), tt.opts, false) require.Error(t, err) assert.Contains(t, err.Error(), tt.err) }) @@ -404,7 +464,7 @@ func TestWalkResult(t *testing.T) { } }{}) - ro, err := newResultObject(typ, resultOptions{}) + ro, err := newResultObject(typ, resultOptions{}, false) require.NoError(t, err) v := fakeResultVisits{