From 96e0e86564a82acaefe09ecefef03e9ccf5c4df4 Mon Sep 17 00:00:00 2001 From: David Muir Sharnoff Date: Mon, 6 Mar 2023 20:52:35 -0800 Subject: [PATCH] feat: limited support for running inner() in parallel --- api.go | 13 +++++++++++ doc.go | 44 ++++++++++++++++++----------------- generate.go | 20 +++++++++++----- matrix_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++ nject.go | 2 ++ 5 files changed, 115 insertions(+), 27 deletions(-) create mode 100644 matrix_test.go diff --git a/api.go b/api.go index 565a4a4..b4d68fb 100644 --- a/api.go +++ b/api.go @@ -300,6 +300,19 @@ func callsInner(fn interface{}) Provider { }) } +// Parallel annotates a wrap function to indicate that +// the inner function may be invoked in parallel. +// +// At the current time, support for this is very +// limited. Returned values cannot be propagated +// across such a call and the resulting lack of +// initialization can cause a panic. +func Parallel(fn interface{}) Provider { + return newThing(fn).modify(func(fm *provider) { + fm.parallel = true + }) +} + // TODO: add ExampleLoose // Loose annotates a wrap function to indicate that when trying diff --git a/doc.go b/doc.go index c0cd9c0..34185df 100644 --- a/doc.go +++ b/doc.go @@ -1,14 +1,13 @@ // Obligatory // comment /* - Package nject is a general purpose dependency injection framework. It provides wrapping, pruning, and indirect variable passing. It is type safe and using it requires no type assertions. There are two main injection APIs: Run and Bind. Bind is designed to be used at program initialization and does as much work as possible then rather than during main execution. -List of providers +# List of providers The API for nject is a list of providers (injectors) that are run in order. The final function in the list must be called. The other functions are called @@ -32,7 +31,7 @@ is a simple example: In this example, context.Background and log.Default are not invoked because their outputs are not used by the final function (http.ListenAndServe). -How to use +# How to use The basic idea of nject is to assemble a Collection of providers and then use that collection to supply inputs for functions that may use some or all of @@ -81,7 +80,7 @@ is 1st2nd -Collections +# Collections Providers are grouped as into linear sequences. When building an injection chain, the providers are grouped into several sets: LITERAL, STATIC, RUN. The LITERAL @@ -100,7 +99,7 @@ The STATIC set is composed of the cacheable injectors. The RUN set if everything else. -Injectors +# Injectors All injectors have the following type signature: @@ -114,7 +113,7 @@ are dropped from the handler chain. They are not invoked. Injectors that have no output values are a special case and they are always retained in the handler chain. -Cached injectors +# Cached injectors In injector that is annotated as Cacheable() may promoted to the STATIC set. An injector that is annotated as MustCache() must be promoted to @@ -139,7 +138,7 @@ is injected, all chains will share the same pointer. return &j })) -Memoized injectors +# Memoized injectors Injectors in the STATIC set are only run for initialization. For some things, like opening a database, that may still be too often. Injectors that are marked @@ -155,7 +154,7 @@ Memoized injectors may not have any inputs that are go maps, slices, or function Arrays, structs, and interfaces are okay. This requirement is recursive so a struct that that has a slice in it is not okay. -Fallible injectors +# Fallible injectors Fallible injectors are special injectors that change the behavior of the injection chain if they return error. Fallible injectors in the RUN set, that return error @@ -208,7 +207,7 @@ Some examples: return nil } -Wrap functions and middleware +# Wrap functions and middleware A wrap function interrupts the linear sequence of providers. It may or may invoke the remainder of the sequence that comes after it. The remainder of @@ -248,7 +247,11 @@ other kinds of functions: one call to reflect.MakeFunc(). Wrap functions serve the same role as middleware, but are usually easier to write. -Final functions +Wrap functions that invoke inner() multiple times in parallel are +are not well supported at this time and such invocations must have +the wrap function decorated with Parallel(). + +# Final functions Final functions are simply the last provider in the chain. They look like regular Go functions. Their input parameters come @@ -280,11 +283,11 @@ because they internally control if the downstream chain is called. return nil } -Literal values +# Literal values Literal values are values in the provider chain that are not functions. -Invalid provider chains +# Invalid provider chains Provider chains can be invalid for many reasons: inputs of a type not provided earlier in the chain; annotations that cannot be honored @@ -293,7 +296,7 @@ functions that take or return functions with an anymous type other than wrapper functions; A chain that does not terminate with a function; etc. Bind() and Run() will return error when presented with an invalid provider chain. -Panics +# Panics Bind() and Run() will return error rather than panic. After Bind()ing an init and invoke function, calling them will not panic unless a provider @@ -322,7 +325,7 @@ can be added with Shun(). var ErrorOfLastResort = nject.Shun(func() error { return nil }) -Chain evaluation +# Chain evaluation Bind() uses a complex and somewhat expensive O(n^2) set of rules to evaluate which providers should be included in a chain and which can be dropped. The goal @@ -350,14 +353,14 @@ from the closest provider. Providers that have unmet dependencies will be eliminated from the chain unless they're Required. -Best practices +# Best practices The remainder of this document consists of suggestions for how to use nject. Contributions to this section would be welcome. Also links to blogs or other discussions of using nject in practice. -For tests +# For tests The best practice for using nject inside a large project is to have a few common chains that everyone imports. @@ -398,7 +401,7 @@ to write tests. }) } -Displaying errors +# Displaying errors If nject cannot bind or run a chain, it will return error. The returned error is generally very good, but it does not contain the full debugging @@ -417,7 +420,7 @@ Remove the comments to hide the original type names. log.Fatal(err) } -Reorder +# Reorder The Reorder() decorator allows injection chains to be fully or partially reordered. Reorder is currently limited to a single pass and does not know which injectors are @@ -459,7 +462,7 @@ and used. OverrideThingOptions(thing.Option1, thing.Option2), ) -Self-cleaning +# Self-cleaning Recommended best practice is to have injectors shutdown the things they themselves start. They should do their own cleanup. @@ -500,7 +503,7 @@ defines. Wrapper functions have a small runtime performance penalty, so if you have more than a couple of providers that need cleanup, it makes sense to include something like CleaningService. -Forcing inclusion +# Forcing inclusion The normal direction of forced inclusion is that an upstream provider is required because a downstream provider uses a type produced by the upstream provider. @@ -514,6 +517,5 @@ produce a type that is only consumed by the downstream provider. Lastly, the providers can be grouped with Cluster so that they'll be included or excluded as a group. - */ package nject diff --git a/generate.go b/generate.go index 5114de1..201800e 100644 --- a/generate.go +++ b/generate.go @@ -3,6 +3,7 @@ package nject import ( "fmt" "reflect" + "sync/atomic" ) type valueCollection []reflect.Value @@ -197,7 +198,7 @@ func generateWrappers( in0Type, reflective := getInZero(fv) fm.wrapWrapper = func(v valueCollection, next func(valueCollection)) { vCopy := v.Copy() - callCount := 0 + var callCount int32 rTypes := make([]reflect.Type, len(fm.flows[receivedParams])) for i, tc := range fm.flows[receivedParams] { @@ -206,12 +207,19 @@ func generateWrappers( // for thread safety, this is not built outside WrapWrapper inner := func(i []reflect.Value) []reflect.Value { - if callCount > 0 { - copy(v, vCopy) + if !fm.parallel { + callCount++ + if callCount > 1 { + v = vCopy.Copy() + } + outMap(v, i) + next(v) + } else { + atomic.AddInt32(&callCount, 1) + vc := vCopy.Copy() + outMap(vc, i) + next(vc) } - callCount++ - outMap(v, i) - next(v) r := retMap(v) for i, retV := range r { if rTypes[i].Kind() == reflect.Interface { diff --git a/matrix_test.go b/matrix_test.go new file mode 100644 index 0000000..1cf8217 --- /dev/null +++ b/matrix_test.go @@ -0,0 +1,63 @@ +package nject_test + +import ( + "testing" + + "github.com/muir/nject" + "github.com/stretchr/testify/assert" +) + +type PT01 string +type PT02 string +type PT03 string +type PT04 string + +func TestParallelCallsToInner(t *testing.T) { + t.Parallel() + assert.NoError(t, nject.Run(t.Name(), + t, + nject.Parallel(func(inner func(*testing.T, PT01), t *testing.T) { + for _, s := range []PT01{"A1", "A2", "A3", "A4"} { + s := s + t.Run(string(s), func(t *testing.T) { + t.Log("branching") + t.Parallel() + inner(t, s) + }) + } + }), + nject.Parallel(func(inner func(*testing.T, PT02), t *testing.T) { + for _, s := range []PT02{"B1", "B2", "B3", "B4"} { + s := s + t.Run(string(s), func(t *testing.T) { + t.Log("branching") + t.Parallel() + inner(t, s) + }) + } + }), + nject.Parallel(func(inner func(*testing.T, PT03), t *testing.T) { + for _, s := range []PT03{"C1", "C2", "C3", "C4"} { + s := s + t.Run(string(s), func(t *testing.T) { + t.Log("branching") + t.Parallel() + inner(t, s) + }) + } + }), + nject.Parallel(func(inner func(*testing.T, PT04), t *testing.T) { + for _, s := range []PT04{"D1", "D2", "D3", "D4"} { + s := s + t.Run(string(s), func(t *testing.T) { + t.Log("branching") + t.Parallel() + inner(t, s) + }) + } + }), + func(t *testing.T, a PT01, b PT02, c PT03, d PT04) { + assert.Equal(t, t.Name(), "TestParallelCallsToInner/"+string(a)+"/"+string(b)+"/"+string(c)+"/"+string(d)) + }, + )) +} diff --git a/nject.go b/nject.go index 4e483b2..ae04d76 100644 --- a/nject.go +++ b/nject.go @@ -34,6 +34,7 @@ type provider struct { consumptionOptional bool singleton bool cluster int32 + parallel bool // added by characterize memoized bool @@ -94,6 +95,7 @@ func (fm *provider) copy() *provider { consumptionOptional: fm.consumptionOptional, singleton: fm.singleton, cluster: fm.cluster, + parallel: fm.parallel, memoized: fm.memoized, class: fm.class, group: fm.group,