Skip to content

Commit

Permalink
Add clear function for bag state types (apache#22917)
Browse files Browse the repository at this point in the history
* Add clear function for bag state types

* Doc comment

* Register functions
  • Loading branch information
damccorm authored and Kanishk Karanawat committed Sep 29, 2022
1 parent 78895ac commit 96331da
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 13 deletions.
37 changes: 29 additions & 8 deletions sdks/go/pkg/beam/core/runtime/exec/userstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,22 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error {
return err
}

// TODO(#22736) - optimize this a bit once all state types are added. In the case of sets/clears,
// we can remove the transactions. We can also consider combining other transactions on read (or sooner)
// so that we don't need to use as much memory/time replaying transactions.
if transactions, ok := s.transactionsByKey[val.Key]; ok {
transactions = append(transactions, val)
s.transactionsByKey[val.Key] = transactions
} else {
s.transactionsByKey[val.Key] = []state.Transaction{val}
// Any transactions before a set don't matter
s.transactionsByKey[val.Key] = []state.Transaction{val}

return nil
}

// ClearValueState clears a value state from the State API.
func (s *stateProvider) ClearValueState(val state.Transaction) error {
cl, err := s.getBagClearer(val.Key)
if err != nil {
return err
}
cl.Write([]byte{})

// Any transactions before a clear don't matter
s.transactionsByKey[val.Key] = []state.Transaction{val}

return nil
}
Expand Down Expand Up @@ -140,6 +147,20 @@ func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state
return initialValue, transactions, nil
}

// ClearBagState clears a bag state from the State API
func (s *stateProvider) ClearBagState(val state.Transaction) error {
cl, err := s.getBagClearer(val.Key)
if err != nil {
return err
}
cl.Write([]byte{})

// Any transactions before a clear don't matter
s.transactionsByKey[val.Key] = []state.Transaction{val}

return nil
}

