Skip to content

Commit 7dd18c6

Browse files
authored
fx.Self: a parameter to fx.As for providing a type as itself (#1201)
We frequently see requests for folks who want to use `fx.As` to provide a type as another type, while also providing it as itself. To name a few: * #1196 * #1148 * #1079 This is currently not possible via strictly using `fx.As` + a single constructor, since `fx.As` causes a constructor to no longer provide its original type. The workaround we often give is for folks to do something like this, which involves adding a second "constructor": ```go fx.Provide( newConcreteType, func(ct *concreteType) Interface { return ct } ) ``` which is admittedly not very ergonomic. A somewhat common pattern mistaken to be a workaround is providing the constructor twice instead: ```go fx.Provide( newConcreteType, fx.Annotate(newConcreteType, fx.As(new(Interface))), ) ``` This PR adds `fx.Self()`, which returns a special value to indicate an `fx.As` should retain that original return type: ```go fx.Provide( newConcreteType, fx.As(fx.Self()), fx.As(new(Interface)), ) ``` As an alternative, I considered a new annotation altogether, named something like `fx.AlsoAs`, but adding a special type that can be passed as an argument to `fx.As` directly allows for more fine-tuned control over individual positional return values. For example, this function's return types can be easily expressed as `*asStringer` and `io.Writer` using `fx.Self()`: ```go fx.Provide( fx.Annotate( func() (*asStringer, *bytes.Buffer) {/* ... */ }, fx.As(fx.Self(), new(io.Writer)), // return values will be: *asStringer, io.Writer ), ), ``` Whereas something like `fx.AlsoAs` wouldn't provide the ability to skip over the first positional return value entirely.
1 parent cb9cccf commit 7dd18c6

File tree

2 files changed

+161
-10
lines changed

2 files changed

+161
-10
lines changed

annotated.go

+61-10
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,19 @@ func OnStop(onStop interface{}) Annotation {
10971097

10981098
type asAnnotation struct {
10991099
targets []interface{}
1100-
types []reflect.Type
1100+
types []asType
1101+
}
1102+
1103+
type asType struct {
1104+
self bool
1105+
typ reflect.Type // May be nil if self is true.
1106+
}
1107+
1108+
func (a asType) String() string {
1109+
if a.self {
1110+
return "self"
1111+
}
1112+
return a.typ.String()
11011113
}
11021114

11031115
func isOut(t reflect.Type) bool {
@@ -1119,7 +1131,7 @@ var _ Annotation = (*asAnnotation)(nil)
11191131
// bytes.NewBuffer (bytes.Buffer) should be provided as io.Writer type:
11201132
//
11211133
// fx.Provide(
1122-
// fx.Annotate(bytes.NewBuffer(...), fx.As(new(io.Writer)))
1134+
// fx.Annotate(bytes.NewBuffer, fx.As(new(io.Writer)))
11231135
// )
11241136
//
11251137
// In other words, the code above is equivalent to:
@@ -1157,15 +1169,50 @@ func As(interfaces ...interface{}) Annotation {
11571169
return &asAnnotation{targets: interfaces}
11581170
}
11591171

1172+
// Self returns a special value that can be passed to [As] to indicate
1173+
// that a type should be provided as its original type, in addition to whatever other
1174+
// types it gets provided as via other [As] annotations.
1175+
//
1176+
// For example,
1177+
//
1178+
// fx.Provide(
1179+
// fx.Annotate(
1180+
// bytes.NewBuffer,
1181+
// fx.As(new(io.Writer)),
1182+
// fx.As(fx.Self()),
1183+
// )
1184+
// )
1185+
//
1186+
// Is equivalent to,
1187+
//
1188+
// fx.Provide(
1189+
// bytes.NewBuffer,
1190+
// func(b *bytes.Buffer) io.Writer {
1191+
// return b
1192+
// },
1193+
// )
1194+
//
1195+
// in that it provides the same *bytes.Buffer instance
1196+
// as both a *bytes.Buffer and an io.Writer.
1197+
func Self() any {
1198+
return &self{}
1199+
}
1200+
1201+
type self struct{}
1202+
11601203
func (at *asAnnotation) apply(ann *annotated) error {
1161-
at.types = make([]reflect.Type, len(at.targets))
1204+
at.types = make([]asType, len(at.targets))
11621205
for i, typ := range at.targets {
1206+
if _, ok := typ.(*self); ok {
1207+
at.types[i] = asType{self: true}
1208+
continue
1209+
}
11631210
t := reflect.TypeOf(typ)
11641211
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Interface {
11651212
return fmt.Errorf("fx.As: argument must be a pointer to an interface: got %v", t)
11661213
}
11671214
t = t.Elem()
1168-
at.types[i] = t
1215+
at.types[i] = asType{typ: t}
11691216
}
11701217

11711218
ann.As = append(ann.As, at.types)
@@ -1209,12 +1256,16 @@ func (at *asAnnotation) results(ann *annotated) (
12091256
Type: t,
12101257
Tag: f.Tag,
12111258
}
1212-
if i < len(at.types) {
1213-
if !t.Implements(at.types[i]) {
1214-
return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i])
1215-
}
1216-
field.Type = at.types[i]
1259+
1260+
if i >= len(at.types) || at.types[i].self {
1261+
fields = append(fields, field)
1262+
continue
1263+
}
1264+
1265+
if !t.Implements(at.types[i].typ) {
1266+
return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i])
12171267
}
1268+
field.Type = at.types[i].typ
12181269
fields = append(fields, field)
12191270
}
12201271
resType := reflect.StructOf(fields)
@@ -1475,7 +1526,7 @@ type annotated struct {
14751526
Annotations []Annotation
14761527
ParamTags []string
14771528
ResultTags []string
1478-
As [][]reflect.Type
1529+
As [][]asType
14791530
From []reflect.Type
14801531
FuncPtr uintptr
14811532
Hooks []*lifecycleHookAnnotation

annotated_test.go

+100
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,12 @@ func TestAnnotatedAs(t *testing.T) {
433433

434434
S fmt.Stringer `name:"goodStringer"`
435435
}
436+
type inSelf struct {
437+
fx.In
438+
439+
S1 fmt.Stringer `name:"goodStringer"`
440+
S2 *asStringer `name:"goodStringer"`
441+
}
436442
type myStringer interface {
437443
String() string
438444
}
@@ -699,6 +705,100 @@ func TestAnnotatedAs(t *testing.T) {
699705
},
700706
startApp: true,
701707
},
708+
{
709+
desc: "self w other As annotations",
710+
provide: fx.Provide(
711+
fx.Annotate(
712+
func() *asStringer {
713+
return &asStringer{name: "stringer"}
714+
},
715+
fx.As(fx.Self()),
716+
fx.As(new(fmt.Stringer)),
717+
),
718+
),
719+
invoke: func(s fmt.Stringer, as *asStringer) {
720+
assert.Equal(t, "stringer", s.String())
721+
assert.Equal(t, "stringer", as.String())
722+
},
723+
},
724+
{
725+
desc: "self as one As target",
726+
provide: fx.Provide(
727+
fx.Annotate(
728+
func() (*asStringer, *bytes.Buffer) {
729+
s := &asStringer{name: "stringer"}
730+
b := &bytes.Buffer{}
731+
return s, b
732+
},
733+
fx.As(fx.Self(), new(io.Writer)),
734+
),
735+
),
736+
invoke: func(s *asStringer, w io.Writer) {
737+
assert.Equal(t, "stringer", s.String())
738+
_, err := w.Write([]byte("."))
739+
assert.NoError(t, err)
740+
},
741+
},
742+
{
743+
desc: "two as, two self, four types",
744+
provide: fx.Provide(
745+
fx.Annotate(
746+
func() (*asStringer, *bytes.Buffer) {
747+
s := &asStringer{name: "stringer"}
748+
b := &bytes.Buffer{}
749+
return s, b
750+
},
751+
fx.As(fx.Self(), new(io.Writer)),
752+
fx.As(new(fmt.Stringer)),
753+
),
754+
),
755+
invoke: func(s1 *asStringer, s2 fmt.Stringer, b *bytes.Buffer, w io.Writer) {
756+
assert.Equal(t, "stringer", s1.String())
757+
assert.Equal(t, "stringer", s2.String())
758+
_, err := w.Write([]byte("."))
759+
assert.NoError(t, err)
760+
_, err = b.Write([]byte("."))
761+
assert.NoError(t, err)
762+
},
763+
},
764+
{
765+
desc: "self with lifecycle hook",
766+
provide: fx.Provide(
767+
fx.Annotate(
768+
func() *asStringer {
769+
return &asStringer{name: "stringer"}
770+
},
771+
fx.As(fx.Self()),
772+
fx.As(new(fmt.Stringer)),
773+
fx.OnStart(func(s fmt.Stringer, as *asStringer) {
774+
assert.Equal(t, "stringer", s.String())
775+
assert.Equal(t, "stringer", as.String())
776+
}),
777+
),
778+
),
779+
invoke: func(s fmt.Stringer, as *asStringer) {
780+
assert.Equal(t, "stringer", s.String())
781+
assert.Equal(t, "stringer", as.String())
782+
},
783+
startApp: true,
784+
},
785+
{
786+
desc: "self with result tags",
787+
provide: fx.Provide(
788+
fx.Annotate(
789+
func() *asStringer {
790+
return &asStringer{name: "stringer"}
791+
},
792+
fx.As(fx.Self()),
793+
fx.As(new(fmt.Stringer)),
794+
fx.ResultTags(`name:"goodStringer"`),
795+
),
796+
),
797+
invoke: func(i inSelf) {
798+
assert.Equal(t, "stringer", i.S1.String())
799+
assert.Equal(t, "stringer", i.S2.String())
800+
},
801+
},
702802
}
703803

704804
for _, tt := range tests {

0 commit comments

Comments
 (0)