Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add InvokeOption and ProvideOption for Names #300

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type provideOptions struct {
Info *ProvideInfo
As []interface{}
Location *digreflect.Func
Names []string
}

func (o *provideOptions) Validate() error {
Expand Down Expand Up @@ -282,10 +283,40 @@ func LocationForPC(pc uintptr) ProvideOption {
})
}

type invokeOptions struct {
Names []string
}

func (*invokeOptions) Validate() error {
return nil
}

// An InvokeOption modifies the default behavior of Invoke. It's included for
// future functionality; currently, there are no concrete implementations.
type InvokeOption interface {
unimplemented()
applyInvokeOption(*invokeOptions)
}

type invokeOptionFunc func(*invokeOptions)

func (f invokeOptionFunc) applyInvokeOption(opts *invokeOptions) { f(opts) }

type InvokeAndProvideOption interface {
InvokeOption
ProvideOption
}

type namesOption []string

func (n namesOption) applyInvokeOption(opts *invokeOptions) {
opts.Names = n
}
func (n namesOption) applyProvideOption(opts *provideOptions) {
opts.Names = n
}

func Names(names ...string) InvokeAndProvideOption {
return namesOption(names)
}

// Container is a directed acyclic graph of types and their dependencies.
Expand Down Expand Up @@ -566,7 +597,15 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error {
return errf("can't invoke non-function %v (type %v)", function, ftype)
}

pl, err := newParamList(ftype)
var options invokeOptions
for _, o := range opts {
o.applyInvokeOption(&options)
}
if err := options.Validate(); err != nil {
return err
}

pl, err := newParamList(ftype, options.Names)
if err != nil {
return err
}
Expand Down Expand Up @@ -624,6 +663,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error {
ResultGroup: opts.Group,
ResultAs: opts.As,
Location: opts.Location,
ParamNames: opts.Names,
},
)
if err != nil {
Expand Down Expand Up @@ -842,14 +882,15 @@ type nodeOptions struct {
ResultGroup string
ResultAs []interface{}
Location *digreflect.Func
ParamNames []string
}

func newNode(ctor interface{}, opts nodeOptions) (*node, error) {
cval := reflect.ValueOf(ctor)
ctype := cval.Type()
cptr := cval.Pointer()

params, err := newParamList(ctype)
params, err := newParamList(ctype, opts.ParamNames)
if err != nil {
return nil, err
}
Expand Down
58 changes: 58 additions & 0 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,45 @@ func TestEndToEndSuccess(t *testing.T) {
}), "invoke should succeed, pulling out two named instances")
})

t.Run("named instances can be used to Provide another instance", func(t *testing.T) {
c := New()

type A struct{ idx int }

buildConstructor := func(idx int) func() A {
return func() A { return A{idx: idx} }
}

require.NoError(t, c.Provide(buildConstructor(1), Name("first")))
require.NoError(t, c.Provide(buildConstructor(2), Name("second")))
require.NoError(t, c.Provide(func(a A) int {
return a.idx + 5
}, Names("first")))

require.NoError(t, c.Invoke(func(i int) {
assert.Equal(t, 6, i)
}), "invoke should succeed, pulling out one named instances")
})

t.Run("named instances can be invoked Name option", func(t *testing.T) {
c := New()

type A struct{ idx int }

buildConstructor := func(idx int) func() A {
return func() A { return A{idx: idx} }
}

require.NoError(t, c.Provide(buildConstructor(1), Name("first")))
require.NoError(t, c.Provide(buildConstructor(2), Name("second")))
require.NoError(t, c.Provide(buildConstructor(3), Name("third")))

require.NoError(t, c.Invoke(func(a1 A, a3 A) {
assert.Equal(t, 1, a1.idx)
assert.Equal(t, 3, a3.idx)
}, Names("first", "third")), "invoke should succeed, using two named instances")
})