// WriteBagState writes a bag state to the State API
func (s *stateProvider) WriteBagState(val state.Transaction) error {
ap, err := s.getBagAppender(val.Key)
Expand Down
26 changes: 26 additions & 0 deletions sdks/go/pkg/beam/core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ type Transaction struct {
type Provider interface {
ReadValueState(id string) (interface{}, []Transaction, error)
WriteValueState(val Transaction) error
ClearValueState(val Transaction) error
ReadBagState(id string) ([]interface{}, []Transaction, error)
WriteBagState(val Transaction) error
ClearBagState(val Transaction) error
CreateAccumulatorFn(userStateID string) reflectx.Func
AddInputFn(userStateID string) reflectx.Func
MergeAccumulatorsFn(userStateID string) reflectx.Func
Expand Down Expand Up @@ -137,6 +139,14 @@ func (s *Value[T]) Read(p Provider) (T, bool, error) {
return cur.(T), true, nil
}

// Clear is used to clear this instance of global pipeline state representing a single value.
func (s *Value[T]) Clear(p Provider) error {
return p.ClearValueState(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
})
}

// StateKey returns the key for this pipeline state entry.
func (s Value[T]) StateKey() string {
if s.Key == "" {
Expand Down Expand Up @@ -212,6 +222,14 @@ func (s *Bag[T]) Read(p Provider) ([]T, bool, error) {
return cur, true, nil
}

// Clear is used to clear this instance of global pipeline state representing a bag.
func (s *Bag[T]) Clear(p Provider) error {
return p.ClearBagState(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
})
}

// StateKey returns the key for this pipeline state entry.
func (s Bag[T]) StateKey() string {
if s.Key == "" {
Expand Down Expand Up @@ -322,6 +340,14 @@ func (s *Combining[T1, T2, T3]) Read(p Provider) (T3, bool, error) {
return acc.(T3), true, nil
}

// Clear is used to clear this instance of global pipeline state representing a combiner.
func (s *Combining[T1, T2, T3]) Clear(p Provider) error {
return p.ClearValueState(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
})
}

func (s *Combining[T1, T2, T3]) readAccumulator(p Provider) (interface{}, bool, error) {
// This replays any writes that have happened to this value since we last read
// For more detail, see "State Transactionality" below for buffered transactions
Expand Down
166 changes: 161 additions & 5 deletions sdks/go/pkg/beam/core/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ func (s *fakeProvider) ReadValueState(userStateID string) (interface{}, []Transa
}

func (s *fakeProvider) WriteValueState(val Transaction) error {
if transactions, ok := s.transactions[val.Key]; ok {
s.transactions[val.Key] = append(transactions, val)
} else {
s.transactions[val.Key] = []Transaction{val}
}
s.transactions[val.Key] = []Transaction{val}
return nil
}

func (s *fakeProvider) ClearValueState(val Transaction) error {
s.transactions[val.Key] = []Transaction{val}
return nil
}

Expand All @@ -80,6 +81,11 @@ func (s *fakeProvider) WriteBagState(val Transaction) error {
return nil
}

func (s *fakeProvider) ClearBagState(val Transaction) error {
s.transactions[val.Key] = []Transaction{val}
return nil
}

func (s *fakeProvider) CreateAccumulatorFn(userStateID string) reflectx.Func {
if s.createAccumForKey[userStateID] {
return reflectx.MakeFunc0x1(func() int {
Expand Down Expand Up @@ -245,6 +251,44 @@ func TestValueWrite(t *testing.T) {
}
}

func TestValueClear(t *testing.T) {
var tests = []struct {
writes []int
clears int
}{
{[]int{}, 1},
{[]int{3}, 1},
{[]int{1, 5}, 1},
{[]int{}, 2},
{[]int{3}, 2},
{[]int{1, 5}, 2},
}

for _, tt := range tests {
f := fakeProvider{
initialState: make(map[string]interface{}),
transactions: make(map[string][]Transaction),
err: make(map[string]error),
}
vs := MakeValueState[int]("vs")
for _, val := range tt.writes {
vs.Write(&f, val)
}
for i := 0; i < tt.clears; i++ {
err := vs.Clear(&f)
if err != nil {
t.Errorf("Value.Clear() attempt %v returned error %v", i, err)
}
}
_, ok, err := vs.Read(&f)
if err != nil {
t.Errorf("Value.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes)
} else if ok {
t.Errorf("Value.Read() returned a value when it shouldn't have after writing %v and performing %v clears", tt.writes, tt.clears)
}
}
}

func TestBagRead(t *testing.T) {
is := make(map[string][]interface{})
ts := make(map[string][]Transaction)
Expand Down Expand Up @@ -356,6 +400,44 @@ func TestBagAdd(t *testing.T) {
}
}

func TestBagClear(t *testing.T) {
var tests = []struct {
writes []int
clears int
}{
{[]int{}, 1},
{[]int{3}, 1},
{[]int{1, 5}, 1},
{[]int{}, 2},
{[]int{3}, 2},
{[]int{1, 5}, 2},
}

for _, tt := range tests {
f := fakeProvider{
initialState: make(map[string]interface{}),
transactions: make(map[string][]Transaction),
err: make(map[string]error),
}
vs := MakeBagState[int]("vs")
for _, val := range tt.writes {
vs.Add(&f, val)
}
for i := 0; i < tt.clears; i++ {
err := vs.Clear(&f)
if err != nil {
t.Errorf("Bag.Clear() attempt %v returned error %v", i, err)
}
}
_, ok, err := vs.Read(&f)
if err != nil {
t.Errorf("Bag.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes)
} else if ok {
t.Errorf("Bag.Read() returned a value when it shouldn't have after writing %v and performing %v clears", tt.writes, tt.clears)
}
}
}

func TestCombiningRead(t *testing.T) {
is := make(map[string]interface{})
ts := make(map[string][]Transaction)
Expand Down Expand Up @@ -436,6 +518,80 @@ func TestCombiningRead(t *testing.T) {
}
}

func TestCombiningClear(t *testing.T) {
var tests = []struct {
vs Combining[int, int, int]
writes []int
val int
clears int
ok bool
}{
{MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int {
return a + b
}), []int{}, 0, 1, false},
{MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int {
return a + b
}), []int{2}, 0, 1, false},
{MakeCombiningState[int, int, int]("no_transactions", func(a, b int) int {
return a + b
}), []int{7, 8, 9}, 0, 1, false},
{MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int {
return a + b
}), []int{}, 1, 1, true},
{MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int {
return a + b
}), []int{1}, 1, 1, true},
{MakeCombiningState[int, int, int]("no_transactions_initial_accum", func(a, b int) int {
return a + b
}), []int{3, 4}, 1, 1, true},
}

for _, tt := range tests {
is := make(map[string]interface{})
ts := make(map[string][]Transaction)
es := make(map[string]error)
ca := make(map[string]bool)
eo := make(map[string]bool)
ma := make(map[string]bool)
ai := make(map[string]bool)
ts["no_transactions"] = nil
ma["no_transactions"] = true
ts["no_transactions_initial_accum"] = nil
ca["no_transactions_initial_accum"] = true
ma["no_transactions_initial_accum"] = true

f := fakeProvider{
initialState: is,
transactions: ts,
err: es,
createAccumForKey: ca,
extractOutForKey: eo,
mergeAccumForKey: ma,
addInputForKey: ai,
}

for _, val := range tt.writes {
tt.vs.Add(&f, val)
}
for i := 0; i < tt.clears; i++ {
err := tt.vs.Clear(&f)
if err != nil {
t.Errorf("Combining.Clear() attempt %v returned error %v", i, err)
}
}
val, ok, err := tt.vs.Read(&f)
if err != nil {
t.Errorf("Combining.Read() returned error %v when it shouldn't have after writing %v and performing %v clears for key %v", err, tt.writes, tt.clears, tt.vs.StateKey())
} else if ok && !tt.ok {
t.Errorf("Combining.Read() returned a value when it shouldn't have after writing %v and performing %v clears for key %v", tt.writes, tt.clears, tt.vs.StateKey())
} else if !ok && tt.ok {
t.Errorf("Combining.Read() returned no value when it should have returned %v after writing %v and performing %v clears for key %v", tt.val, tt.writes, tt.clears, tt.vs.StateKey())
} else if tt.val != val {
t.Errorf("Combining.Read()=%v, want %v after writing %v and performing %v clears for key %v", val, tt.val, tt.writes, tt.clears, tt.vs.StateKey())
}
}
}

func TestCombiningAdd(t *testing.T) {
var tests = []struct {
vs Combining[int, int, int]
Expand Down
6 changes: 6 additions & 0 deletions sdks/go/test/integration/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ var directFilters = []string{
"TestOomParDo",
// The direct runner does not support user state.
"TestValueState",
"TestValueState_Clear",
"TestBagState",
"TestBagState_Clear",
"TestCombiningState",
"TestMapState",
}
Expand All @@ -116,7 +118,9 @@ var portableFilters = []string{
"TestOomParDo",
// The portable runner does not support user state.
"TestValueState",
"TestValueState_Clear",
"TestBagState",
"TestBagState_Clear",
"TestCombiningState",
"TestMapState",
}
Expand Down Expand Up @@ -162,7 +166,9 @@ var samzaFilters = []string{
"TestOomParDo",
// The samza runner does not support user state.
"TestValueState",
"TestValueState_Clear",
"TestBagState",
"TestBagState_Clear",
"TestCombiningState",
"TestMapState",
}
Expand Down
Loading

0 comments on commit 96331da

Please sign in to comment.