From 2d9f8071dcc54e5346d0966d70bf50d1803ea359 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 24 Aug 2022 14:30:01 -0400 Subject: [PATCH] Add combining state support (#22826) --- sdks/go/pkg/beam/core/graph/fn.go | 4 +- sdks/go/pkg/beam/core/graph/fn_test.go | 11 + .../pkg/beam/core/runtime/exec/translate.go | 20 +- .../pkg/beam/core/runtime/exec/userstate.go | 53 ++++- .../beam/core/runtime/exec/userstate_test.go | 4 +- .../pkg/beam/core/runtime/graphx/translate.go | 32 +++ sdks/go/pkg/beam/core/state/state.go | 159 +++++++++++++ sdks/go/pkg/beam/core/state/state_test.go | 224 +++++++++++++++++- 8 files changed, 485 insertions(+), 22 deletions(-) diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go index 9a9f517862239..837c8a413776c 100644 --- a/sdks/go/pkg/beam/core/graph/fn.go +++ b/sdks/go/pkg/beam/core/graph/fn.go @@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error { "unique per DoFn", k, orig, s) } t := s.StateType() - if t != state.TypeValue && t != state.TypeBag { + if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining { err := errors.Errorf("Unrecognized state type %v for state %v", t, s) return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+ - "type is state.Value and state.Bag", t, s) + "types are state.Value, state.Combining, and state.Bag", t, s) } stateKeys[k] = s } diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go index c2727298b0fa9..19647d88cbb3d 100644 --- a/sdks/go/pkg/beam/core/graph/fn_test.go +++ b/sdks/go/pkg/beam/core/graph/fn_test.go @@ -55,6 +55,9 @@ func TestNewDoFn(t *testing.T) { {dfn: &GoodDoFnCoGbk1wSide{}, opt: NumMainInputs(MainKv)}, {dfn: &GoodStatefulDoFn{State1: state.MakeValueState[int]("state1")}, opt: NumMainInputs(MainKv)}, {dfn: &GoodStatefulDoFn2{State1: state.MakeBagState[int]("state1")}, opt: NumMainInputs(MainKv)}, + {dfn: &GoodStatefulDoFn3{State1: state.MakeCombiningState[int, int, int]("state1", func(a, b int) int { + return a * b + })}, opt: NumMainInputs(MainKv)}, } for _, test := range tests { @@ -1096,6 +1099,14 @@ func (fn *GoodStatefulDoFn2) ProcessElement(state.Provider, int, int) int { return 0 } +type GoodStatefulDoFn3 struct { + State1 state.Combining[int, int, int] +} + +func (fn *GoodStatefulDoFn3) ProcessElement(state.Provider, int, int) int { + return 0 +} + // Examples of incorrect SDF signatures. // Examples with missing methods. diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go index 65b0de9d48d11..fc4844010fa67 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/translate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go @@ -467,13 +467,29 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { if len(userState) > 0 { stateIDToCoder := make(map[string]*coder.Coder) + stateIDToCombineFn := make(map[string]*graph.CombineFn) for key, spec := range userState { - // TODO(#22736) - this will eventually need to be aware of which type of state its modifying to support non-Value state types. var cID string if rmw := spec.GetReadModifyWriteSpec(); rmw != nil { cID = rmw.CoderId } else if bs := spec.GetBagSpec(); bs != nil { cID = bs.ElementCoderId + } else if cs := spec.GetCombiningSpec(); cs != nil { + cID = cs.AccumulatorCoderId + cmbData := string(cs.GetCombineFn().GetPayload()) + var cmbTp v1pb.TransformPayload + if err := protox.DecodeBase64(cmbData, &cmbTp); err != nil { + return nil, errors.Wrapf(err, "invalid transform payload %v for %v", cmbData, transform) + } + _, fn, _, _, _, err := graphx.DecodeMultiEdge(cmbTp.GetEdge()) + if err != nil { + return nil, err + } + cfn, err := graph.AsCombineFn(fn) + if err != nil { + return nil, err + } + stateIDToCombineFn[key] = cfn } c, err := b.coders.Coder(cID) if err != nil { @@ -489,7 +505,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) { if err != nil { return nil, err } - n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder) + n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToCombineFn) } } diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go index 05105cda78f68..2587530c838a6 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go @@ -20,9 +20,11 @@ import ( "fmt" "io" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" ) type stateProvider struct { @@ -39,6 +41,7 @@ type stateProvider struct { appendersByKey map[string]io.Writer clearersByKey map[string]io.Writer codersByKey map[string]*coder.Coder + combineFnsByKey map[string]*graph.CombineFn } // ReadValueState reads a value state from the State API @@ -159,6 +162,40 @@ func (s *stateProvider) WriteBagState(val state.Transaction) error { return nil } +func (s *stateProvider) CreateAccumulatorFn(userStateID string) reflectx.Func { + a := s.combineFnsByKey[userStateID] + if ca := a.CreateAccumulatorFn(); ca != nil { + return ca.Fn + } + return nil +} + +func (s *stateProvider) AddInputFn(userStateID string) reflectx.Func { + a := s.combineFnsByKey[userStateID] + if ai := a.AddInputFn(); ai != nil { + return ai.Fn + } + + return nil +} + +func (s *stateProvider) MergeAccumulatorsFn(userStateID string) reflectx.Func { + a := s.combineFnsByKey[userStateID] + if ma := a.MergeAccumulatorsFn(); ma != nil { + return ma.Fn + } + + return nil +} + +func (s *stateProvider) ExtractOutputFn(userStateID string) reflectx.Func { + a := s.combineFnsByKey[userStateID] + if eo := a.ExtractOutputFn(); eo != nil { + return eo.Fn + } + return nil +} + func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) { if r, ok := s.readersByKey[userStateID]; ok { return r, nil @@ -201,16 +238,17 @@ type UserStateAdapter interface { } type userStateAdapter struct { - sid StreamID - wc WindowEncoder - kc ElementEncoder - stateIDToCoder map[string]*coder.Coder - c *coder.Coder + sid StreamID + wc WindowEncoder + kc ElementEncoder + stateIDToCoder map[string]*coder.Coder + stateIDToCombineFn map[string]*graph.CombineFn + c *coder.Coder } // NewUserStateAdapter returns a user state adapter for the given StreamID and coder. // It expects a W or W> coder, because the protocol requires windowing information. -func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder) UserStateAdapter { +func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter { if !coder.IsW(c) { panic(fmt.Sprintf("expected WV coder for user state %v: %v", sid, c)) } @@ -220,7 +258,7 @@ func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string if coder.IsKV(coder.SkipW(c)) { kc = MakeElementEncoder(coder.SkipW(c).Components[0]) } - return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder} + return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToCombineFn: stateIDToCombineFn} } // NewStateProvider creates a stateProvider with the ability to talk to the state API. @@ -249,6 +287,7 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea readersByKey: make(map[string]io.ReadCloser), appendersByKey: make(map[string]io.Writer), clearersByKey: make(map[string]io.Writer), + combineFnsByKey: s.stateIDToCombineFn, codersByKey: s.stateIDToCoder, } diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go b/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go index a4091bc7d6580..7efb703b62266 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go @@ -22,6 +22,7 @@ import ( "reflect" "testing" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/coderx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" @@ -86,7 +87,8 @@ func buildStateProvider() stateProvider { readersByKey: make(map[string]io.ReadCloser), appendersByKey: make(map[string]io.Writer), clearersByKey: make(map[string]io.Writer), - codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed + combineFnsByKey: make(map[string]*graph.CombineFn), // Each test can specify coders as needed + codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed } } diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go index 63972be645a85..3774bf71a2ce4 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go @@ -493,6 +493,38 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) { Urn: URNBagUserState, }, } + case state.TypeCombining: + cps := ps.(state.CombiningPipelineState).GetCombineFn() + f, err := graph.NewFn(cps) + if err != nil { + return handleErr(err) + } + cf, err := graph.AsCombineFn(f) + if err != nil { + return handleErr(err) + } + me := graph.MultiEdge{ + Op: graph.Combine, + CombineFn: cf, + } + mustEncodeMultiEdge, err := mustEncodeMultiEdgeBase64(&me) + if err != nil { + return handleErr(err) + } + stateSpecs[ps.StateKey()] = &pipepb.StateSpec{ + Spec: &pipepb.StateSpec_CombiningSpec{ + CombiningSpec: &pipepb.CombiningStateSpec{ + AccumulatorCoderId: coderID, + CombineFn: &pipepb.FunctionSpec{ + Urn: "beam:combinefn:gosdk:v1", + Payload: []byte(mustEncodeMultiEdge), + }, + }, + }, + Protocol: &pipepb.FunctionSpec{ + Urn: URNBagUserState, + }, + } default: return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps) } diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go index d655554e825b7..b9730a6405f4e 100644 --- a/sdks/go/pkg/beam/core/state/state.go +++ b/sdks/go/pkg/beam/core/state/state.go @@ -17,7 +17,10 @@ package state import ( + "fmt" "reflect" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" ) // TransactionTypeEnum represents the type of state transaction (e.g. set, clear) @@ -37,6 +40,8 @@ const ( TypeValue TypeEnum = 0 // TypeBag represents a bag state TypeBag TypeEnum = 1 + // TypeCombining represents a combining state + TypeCombining TypeEnum = 2 ) var ( @@ -63,6 +68,10 @@ type Provider interface { WriteValueState(val Transaction) error ReadBagState(id string) ([]interface{}, []Transaction, error) WriteBagState(val Transaction) error + CreateAccumulatorFn(userStateID string) reflectx.Func + AddInputFn(userStateID string) reflectx.Func + MergeAccumulatorsFn(userStateID string) reflectx.Func + ExtractOutputFn(userStateID string) reflectx.Func } // PipelineState is an interface representing different kinds of PipelineState (currently just state.Value). @@ -73,6 +82,12 @@ type PipelineState interface { StateType() TypeEnum } +// CombiningPipelineState is an interface representing combining pipeline state. +// It is primarily meant for Beam packages to use and is probably not useful for most pipeline authors. +type CombiningPipelineState interface { + GetCombineFn() interface{} +} + // Value is used to read and write global pipeline state representing a single value. // Key represents the key used to lookup this state. type Value[T any] struct { @@ -209,3 +224,147 @@ func MakeBagState[T any](k string) Bag[T] { Key: k, } } + +// Combining is used to read and write global pipeline state representing a single combined value. +// It uses 3 generic values, [T1, T2, T3], to represent the accumulator, input, and output types respectively. +// Key represents the key used to lookup this state. +type Combining[T1, T2, T3 any] struct { + Key string + accumFn interface{} +} + +// Add is used to write add an element to the combining pipeline state. +func (s *Combining[T1, T2, T3]) Add(p Provider, val T2) error { + // We will always maintain a single accumulated value as a value state. + // Therefore, when we add we must first read the current accumulator so that we can add to it. + acc, ok, err := s.readAccumulator(p) + if err != nil { + return err + } + if !ok { + // If no accumulator, that means that the CreateAccumulator function doesn't exist + // and our value is our initial accumulator. + return p.WriteValueState(Transaction{ + Key: s.Key, + Type: TransactionTypeSet, + Val: val, + }) + } + + if ai := p.AddInputFn(s.Key); ai != nil { + var newVal interface{} + if f, ok := ai.(reflectx.Func2x1); ok { + newVal = f.Call2x1(acc, val) + } else { + newVal = f.Call([]interface{}{acc, val})[0] + } + return p.WriteValueState(Transaction{ + Key: s.Key, + Type: TransactionTypeSet, + Val: newVal, + }) + } + // If AddInput isn't defined, that means we must just have one accumulator type identical to the input type. + if ma := p.MergeAccumulatorsFn(s.Key); ma != nil { + var newVal interface{} + if f, ok := ma.(reflectx.Func2x1); ok { + newVal = f.Call2x1(acc, val) + } else { + newVal = f.Call([]interface{}{acc, val})[0] + } + return p.WriteValueState(Transaction{ + Key: s.Key, + Type: TransactionTypeSet, + Val: newVal, + }) + } + + // Should be taken care of by previous validation + panic(fmt.Sprintf("MergeAccumulators must be defined on accumulator %v", s)) +} + +// Read is used to read this instance of global pipeline state representing a combiner. +// When a value is not found, returns an empty list and false. +func (s *Combining[T1, T2, T3]) Read(p Provider) (T3, bool, error) { + acc, ok, err := s.readAccumulator(p) + if !ok || err != nil { + var val T3 + return val, ok, err + } + + if eo := p.ExtractOutputFn(s.Key); eo != nil { + f, ok := eo.(reflectx.Func1x1) + if ok { + return f.Call1x1(acc).(T3), true, nil + } + return f.Call([]interface{}{acc})[0].(T3), true, nil + } + + return acc.(T3), true, nil +} + +func (s *Combining[T1, T2, T3]) readAccumulator(p Provider) (interface{}, bool, error) { + // This replays any writes that have happened to this value since we last read + // For more detail, see "State Transactionality" below for buffered transactions + cur, bufferedTransactions, err := p.ReadValueState(s.Key) + if err != nil { + var val T1 + return val, false, err + } + for _, t := range bufferedTransactions { + switch t.Type { + case TransactionTypeSet: + cur = t.Val + case TransactionTypeClear: + cur = nil + } + } + if cur == nil { + if ca := p.CreateAccumulatorFn(s.Key); ca != nil { + f, ok := ca.(reflectx.Func0x1) + if ok { + return f.Call0x1(), true, nil + } + return f.Call([]interface{}{})[0], true, nil + } + var val T1 + return val, false, nil + } + + return cur, true, nil +} + +// StateKey returns the key for this pipeline state entry. +func (s Combining[T1, T2, T3]) StateKey() string { + if s.Key == "" { + // TODO(#22736) - infer the state from the member variable name during pipeline construction. + panic("Value state exists on struct but has not been initialized with a key.") + } + return s.Key +} + +// CoderType returns the type of the bag state which should be used for a coder. +func (s Combining[T1, T2, T3]) CoderType() reflect.Type { + var t T1 + return reflect.TypeOf(t) +} + +// StateType returns the type of the state (in this case always Bag). +func (s Combining[T1, T2, T3]) StateType() TypeEnum { + return TypeCombining +} + +func (s Combining[T1, T2, T3]) GetCombineFn() interface{} { + return s.accumFn +} + +// MakeCombiningState is a factory function to create an instance of Combining state with the given key and combiner +// when the combiner may have different types for its accumulator, input, and output. +// Takes 3 generic constraints [T1, T2, T3 any] representing the accumulator/input/output types respectively. +// If no accumulator or output types are defined, use the input type. +func MakeCombiningState[T1, T2, T3 any](k string, combiner interface{}) Combining[T1, T2, T3] { + return Combining[T1, T2, T3]{ + Key: k, + accumFn: combiner, + } +} diff --git a/sdks/go/pkg/beam/core/state/state_test.go b/sdks/go/pkg/beam/core/state/state_test.go index c8924560ad6c5..9ee99d790cf68 100644 --- a/sdks/go/pkg/beam/core/state/state_test.go +++ b/sdks/go/pkg/beam/core/state/state_test.go @@ -18,6 +18,8 @@ package state import ( "errors" "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" ) var ( @@ -25,10 +27,14 @@ var ( ) type fakeProvider struct { - initialState map[string]interface{} - initialBagState map[string][]interface{} - transactions map[string][]Transaction - err map[string]error + initialState map[string]interface{} + initialBagState map[string][]interface{} + transactions map[string][]Transaction + err map[string]error + createAccumForKey map[string]bool + addInputForKey map[string]bool + mergeAccumForKey map[string]bool + extractOutForKey map[string]bool } func (s *fakeProvider) ReadValueState(userStateID string) (interface{}, []Transaction, error) { @@ -73,6 +79,43 @@ func (s *fakeProvider) WriteBagState(val Transaction) error { return nil } +func (s *fakeProvider) CreateAccumulatorFn(userStateID string) reflectx.Func { + if s.createAccumForKey[userStateID] { + return reflectx.MakeFunc0x1(func() int { + return 1 + }) + } + + return nil +} +func (s *fakeProvider) AddInputFn(userStateID string) reflectx.Func { + if s.addInputForKey[userStateID] { + return reflectx.MakeFunc2x1(func(a, b int) int { + return a + b + }) + } + + return nil +} +func (s *fakeProvider) MergeAccumulatorsFn(userStateID string) reflectx.Func { + if s.mergeAccumForKey[userStateID] { + return reflectx.MakeFunc2x1(func(a, b int) int { + return a + b + }) + } + + return nil +} +func (s *fakeProvider) ExtractOutputFn(userStateID string) reflectx.Func { + if s.extractOutForKey[userStateID] { + return reflectx.MakeFunc1x1(func(a int) int { + return a * 100 + }) + } + + return nil +} + func TestValueRead(t *testing.T) { is := make(map[string]interface{}) ts := make(map[string][]Transaction) @@ -228,7 +271,7 @@ func TestBagRead(t *testing.T) { } } -func TestBagWrite(t *testing.T) { +func TestBagAdd(t *testing.T) { var tests = []struct { writes []int val []int @@ -251,13 +294,13 @@ func TestBagWrite(t *testing.T) { } val, ok, err := vs.Read(&f) if err != nil { - t.Errorf("Bag.Write() returned error %v when it shouldn't have after writing: %v", err, tt.writes) + t.Errorf("Bag.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes) } else if ok && !tt.ok { - t.Errorf("Bag.Write() returned a value %v when it shouldn't have after writing: %v", val, tt.writes) + t.Errorf("Bag.Read() returned a value %v when it shouldn't have after writing: %v", val, tt.writes) } else if !ok && tt.ok { - t.Errorf("Bag.Write() didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes) + t.Errorf("Bag.Red() didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes) } else if len(val) != len(tt.val) { - t.Errorf("Bag.Write()=%v, want %v after writing: %v", val, tt.val, tt.writes) + t.Errorf("Bag.Read()=%v, want %v after writing: %v", val, tt.val, tt.writes) } else { eq := true for idx, v := range val { @@ -266,8 +309,169 @@ func TestBagWrite(t *testing.T) { } } if !eq { - t.Errorf("Bag.Write()=%v, want %v after writing: %v", val, tt.val, tt.writes) + t.Errorf("Bag.Read()=%v, want %v after writing: %v", val, tt.val, tt.writes) } } } } + +func TestCombiningRead(t *testing.T) { + is := make(map[string]interface{}) + ts := make(map[string][]Transaction) + es := make(map[string]error) + ca := make(map[string]bool) + eo := make(map[string]bool) + ts["no_transactions"] = nil + ts["no_transactions_initial_accum"] = nil + ca["no_transactions_initial_accum"] = true + ts["no_transactions_initial_accum_extract_out"] = nil + ca["no_transactions_initial_accum_extract_out"] = true + eo["no_transactions_initial_accum_extract_out"] = true + is["basic_set"] = 1 + ts["basic_set"] = []Transaction{{Key: "basic_set", Type: TransactionTypeSet, Val: 3}} + is["basic_clear"] = 1 + ts["basic_clear"] = []Transaction{{Key: "basic_clear", Type: TransactionTypeClear, Val: nil}} + is["set_then_clear"] = 1 + ts["set_then_clear"] = []Transaction{{Key: "set_then_clear", Type: TransactionTypeSet, Val: 3}, {Key: "set_then_clear", Type: TransactionTypeClear, Val: nil}} + is["set_then_clear_then_set"] = 1 + ts["set_then_clear_then_set"] = []Transaction{{Key: "set_then_clear_then_set", Type: TransactionTypeSet, Val: 3}, {Key: "set_then_clear_then_set", Type: TransactionTypeClear, Val: nil}, {Key: "set_then_clear_then_set", Type: TransactionTypeSet, Val: 4}} + is["err"] = 1 + ts["err"] = []Transaction{{Key: "err", Type: TransactionTypeSet, Val: 3}} + es["err"] = errFake + + f := fakeProvider{ + initialState: is, + transactions: ts, + err: es, + createAccumForKey: ca, + extractOutForKey: eo, + } + + var tests = []struct { + vs Combining[int, int, int] + val int + ok bool + err error + }{ + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), 0, false, nil}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), 1, true, nil}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum_extract_out", func(a, b int) int { + return a + b + }), 100, true, nil}, + {MakeCombiningState[int, int, int]("basic_set", func(a, b int) int { + return a + b + }), 3, true, nil}, + {MakeCombiningState[int, int, int]("basic_clear", func(a, b int) int { + return a + b + }), 0, false, nil}, + {MakeCombiningState[int, int, int]("set_then_clear", func(a, b int) int { + return a + b + }), 0, false, nil}, + {MakeCombiningState[int, int, int]("set_then_clear_then_set", func(a, b int) int { + return a + b + }), 4, true, nil}, + {MakeCombiningState[int, int, int]("err", func(a, b int) int { + return a + b + }), 0, false, errFake}, + } + + for _, tt := range tests { + val, ok, err := tt.vs.Read(&f) + if err != nil && tt.err == nil { + t.Errorf("Combining.Read() returned error %v for state key %v when it shouldn't have", err, tt.vs.Key) + } else if err == nil && tt.err != nil { + t.Errorf("Combining.Read() returned no error for state key %v when it should have returned %v", tt.vs.Key, err) + } else if ok && !tt.ok { + t.Errorf("Combining.Read() returned a value %v for state key %v when it shouldn't have", val, tt.vs.Key) + } else if !ok && tt.ok { + t.Errorf("Combining.Read() didn't return a value for state key %v when it should have returned %v", tt.vs.Key, tt.val) + } else if val != tt.val { + t.Errorf("Combining.Read()=%v, want %v for state key %v", val, tt.val, tt.vs.Key) + } + } +} + +func TestCombiningAdd(t *testing.T) { + var tests = []struct { + vs Combining[int, int, int] + writes []int + val int + ok bool + }{ + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), []int{}, 0, false}, + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), []int{2}, 2, true}, + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), []int{7, 8, 9}, 24, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), []int{}, 1, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), []int{1}, 2, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), []int{3, 4}, 8, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum_extract_out", func(a, b int) int { + return a + b + }), []int{}, 100, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum_extract_out", func(a, b int) int { + return a + b + }), []int{1}, 200, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum_extract_out", func(a, b int) int { + return a + b + }), []int{1, 2}, 400, true}, + } + + for _, tt := range tests { + is := make(map[string]interface{}) + ts := make(map[string][]Transaction) + es := make(map[string]error) + ca := make(map[string]bool) + eo := make(map[string]bool) + ma := make(map[string]bool) + ai := make(map[string]bool) + ts["no_transactions"] = nil + ma["no_transactions"] = true + ts["no_transactions_initial_accum"] = nil + ca["no_transactions_initial_accum"] = true + ma["no_transactions_initial_accum"] = true + ts["no_transactions_initial_accum_extract_out"] = nil + ca["no_transactions_initial_accum_extract_out"] = true + eo["no_transactions_initial_accum_extract_out"] = true + ai["no_transactions_initial_accum_extract_out"] = true + + f := fakeProvider{ + initialState: is, + transactions: ts, + err: es, + createAccumForKey: ca, + extractOutForKey: eo, + mergeAccumForKey: ma, + addInputForKey: ai, + } + + for _, val := range tt.writes { + tt.vs.Add(&f, val) + } + + val, ok, err := tt.vs.Read(&f) + if err != nil { + t.Errorf("Bag.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes) + } else if ok && !tt.ok { + t.Errorf("Bag.Read() returned a value %v when it shouldn't have after writing: %v", val, tt.writes) + } else if !ok && tt.ok { + t.Errorf("Bag.Read() didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes) + } else if val != tt.val { + t.Errorf("Bag.Read()=%v, want %v after writing: %v", val, tt.val, tt.writes) + } + } +}