t.Run("named and unnamed instances coexist", func(t *testing.T) {
c := New()
type A struct{ idx int }
Expand All @@ -561,6 +600,25 @@ func TestEndToEndSuccess(t *testing.T) {
}))
})

t.Run("named and unnamed instances can be invoked with Names option", func(t *testing.T) {
c := New()

type A struct{ idx int }

buildConstructor := func(idx int) func() A {
return func() A { return A{idx: idx} }
}

require.NoError(t, c.Provide(buildConstructor(1), Name("first")))
require.NoError(t, c.Provide(buildConstructor(2), Name("second")))
require.NoError(t, c.Provide(buildConstructor(3)))

require.NoError(t, c.Invoke(func(a1 A, a3 A) {
assert.Equal(t, 1, a1.idx)
assert.Equal(t, 3, a3.idx)
}, Names("first")), "invoke should succeed, using two named instances")
})

t.Run("named instances recurse", func(t *testing.T) {
c := New()
type A struct{ idx int }
Expand Down
21 changes: 16 additions & 5 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ var (

// newParam builds a param from the given type. If the provided type is a
// dig.In struct, an paramObject will be returned.
func newParam(t reflect.Type) (param, error) {
func newParam(t reflect.Type, paramName string) (param, error) {
switch {
case IsOut(t) || (t.Kind() == reflect.Ptr && IsOut(t.Elem())) || embedsType(t, _outPtrType):
return nil, errf("cannot depend on result objects", "%v embeds a dig.Out", t)
case IsIn(t):
if paramName != "" {
return nil, errf("cannot have a paramName (%s) with a struct that has dig.In", paramName)
}
return newParamObject(t)
case embedsType(t, _inPtrType):
return nil, errf(
Expand All @@ -77,7 +80,7 @@ func newParam(t reflect.Type) (param, error) {
"cannot depend on a pointer to a parameter object, use a value instead",
"%v is a pointer to a struct that embeds dig.In", t)
default:
return paramSingle{Type: t}, nil
return paramSingle{Type: t, Name: paramName}, nil
}
}

Expand Down Expand Up @@ -158,7 +161,7 @@ func (pl paramList) DotParam() []*dot.Param {
//
// Variadic arguments of a constructor are ignored and not included as
// dependencies.
func newParamList(ctype reflect.Type) (paramList, error) {
func newParamList(ctype reflect.Type, names []string) (paramList, error) {
numArgs := ctype.NumIn()
if ctype.IsVariadic() {
// NOTE: If the function is variadic, we skip the last argument
Expand All @@ -171,8 +174,16 @@ func newParamList(ctype reflect.Type) (paramList, error) {
Params: make([]param, 0, numArgs),
}

if numArgs < len(names) {
return pl, errf("can't create a constructor with more names=%s than args=%s", names, ctype)
}

for i := 0; i < numArgs; i++ {
p, err := newParam(ctype.In(i))
name := ""
if i < len(names) {
name = names[i]
}
p, err := newParam(ctype.In(i), name)
if err != nil {
return pl, errf("bad argument %d", i+1, err)
}
Expand Down Expand Up @@ -370,7 +381,7 @@ func newParamObjectField(idx int, f reflect.StructField) (paramObjectField, erro

default:
var err error
p, err = newParam(f.Type)
p, err = newParam(f.Type, "")
if err != nil {
return pof, err
}
Expand Down
4 changes: 2 additions & 2 deletions param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
)

func TestParamListBuild(t *testing.T) {
p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }))
p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), []string{})
require.NoError(t, err)
assert.Panics(t, func() {
p.Build(New())
Expand Down Expand Up @@ -238,7 +238,7 @@ func TestParamVisitorChecksEverything(t *testing.T) {

pl, err := newParamList(reflect.TypeOf(func(io.Reader, params, io.Writer) {
t.Fatalf("this function should not be called")
}))
}), []string{})
require.NoError(t, err)

idx := 0
Expand Down