From 502c0a1fdf3704dc0aa22e6e3d7d0f3476b8a9e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=9C=E6=9C=A8?= Date: Sun, 7 May 2023 22:36:17 +0800 Subject: [PATCH 1/2] feat: dig out anonymous field --- result.go | 31 ++++++++++++++++++++----- result_test.go | 61 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/result.go b/result.go index 369cd218..560acccf 100644 --- a/result.go +++ b/result.go @@ -74,7 +74,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) { case isError(t): return nil, newErrInvalidInput("cannot return an error here, return it from the constructor instead", nil) case IsOut(t): - return newResultObject(t, opts) + return newResultObject(t, opts, false) case embedsType(t, _outPtrType): return nil, newErrInvalidInput(fmt.Sprintf( "cannot build a result object by embedding *dig.Out, embed dig.Out instead: %v embeds *dig.Out", t), nil) @@ -353,7 +353,7 @@ func (ro resultObject) DotResult() []*dot.Result { return types } -func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { +func newResultObject(t reflect.Type, opts resultOptions, anonymous bool) (resultObject, error) { ro := resultObject{Type: t} if len(opts.Name) > 0 { return ro, newErrInvalidInput(fmt.Sprintf( @@ -372,19 +372,38 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { continue } + if anonymous && !f.IsExported() { + continue + } + rof, err := newResultObjectField(i, f, opts) if err != nil { return ro, newErrInvalidInput(fmt.Sprintf("bad field %q of %v", f.Name, t), err) } ro.Fields = append(ro.Fields, rof) + + if f.Anonymous { + subRo, err := newResultObject(f.Type, opts, true) + if err != nil { + return resultObject{}, err + } + for _, subField := range subRo.Fields { + subField.FieldIndices = append([]int{i}, subField.FieldIndices...) + ro.Fields = append(ro.Fields, subField) + } + } } return ro, nil } func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) + var rv reflect.Value = v + for _, fieldIndex := range f.FieldIndices { + rv = rv.Field(fieldIndex) + } + f.Result.Extract(cw, decorated, rv) } } @@ -397,7 +416,7 @@ type resultObjectField struct { // // We need to track this separately because not all fields of the struct // map to results. - FieldIndex int + FieldIndices []int // Result produced by this field. Result result @@ -411,8 +430,8 @@ func (rof resultObjectField) DotResult() []*dot.Result { // f at index i. func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (resultObjectField, error) { rof := resultObjectField{ - FieldName: f.Name, - FieldIndex: idx, + FieldName: f.Name, + FieldIndices: []int{idx}, } var r result diff --git a/result_test.go b/result_test.go index c19db20d..f6d682a1 100644 --- a/result_test.go +++ b/result_test.go @@ -116,6 +116,11 @@ func TestNewResultErrors(t *testing.T) { } func TestNewResultObject(t *testing.T) { + type Embed struct { + Writer io.Writer + } + + typeOfEmbed := reflect.TypeOf(&Embed{}).Elem() typeOfReader := reflect.TypeOf((*io.Reader)(nil)).Elem() typeOfWriter := reflect.TypeOf((*io.Writer)(nil)).Elem() @@ -137,14 +142,14 @@ func TestNewResultObject(t *testing.T) { }{}, wantFields: []resultObjectField{ { - FieldName: "Reader", - FieldIndex: 1, - Result: resultSingle{Type: typeOfReader}, + FieldName: "Reader", + FieldIndices: []int{1}, + Result: resultSingle{Type: typeOfReader}, }, { - FieldName: "Writer", - FieldIndex: 2, - Result: resultSingle{Type: typeOfWriter}, + FieldName: "Writer", + FieldIndices: []int{2}, + Result: resultSingle{Type: typeOfWriter}, }, }, }, @@ -158,14 +163,14 @@ func TestNewResultObject(t *testing.T) { }{}, wantFields: []resultObjectField{ { - FieldName: "A", - FieldIndex: 1, - Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, + FieldName: "A", + FieldIndices: []int{1}, + Result: resultSingle{Name: "stream-a", Type: typeOfWriter}, }, { - FieldName: "B", - FieldIndex: 2, - Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, + FieldName: "B", + FieldIndices: []int{2}, + Result: resultSingle{Name: "stream-b", Type: typeOfWriter}, }, }, }, @@ -178,9 +183,29 @@ func TestNewResultObject(t *testing.T) { }{}, wantFields: []resultObjectField{ { - FieldName: "Writer", - FieldIndex: 1, - Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + FieldName: "Writer", + FieldIndices: []int{1}, + Result: resultGrouped{Group: "writers", Type: typeOfWriter}, + }, + }, + }, + { + desc: "anonymous", + give: struct { + Out + + Embed + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Embed", + FieldIndices: []int{1}, + Result: resultSingle{Name: "", Type: typeOfEmbed}, + }, + { + FieldName: "Writer", + FieldIndices: []int{1, 0}, + Result: resultSingle{Name: "", Type: typeOfWriter}, }, }, }, @@ -188,7 +213,7 @@ func TestNewResultObject(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts) + got, err := newResultObject(reflect.TypeOf(tt.give), tt.opts, false) require.NoError(t, err) assert.Equal(t, tt.wantFields, got.Fields) }) @@ -302,7 +327,7 @@ func TestNewResultObjectErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - _, err := newResultObject(reflect.TypeOf(tt.give), tt.opts) + _, err := newResultObject(reflect.TypeOf(tt.give), tt.opts, false) require.Error(t, err) assert.Contains(t, err.Error(), tt.err) }) @@ -404,7 +429,7 @@ func TestWalkResult(t *testing.T) { } }{}) - ro, err := newResultObject(typ, resultOptions{}) + ro, err := newResultObject(typ, resultOptions{}, false) require.NoError(t, err) v := fakeResultVisits{ From 9338321dc1ce6cf9cea39b3738bf6aa5f7b253ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=9C=E6=9C=A8?= Date: Mon, 8 May 2023 00:14:20 +0800 Subject: [PATCH 2/2] chore: add tag extra-anonymous --- result.go | 51 +++++++++++++++++++++++++++++++++++++++++--------- result_test.go | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/result.go b/result.go index 560acccf..f5e37494 100644 --- a/result.go +++ b/result.go @@ -23,11 +23,16 @@ package dig import ( "fmt" "reflect" + "strings" "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/dot" ) +const ( + _extraAnonymous = "extra-anonymous" +) + // The result interface represents a result produced by a constructor. // // The following implementations exist: @@ -383,20 +388,48 @@ func newResultObject(t reflect.Type, opts resultOptions, anonymous bool) (result ro.Fields = append(ro.Fields, rof) - if f.Anonymous { - subRo, err := newResultObject(f.Type, opts, true) - if err != nil { - return resultObject{}, err - } - for _, subField := range subRo.Fields { - subField.FieldIndices = append([]int{i}, subField.FieldIndices...) - ro.Fields = append(ro.Fields, subField) - } + if !f.Anonymous || f.Tag.Get(_extraAnonymous) != "true" { + continue + } + if err = extraAnonymous(&ro, &f, &rof, opts); err != nil { + return ro, err } } return ro, nil } +func extraAnonymous(ro *resultObject, f *reflect.StructField, rof *resultObjectField, opts resultOptions) error { + ft := f.Type + if ft.Kind() == reflect.Pointer { + ft = ft.Elem() + } + subRo, err := newResultObject(ft, opts, true) + if err != nil { + return err + } + + for _, subField := range subRo.Fields { + subField.FieldIndices = append(rof.FieldIndices, subField.FieldIndices...) + switch rofResult := rof.Result.(type) { + case resultGrouped: + switch subResult := subField.Result.(type) { + case resultGrouped: + subResult.Group = strings.Join([]string{rofResult.Group, subResult.Group}, ",") + case resultSingle: + subField.Result = resultGrouped{ + Group: rofResult.Group, + Flatten: rofResult.Flatten, + Type: subResult.Type, + } + } + } + + ro.Fields = append(ro.Fields, subField) + } + + return nil +} + func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { var rv reflect.Value = v diff --git a/result_test.go b/result_test.go index f6d682a1..abd3643e 100644 --- a/result_test.go +++ b/result_test.go @@ -196,6 +196,21 @@ func TestNewResultObject(t *testing.T) { Embed }{}, + wantFields: []resultObjectField{ + { + FieldName: "Embed", + FieldIndices: []int{1}, + Result: resultSingle{Name: "", Type: typeOfEmbed}, + }, + }, + }, + { + desc: "anonymous", + give: struct { + Out + + Embed `extra-anonymous:"true"` + }{}, wantFields: []resultObjectField{ { FieldName: "Embed", @@ -209,6 +224,26 @@ func TestNewResultObject(t *testing.T) { }, }, }, + { + desc: "anonymous group", + give: struct { + Out + + Embed `extra-anonymous:"true" group:"embed_group"` + }{}, + wantFields: []resultObjectField{ + { + FieldName: "Embed", + FieldIndices: []int{1}, + Result: resultGrouped{Group: "embed_group", Type: typeOfEmbed}, + }, + { + FieldName: "Writer", + FieldIndices: []int{1, 0}, + Result: resultGrouped{Group: "embed_group", Type: typeOfWriter}, + }, + }, + }, } for _, tt := range tests {