diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go index a980d2bb3bbf8..e089d843f67f2 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go +++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go @@ -97,15 +97,22 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error { return err } - // TODO(#22736) - optimize this a bit once all state types are added. In the case of sets/clears, - // we can remove the transactions. We can also consider combining other transactions on read (or sooner) - // so that we don't need to use as much memory/time replaying transactions. - if transactions, ok := s.transactionsByKey[val.Key]; ok { - transactions = append(transactions, val) - s.transactionsByKey[val.Key] = transactions - } else { - s.transactionsByKey[val.Key] = []state.Transaction{val} + // Any transactions before a set don't matter + s.transactionsByKey[val.Key] = []state.Transaction{val} + + return nil +} + +// ClearValueState clears a value state from the State API. +func (s *stateProvider) ClearValueState(val state.Transaction) error { + cl, err := s.getBagClearer(val.Key) + if err != nil { + return err } + cl.Write([]byte{}) + + // Any transactions before a clear don't matter + s.transactionsByKey[val.Key] = []state.Transaction{val} return nil } @@ -140,6 +147,20 @@ func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state return initialValue, transactions, nil } +// ClearBagState clears a bag state from the State API +func (s *stateProvider) ClearBagState(val state.Transaction) error { + cl, err := s.getBagClearer(val.Key) + if err != nil { + return err + } + cl.Write([]byte{}) + + // Any transactions before a clear don't matter + s.transactionsByKey[val.Key] = []state.Transaction{val} + + return nil +} + // WriteBagState writes a bag state to the State API func (s *stateProvider) WriteBagState(val state.Transaction) error { ap, err := s.getBagAppender(val.Key) diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go index c3b5fe7f83aea..70050598cca58 100644 --- a/sdks/go/pkg/beam/core/state/state.go +++ b/sdks/go/pkg/beam/core/state/state.go @@ -71,8 +71,10 @@ type Transaction struct { type Provider interface { ReadValueState(id string) (interface{}, []Transaction, error) WriteValueState(val Transaction) error + ClearValueState(val Transaction) error ReadBagState(id string) ([]interface{}, []Transaction, error) WriteBagState(val Transaction) error + ClearBagState(val Transaction) error CreateAccumulatorFn(userStateID string) reflectx.Func AddInputFn(userStateID string) reflectx.Func MergeAccumulatorsFn(userStateID string) reflectx.Func @@ -137,6 +139,14 @@ func (s *Value[T]) Read(p Provider) (T, bool, error) { return cur.(T), true, nil } +// Clear is used to clear this instance of global pipeline state representing a single value. +func (s *Value[T]) Clear(p Provider) error { + return p.ClearValueState(Transaction{ + Key: s.Key, + Type: TransactionTypeClear, + }) +} + // StateKey returns the key for this pipeline state entry. func (s Value[T]) StateKey() string { if s.Key == "" { @@ -212,6 +222,14 @@ func (s *Bag[T]) Read(p Provider) ([]T, bool, error) { return cur, true, nil } +// Clear is used to clear this instance of global pipeline state representing a bag. +func (s *Bag[T]) Clear(p Provider) error { + return p.ClearBagState(Transaction{ + Key: s.Key, + Type: TransactionTypeClear, + }) +} + // StateKey returns the key for this pipeline state entry. func (s Bag[T]) StateKey() string { if s.Key == "" { @@ -322,6 +340,14 @@ func (s *Combining[T1, T2, T3]) Read(p Provider) (T3, bool, error) { return acc.(T3), true, nil } +// Clear is used to clear this instance of global pipeline state representing a combiner. +func (s *Combining[T1, T2, T3]) Clear(p Provider) error { + return p.ClearValueState(Transaction{ + Key: s.Key, + Type: TransactionTypeClear, + }) +} + func (s *Combining[T1, T2, T3]) readAccumulator(p Provider) (interface{}, bool, error) { // This replays any writes that have happened to this value since we last read // For more detail, see "State Transactionality" below for buffered transactions diff --git a/sdks/go/pkg/beam/core/state/state_test.go b/sdks/go/pkg/beam/core/state/state_test.go index e964c9647c511..555e0adbe7b21 100644 --- a/sdks/go/pkg/beam/core/state/state_test.go +++ b/sdks/go/pkg/beam/core/state/state_test.go @@ -51,11 +51,12 @@ func (s *fakeProvider) ReadValueState(userStateID string) (interface{}, []Transa } func (s *fakeProvider) WriteValueState(val Transaction) error { - if transactions, ok := s.transactions[val.Key]; ok { - s.transactions[val.Key] = append(transactions, val) - } else { - s.transactions[val.Key] = []Transaction{val} - } + s.transactions[val.Key] = []Transaction{val} + return nil +} + +func (s *fakeProvider) ClearValueState(val Transaction) error { + s.transactions[val.Key] = []Transaction{val} return nil } @@ -80,6 +81,11 @@ func (s *fakeProvider) WriteBagState(val Transaction) error { return nil } +func (s *fakeProvider) ClearBagState(val Transaction) error { + s.transactions[val.Key] = []Transaction{val} + return nil +} + func (s *fakeProvider) CreateAccumulatorFn(userStateID string) reflectx.Func { if s.createAccumForKey[userStateID] { return reflectx.MakeFunc0x1(func() int { @@ -245,6 +251,44 @@ func TestValueWrite(t *testing.T) { } } +func TestValueClear(t *testing.T) { + var tests = []struct { + writes []int + clears int + }{ + {[]int{}, 1}, + {[]int{3}, 1}, + {[]int{1, 5}, 1}, + {[]int{}, 2}, + {[]int{3}, 2}, + {[]int{1, 5}, 2}, + } + + for _, tt := range tests { + f := fakeProvider{ + initialState: make(map[string]interface{}), + transactions: make(map[string][]Transaction), + err: make(map[string]error), + } + vs := MakeValueState[int]("vs") + for _, val := range tt.writes { + vs.Write(&f, val) + } + for i := 0; i < tt.clears; i++ { + err := vs.Clear(&f) + if err != nil { + t.Errorf("Value.Clear() attempt %v returned error %v", i, err) + } + } + _, ok, err := vs.Read(&f) + if err != nil { + t.Errorf("Value.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes) + } else if ok { + t.Errorf("Value.Read() returned a value when it shouldn't have after writing %v and performing %v clears", tt.writes, tt.clears) + } + } +} + func TestBagRead(t *testing.T) { is := make(map[string][]interface{}) ts := make(map[string][]Transaction) @@ -356,6 +400,44 @@ func TestBagAdd(t *testing.T) { } } +func TestBagClear(t *testing.T) { + var tests = []struct { + writes []int + clears int + }{ + {[]int{}, 1}, + {[]int{3}, 1}, + {[]int{1, 5}, 1}, + {[]int{}, 2}, + {[]int{3}, 2}, + {[]int{1, 5}, 2}, + } + + for _, tt := range tests { + f := fakeProvider{ + initialState: make(map[string]interface{}), + transactions: make(map[string][]Transaction), + err: make(map[string]error), + } + vs := MakeBagState[int]("vs") + for _, val := range tt.writes { + vs.Add(&f, val) + } + for i := 0; i < tt.clears; i++ { + err := vs.Clear(&f) + if err != nil { + t.Errorf("Bag.Clear() attempt %v returned error %v", i, err) + } + } + _, ok, err := vs.Read(&f) + if err != nil { + t.Errorf("Bag.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes) + } else if ok { + t.Errorf("Bag.Read() returned a value when it shouldn't have after writing %v and performing %v clears", tt.writes, tt.clears) + } + } +} + func TestCombiningRead(t *testing.T) { is := make(map[string]interface{}) ts := make(map[string][]Transaction) @@ -436,6 +518,80 @@ func TestCombiningRead(t *testing.T) { } } +func TestCombiningClear(t *testing.T) { + var tests = []struct { + vs Combining[int, int, int] + writes []int + val int + clears int + ok bool + }{ + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), []int{}, 0, 1, false}, + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), []int{2}, 0, 1, false}, + {MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int { + return a + b + }), []int{7, 8, 9}, 0, 1, false}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), []int{}, 1, 1, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), []int{1}, 1, 1, true}, + {MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int { + return a + b + }), []int{3, 4}, 1, 1, true}, + } + + for _, tt := range tests { + is := make(map[string]interface{}) + ts := make(map[string][]Transaction) + es := make(map[string]error) + ca := make(map[string]bool) + eo := make(map[string]bool) + ma := make(map[string]bool) + ai := make(map[string]bool) + ts["no_transactions"] = nil + ma["no_transactions"] = true + ts["no_transactions_initial_accum"] = nil + ca["no_transactions_initial_accum"] = true + ma["no_transactions_initial_accum"] = true + + f := fakeProvider{ + initialState: is, + transactions: ts, + err: es, + createAccumForKey: ca, + extractOutForKey: eo, + mergeAccumForKey: ma, + addInputForKey: ai, + } + + for _, val := range tt.writes { + tt.vs.Add(&f, val) + } + for i := 0; i < tt.clears; i++ { + err := tt.vs.Clear(&f) + if err != nil { + t.Errorf("Combining.Clear() attempt %v returned error %v", i, err) + } + } + val, ok, err := tt.vs.Read(&f) + if err != nil { + t.Errorf("Combining.Read() returned error %v when it shouldn't have after writing %v and performing %v clears for key %v", err, tt.writes, tt.clears, tt.vs.StateKey()) + } else if ok && !tt.ok { + t.Errorf("Combining.Read() returned a value when it shouldn't have after writing %v and performing %v clears for key %v", tt.writes, tt.clears, tt.vs.StateKey()) + } else if !ok && tt.ok { + t.Errorf("Combining.Read() returned no value when it should have returned %v after writing %v and performing %v clears for key %v", tt.val, tt.writes, tt.clears, tt.vs.StateKey()) + } else if tt.val != val { + t.Errorf("Combining.Read()=%v, want %v after writing %v and performing %v clears for key %v", val, tt.val, tt.writes, tt.clears, tt.vs.StateKey()) + } + } +} + func TestCombiningAdd(t *testing.T) { var tests = []struct { vs Combining[int, int, int] diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go index 76fff3b968725..504cf3e58f2d7 100644 --- a/sdks/go/test/integration/integration.go +++ b/sdks/go/test/integration/integration.go @@ -91,7 +91,9 @@ var directFilters = []string{ "TestOomParDo", // The direct runner does not support user state. "TestValueState", + "TestValueState_Clear", "TestBagState", + "TestBagState_Clear", "TestCombiningState", "TestMapState", } @@ -116,7 +118,9 @@ var portableFilters = []string{ "TestOomParDo", // The portable runner does not support user state. "TestValueState", + "TestValueState_Clear", "TestBagState", + "TestBagState_Clear", "TestCombiningState", "TestMapState", } @@ -162,7 +166,9 @@ var samzaFilters = []string{ "TestOomParDo", // The samza runner does not support user state. "TestValueState", + "TestValueState_Clear", "TestBagState", + "TestBagState_Clear", "TestCombiningState", "TestMapState", } diff --git a/sdks/go/test/integration/primitives/state.go b/sdks/go/test/integration/primitives/state.go index 605c03bcf3b94..dd6a0ccf49b25 100644 --- a/sdks/go/test/integration/primitives/state.go +++ b/sdks/go/test/integration/primitives/state.go @@ -29,7 +29,9 @@ import ( func init() { register.DoFn3x1[state.Provider, string, int, string](&valueStateFn{}) + register.DoFn3x1[state.Provider, string, int, string](&valueStateClearFn{}) register.DoFn3x1[state.Provider, string, int, string](&bagStateFn{}) + register.DoFn3x1[state.Provider, string, int, string](&bagStateClearFn{}) register.DoFn3x1[state.Provider, string, int, string](&combiningStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&mapStateFn{}) register.Emitter2[string, int]() @@ -85,6 +87,44 @@ func ValueStateParDo() *beam.Pipeline { return p } +type valueStateClearFn struct { + State1 state.Value[int] +} + +func (f *valueStateClearFn) ProcessElement(s state.Provider, w string, c int) string { + i, ok, err := f.State1.Read(s) + if err != nil { + panic(err) + } + if ok { + err = f.State1.Clear(s) + if err != nil { + panic(err) + } + } else { + err = f.State1.Write(s, 1) + if err != nil { + panic(err) + } + } + + return fmt.Sprintf("%s: %v,%v", w, i, ok) +} + +// ValueStateParDo_Clear tests that a DoFn that uses value state can be cleared. +func ValueStateParDo_Clear() *beam.Pipeline { + p, s := beam.NewPipelineWithRoot() + + in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear", "pear", "apple") + keyed := beam.ParDo(s, func(w string, emit func(string, int)) { + emit(w, 1) + }, in) + counts := beam.ParDo(s, &valueStateClearFn{State1: state.MakeValueState[int]("key1")}, keyed) + passert.Equals(s, counts, "apple: 0,false", "pear: 0,false", "peach: 0,false", "apple: 1,true", "apple: 0,false", "pear: 1,true", "pear: 0,false", "apple: 1,true") + + return p +} + type bagStateFn struct { State1 state.Bag[int] State2 state.Bag[string] @@ -135,6 +175,47 @@ func BagStateParDo() *beam.Pipeline { return p } +type bagStateClearFn struct { + State1 state.Bag[int] +} + +func (f *bagStateClearFn) ProcessElement(s state.Provider, w string, c int) string { + i, ok, err := f.State1.Read(s) + if err != nil { + panic(err) + } + if !ok { + i = []int{} + } + err = f.State1.Add(s, 1) + if err != nil { + panic(err) + } + + sum := 0 + for _, val := range i { + sum += val + } + if sum == 3 { + f.State1.Clear(s) + } + return fmt.Sprintf("%s: %v", w, sum) +} + +// BagStateParDo_Clear tests a DoFn that uses bag state. +func BagStateParDo_Clear() *beam.Pipeline { + p, s := beam.NewPipelineWithRoot() + + in := beam.Create(s, "apple", "pear", "apple", "apple", "pear", "apple", "apple", "pear", "pear", "pear", "apple", "pear") + keyed := beam.ParDo(s, func(w string, emit func(string, int)) { + emit(w, 1) + }, in) + counts := beam.ParDo(s, &bagStateClearFn{State1: state.MakeBagState[int]("key1")}, keyed) + passert.Equals(s, counts, "apple: 0", "pear: 0", "apple: 1", "apple: 2", "pear: 1", "apple: 3", "apple: 0", "pear: 2", "pear: 3", "pear: 0", "apple: 1", "pear: 1") + + return p +} + type combiningStateFn struct { State0 state.Combining[int, int, int] State1 state.Combining[int, int, int] diff --git a/sdks/go/test/integration/primitives/state_test.go b/sdks/go/test/integration/primitives/state_test.go index 4c310e0ceacc8..23d389ba1f17a 100644 --- a/sdks/go/test/integration/primitives/state_test.go +++ b/sdks/go/test/integration/primitives/state_test.go @@ -27,11 +27,21 @@ func TestValueState(t *testing.T) { ptest.RunAndValidate(t, ValueStateParDo()) } +func TestValueState_Clear(t *testing.T) { + integration.CheckFilters(t) + ptest.RunAndValidate(t, ValueStateParDo_Clear()) +} + func TestBagState(t *testing.T) { integration.CheckFilters(t) ptest.RunAndValidate(t, BagStateParDo()) } +func TestBagState_Clear(t *testing.T) { + integration.CheckFilters(t) + ptest.RunAndValidate(t, BagStateParDo_Clear()) +} + func TestCombiningState(t *testing.T) { integration.CheckFilters(t) ptest.RunAndValidate(t, CombiningStateParDo())