Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to remove/clear map and set state #22938

Merged
merged 8 commits into from
Aug 31, 2022
77 changes: 64 additions & 13 deletions sdks/go/pkg/beam/core/runtime/exec/userstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error {
if err != nil {
return err
}
cl.Write([]byte{})
_, err = cl.Write([]byte{})
if err != nil {
return err
}

ap, err := s.getBagAppender(val.Key)
if err != nil {
return err
}
fv := FullValue{Elm: val.Val}
// TODO(#22736) - consider caching this a proprty of stateProvider
enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
err = enc.Encode(&fv, ap)
if err != nil {
Expand All @@ -109,7 +111,10 @@ func (s *stateProvider) ClearValueState(val state.Transaction) error {
if err != nil {
return err
}
cl.Write([]byte{})
_, err = cl.Write([]byte{})
if err != nil {
return err
}

// Any transactions before a clear don't matter
s.transactionsByKey[val.Key] = []state.Transaction{val}
Expand Down Expand Up @@ -153,7 +158,10 @@ func (s *stateProvider) ClearBagState(val state.Transaction) error {
if err != nil {
return err
}
cl.Write([]byte{})
_, err = cl.Write([]byte{})
if err != nil {
return err
}

// Any transactions before a clear don't matter
s.transactionsByKey[val.Key] = []state.Transaction{val}
Expand All @@ -168,14 +176,12 @@ func (s *stateProvider) WriteBagState(val state.Transaction) error {
return err
}
fv := FullValue{Elm: val.Val}
// TODO(#22736) - consider caching this a proprty of stateProvider
enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
err = enc.Encode(&fv, ap)
if err != nil {
return err
}

// TODO(#22736) - optimize this a bit once all state types are added.
if transactions, ok := s.transactionsByKey[val.Key]; ok {
transactions = append(transactions, val)
s.transactionsByKey[val.Key] = transactions
Expand Down Expand Up @@ -254,27 +260,26 @@ func (s *stateProvider) ReadMapStateKeys(userStateID string) ([]interface{}, []s

// WriteMapState writes a key value pair to the global map state.
func (s *stateProvider) WriteMapState(val state.Transaction) error {
cl, err := s.getMultiMapClearer(val.Key, val.MapKey)
cl, err := s.getMultiMapKeyClearer(val.Key, val.MapKey)
if err != nil {
return err
}
_, err = cl.Write([]byte{})
if err != nil {
return err
}
cl.Write([]byte{})

ap, err := s.getMultiMapAppender(val.Key, val.MapKey)
if err != nil {
return err
}
fv := FullValue{Elm: val.Val}
// TODO(#22736) - consider caching this a proprty of stateProvider
enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
err = enc.Encode(&fv, ap)
if err != nil {
return err
}

// TODO(#22736) - optimize this a bit once all state types are added. In the case of sets/clears,
// we can remove the transactions. We can also consider combining other transactions on read (or sooner)
// so that we don't need to use as much memory/time replaying transactions.
if transactions, ok := s.transactionsByKey[val.Key]; ok {
transactions = append(transactions, val)
s.transactionsByKey[val.Key] = transactions
Expand All @@ -285,6 +290,44 @@ func (s *stateProvider) WriteMapState(val state.Transaction) error {
return nil
}

// ClearMapStateKey deletes a key value pair from the global map state.
func (s *stateProvider) ClearMapStateKey(val state.Transaction) error {
cl, err := s.getMultiMapKeyClearer(val.Key, val.MapKey)
if err != nil {
return err
}
_, err = cl.Write([]byte{})
if err != nil {
return err
}

if transactions, ok := s.transactionsByKey[val.Key]; ok {
transactions = append(transactions, val)
s.transactionsByKey[val.Key] = transactions
} else {
s.transactionsByKey[val.Key] = []state.Transaction{val}
}

return nil
}

// ClearMapState deletes all key value pairs from the global map state.
func (s *stateProvider) ClearMapState(val state.Transaction) error {
cl, err := s.getMultiMapClearer(val.Key)
if err != nil {
return err
}
_, err = cl.Write([]byte{})
if err != nil {
return err
}

// Any transactions before a clear don't matter
s.transactionsByKey[val.Key] = []state.Transaction{val}

return nil
}

func (s *stateProvider) CreateAccumulatorFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ca := a.CreateAccumulatorFn(); ca != nil {
Expand Down Expand Up @@ -379,7 +422,7 @@ func (s *stateProvider) getMultiMapAppender(userStateID string, key interface{})
return w, nil
}

func (s *stateProvider) getMultiMapClearer(userStateID string, key interface{}) (io.Writer, error) {
func (s *stateProvider) getMultiMapKeyClearer(userStateID string, key interface{}) (io.Writer, error) {
ek, err := s.encodeKey(userStateID, key)
if err != nil {
return nil, err
Expand All @@ -391,6 +434,14 @@ func (s *stateProvider) getMultiMapClearer(userStateID string, key interface{})
return w, nil
}

func (s *stateProvider) getMultiMapClearer(userStateID string) (io.Writer, error) {
w, err := s.sr.OpenMultimapKeysUserStateClearer(s.ctx, s.SID, userStateID, s.elementKey, s.window)
if err != nil {
return nil, err
}
return w, nil
}

func (s *stateProvider) getMultiMapKeyReader(userStateID string) (io.ReadCloser, error) {
if r, ok := s.readersByKey[userStateID]; ok {
return r, nil
Expand Down
36 changes: 36 additions & 0 deletions sdks/go/pkg/beam/core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ type Provider interface {
ReadMapStateValue(userStateID string, key interface{}) (interface{}, []Transaction, error)
ReadMapStateKeys(userStateID string) ([]interface{}, []Transaction, error)
WriteMapState(val Transaction) error
ClearMapStateKey(val Transaction) error
ClearMapState(val Transaction) error
}

// PipelineState is an interface representing different kinds of PipelineState (currently just state.Value).
Expand Down Expand Up @@ -501,6 +503,23 @@ func (s *Map[K, V]) Get(p Provider, key K) (V, bool, error) {
return cur.(V), true, nil
}

// Remove deletes an entry from this instance of map state.
func (s *Map[K, V]) Remove(p Provider, key K) error {
return p.ClearMapStateKey(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
MapKey: key,
})
}

// Clear deletes all entries from this instance of map state.
func (s *Map[K, V]) Clear(p Provider) error {
return p.ClearMapState(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
})
}

// StateKey returns the key for this pipeline state entry.
func (s Map[K, V]) StateKey() string {
return s.Key
Expand Down Expand Up @@ -620,6 +639,23 @@ func (s *Set[K]) Contains(p Provider, key K) (bool, error) {
return true, nil
}

// Remove deletes an entry from this instance of set state.
func (s Set[K]) Remove(p Provider, key K) error {
return p.ClearMapStateKey(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
MapKey: key,
})
}

// Clear deletes all entries from this instance of set state.
func (s Set[K]) Clear(p Provider) error {
return p.ClearMapState(Transaction{
Key: s.Key,
Type: TransactionTypeClear,
})
}

// StateKey returns the key for this pipeline state entry.
func (s Set[K]) StateKey() string {
return s.Key
Expand Down
Loading