Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bag state support #22816

Merged
merged 2 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,11 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
// TODO(#22736) - Add more state types as they become supported
if t != state.StateTypeValue {
err := errors.Errorf("Non-value state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Non-value state type %v for state %v. Currently the only supported state"+
"type is state.Value", t, s)
if t != state.StateTypeValue && t != state.StateTypeBag {
err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
"type is state.Value and state.Bag", t, s)
}
stateKeys[k] = s
}
Expand Down
11 changes: 10 additions & 1 deletion sdks/go/pkg/beam/core/graph/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ func TestNewDoFn(t *testing.T) {
{dfn: &GoodDoFnCoGbk2{}, opt: CoGBKMainInput(3)},
{dfn: &GoodDoFnCoGbk7{}, opt: CoGBKMainInput(8)},
{dfn: &GoodDoFnCoGbk1wSide{}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn{State1: state.Value[int](state.MakeValueState[int]("state1"))}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn{State1: state.MakeValueState[int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn2{State1: state.MakeBagState[int]("state1")}, opt: NumMainInputs(MainKv)},
}

for _, test := range tests {
Expand Down Expand Up @@ -1087,6 +1088,14 @@ func (fn *GoodStatefulDoFn) ProcessElement(state.Provider, int, int) int {
return 0
}

type GoodStatefulDoFn2 struct {
State1 state.Bag[int]
}

func (fn *GoodStatefulDoFn2) ProcessElement(state.Provider, int, int) int {
return 0
}

// Examples of incorrect SDF signatures.
// Examples with missing methods.

Expand Down
7 changes: 6 additions & 1 deletion sdks/go/pkg/beam/core/runtime/exec/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,12 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
stateIDToCoder := make(map[string]*coder.Coder)
for key, spec := range userState {
// TODO(#22736) - this will eventually need to be aware of which type of state its modifying to support non-Value state types.
cID := spec.GetReadModifyWriteSpec().CoderId
var cID string
if rmw := spec.GetReadModifyWriteSpec(); rmw != nil {
cID = rmw.CoderId
} else if bs := spec.GetBagSpec(); bs != nil {
cID = bs.ElementCoderId
}
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
Expand Down
59 changes: 59 additions & 0 deletions sdks/go/pkg/beam/core/runtime/exec/userstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type stateProvider struct {

transactionsByKey map[string][]state.Transaction
initialValueByKey map[string]interface{}
initialBagByKey map[string][]interface{}
readersByKey map[string]io.ReadCloser
appendersByKey map[string]io.Writer
clearersByKey map[string]io.Writer
Expand All @@ -57,6 +58,7 @@ func (s *stateProvider) ReadValueState(userStateID string) (interface{}, []state
return nil, []state.Transaction{}, nil
}
initialValue = resp.Elm
s.initialValueByKey[userStateID] = initialValue
}

transactions, ok := s.transactionsByKey[userStateID]
Expand Down Expand Up @@ -101,6 +103,62 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error {
return nil
}

// ReadBagState reads a ReadBagState state from the State API
func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state.Transaction, error) {
initialValue, ok := s.initialBagByKey[userStateID]
if !ok {
initialValue = []interface{}{}
rw, err := s.getReader(userStateID)
if err != nil {
return nil, nil, err
}
dec := MakeElementDecoder(coder.SkipW(s.codersByKey[userStateID]))
for err == nil {
var resp *FullValue
resp, err = dec.Decode(rw)
if err == nil {
initialValue = append(initialValue, resp.Elm)
} else if err != io.EOF {
return nil, nil, err
}
}
s.initialBagByKey[userStateID] = initialValue
}

transactions, ok := s.transactionsByKey[userStateID]
if !ok {
transactions = []state.Transaction{}
}

return initialValue, transactions, nil
}

// WriteValueState writes a value state to the State API
// For value states, this is done by clearing a bag state and writing a value to it.
func (s *stateProvider) WriteBagState(val state.Transaction) error {
ap, err := s.getAppender(val.Key)
if err != nil {
return err
}
fv := FullValue{Elm: val.Val}
// TODO(#22736) - consider caching this a proprty of stateProvider
enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
err = enc.Encode(&fv, ap)
if err != nil {
return err
}

// TODO(#22736) - optimize this a bit once all state types are added.
if transactions, ok := s.transactionsByKey[val.Key]; ok {
transactions = append(transactions, val)
s.transactionsByKey[val.Key] = transactions
} else {
s.transactionsByKey[val.Key] = []state.Transaction{val}
}

return nil
}

func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
if r, ok := s.readersByKey[userStateID]; ok {
return r, nil
Expand Down Expand Up @@ -189,6 +247,7 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea
window: win,
transactionsByKey: make(map[string][]state.Transaction),
initialValueByKey: make(map[string]interface{}),
initialBagByKey: make(map[string][]interface{}),
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
Expand Down
1 change: 1 addition & 0 deletions sdks/go/pkg/beam/core/runtime/exec/userstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func buildStateProvider() stateProvider {
window: []byte{1},
transactionsByKey: make(map[string][]state.Transaction),
initialValueByKey: make(map[string]interface{}),
initialBagByKey: make(map[string][]interface{}),
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
Expand Down
38 changes: 28 additions & 10 deletions sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window/trigger"
v1pb "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
Expand Down Expand Up @@ -84,6 +85,9 @@ const (
URNEnvProcess = "beam:env:process:v1"
URNEnvExternal = "beam:env:external:v1"
URNEnvDocker = "beam:env:docker:v1"

// Userstate Urns.
URNBagUserState = "beam:user_state:bag:v1"
)

func goCapabilities() []string {
Expand Down Expand Up @@ -466,17 +470,31 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
if err != nil {
return handleErr(err)
}
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
// TODO (#22736) - make spec type and protocol conditional on type of State. Right now, assumes ValueState.
// See https://github.com/apache/beam/blob/54b0784da7ccba738deff22bd83fbc374ad21d2e/sdks/go/pkg/beam/model/pipeline_v1/beam_runner_api.pb.go#L2635
Spec: &pipepb.StateSpec_ReadModifyWriteSpec{
ReadModifyWriteSpec: &pipepb.ReadModifyWriteStateSpec{
CoderId: coderID,
switch ps.StateType() {
case state.StateTypeValue:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_ReadModifyWriteSpec{
ReadModifyWriteSpec: &pipepb.ReadModifyWriteStateSpec{
CoderId: coderID,
},
},
},
Protocol: &pipepb.FunctionSpec{
Urn: "beam:user_state:bag:v1",
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
case state.StateTypeBag:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_BagSpec{
BagSpec: &pipepb.BagStateSpec{
ElementCoderId: coderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
}
payload.StateSpecs = stateSpecs
Expand Down
76 changes: 76 additions & 0 deletions sdks/go/pkg/beam/core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ const (
TransactionTypeSet TransactionTypeEnum = 0
// TransactionTypeClear is the set transaction type
TransactionTypeClear TransactionTypeEnum = 1
// TransactionTypeAppend is the append transaction type
TransactionTypeAppend TransactionTypeEnum = 2
// StateTypeValue represents a value state
StateTypeValue StateTypeEnum = 0
// StateTypeBag represents a bag state
StateTypeBag StateTypeEnum = 1
)

var (
Expand All @@ -57,6 +61,8 @@ type Transaction struct {
type Provider interface {
ReadValueState(id string) (interface{}, []Transaction, error)
WriteValueState(val Transaction) error
ReadBagState(id string) ([]interface{}, []Transaction, error)
WriteBagState(val Transaction) error
}

// PipelineState is an interface representing different kinds of PipelineState (currently just state.Value).
Expand Down Expand Up @@ -133,3 +139,73 @@ func MakeValueState[T any](k string) Value[T] {
Key: k,
}
}

// Bag is used to read and write global pipeline state representing a collection of values.
// Key represents the key used to lookup this state.
type Bag[T any] struct {
Key string
}

// Add is used to write append to the bag pipeline state.
func (s *Bag[T]) Add(p Provider, val T) error {
return p.WriteBagState(Transaction{
Key: s.Key,
Type: TransactionTypeAppend,
Val: val,
})
}

// Read is used to read this instance of global pipeline state representing a bag.
// When a value is not found, returns an empty list and false.
func (s *Bag[T]) Read(p Provider) ([]T, bool, error) {
// This replays any writes that have happened to this value since we last read
// For more detail, see "State Transactionality" below for buffered transactions
initialValue, bufferedTransactions, err := p.ReadBagState(s.Key)
if err != nil {
var val []T
return val, false, err
}
cur := []T{}
for _, v := range initialValue {
cur = append(cur, v.(T))
}
for _, t := range bufferedTransactions {
switch t.Type {
case TransactionTypeAppend:
cur = append(cur, t.Val.(T))
case TransactionTypeClear:
cur = []T{}
}
}
if len(cur) == 0 {
return cur, false, nil
}
return cur, true, nil
}

// StateKey returns the key for this pipeline state entry.
func (s Bag[T]) StateKey() string {
if s.Key == "" {
// TODO(#22736) - infer the state from the member variable name during pipeline construction.
panic("Value state exists on struct but has not been initialized with a key.")
}
return s.Key
}

// CoderType returns the type of the bag state which should be used for a coder.
func (s Bag[T]) CoderType() reflect.Type {
var t T
return reflect.TypeOf(t)
}

// StateType returns the type of the state (in this case always Bag).
func (s Bag[T]) StateType() StateTypeEnum {
return StateTypeBag
}

// MakeBagState is a factory function to create an instance of BagState with the given key.
func MakeBagState[T any](k string) Bag[T] {
return Bag[T]{
Key: k,
}
}
Loading